refactor: type the resolver

almost perfect, except for python < 3.9 making it so we can’t have nice
things.
This commit is contained in:
mathieui 2021-04-20 22:14:01 +02:00
parent 9f01d368c0
commit 4931e7e604
2 changed files with 50 additions and 22 deletions

View file

@ -16,11 +16,13 @@ try:
from typing import (
Literal,
TypedDict,
Protocol,
)
except ImportError:
from typing_extensions import (
Literal,
TypedDict,
Protocol,
)
from slixmpp.jid import JID

View file

@ -1,18 +1,32 @@
# slixmpp.xmlstream.dns
# ~~~~~~~~~~~~~~~~~~~~~~~
# :copyright: (c) 2012 Nathanael C. Fritz
# :license: MIT, see LICENSE for more details
from slixmpp.xmlstream.asyncio import asyncio
import socket
import sys
import logging
import random
from asyncio import Future, AbstractEventLoop
from typing import Optional, Tuple, Dict, List, Iterable, cast
from slixmpp.types import Protocol
log = logging.getLogger(__name__)
class AnswerProtocol(Protocol):
host: str
priority: int
weight: int
port: int
class ResolverProtocol(Protocol):
def query(self, query: str, querytype: str) -> Future:
...
#: Global flag indicating the availability of the ``aiodns`` package.
#: Installing ``aiodns`` can be done via:
#:
@ -23,12 +37,12 @@ AIODNS_AVAILABLE = False
try:
import aiodns
AIODNS_AVAILABLE = True
except ImportError as e:
log.debug("Could not find aiodns package. " + \
except ImportError:
log.debug("Could not find aiodns package. "
"Not all features will be available")
def default_resolver(loop):
def default_resolver(loop: AbstractEventLoop) -> Optional[ResolverProtocol]:
"""Return a basic DNS resolver object.
:returns: A :class:`aiodns.DNSResolver` object if aiodns
@ -41,8 +55,11 @@ def default_resolver(loop):
return None
async def resolve(host, port=None, service=None, proto='tcp',
resolver=None, use_ipv6=True, use_aiodns=True, loop=None):
async def resolve(host: str, port: int, *, loop: AbstractEventLoop,
service: Optional[str] = None, proto: str = 'tcp',
resolver: Optional[ResolverProtocol] = None,
use_ipv6: bool = True,
use_aiodns: bool = True) -> List[Tuple[str, str, int]]:
"""Peform DNS resolution for a given hostname.
Resolution may perform SRV record lookups if a service and protocol
@ -91,8 +108,8 @@ async def resolve(host, port=None, service=None, proto='tcp',
if not use_ipv6:
log.debug("DNS: Use of IPv6 has been disabled.")
if resolver is None and AIODNS_AVAILABLE and use_aiodns:
resolver = aiodns.DNSResolver(loop=loop)
if resolver is None and use_aiodns:
resolver = default_resolver(loop=loop)
# An IPv6 literal is allowed to be enclosed in square brackets, but
# the brackets must be stripped in order to process the literal;
@ -101,7 +118,7 @@ async def resolve(host, port=None, service=None, proto='tcp',
try:
# If `host` is an IPv4 literal, we can return it immediately.
ipv4 = socket.inet_aton(host)
socket.inet_aton(host)
return [(host, host, port)]
except socket.error:
pass
@ -111,7 +128,7 @@ async def resolve(host, port=None, service=None, proto='tcp',
# Likewise, If `host` is an IPv6 literal, we can return
# it immediately.
if hasattr(socket, 'inet_pton'):
ipv6 = socket.inet_pton(socket.AF_INET6, host)
socket.inet_pton(socket.AF_INET6, host)
return [(host, host, port)]
except (socket.error, ValueError):
pass
@ -148,7 +165,10 @@ async def resolve(host, port=None, service=None, proto='tcp',
return results
async def get_A(host, resolver=None, use_aiodns=True, loop=None):
async def get_A(host: str, *, loop: AbstractEventLoop,
resolver: Optional[ResolverProtocol] = None,
use_aiodns: bool = True) -> List[str]:
"""Lookup DNS A records for a given host.
If ``resolver`` is not provided, or is ``None``, then resolution will
@ -172,10 +192,10 @@ async def get_A(host, resolver=None, use_aiodns=True, loop=None):
# getaddrinfo() method.
if resolver is None or not use_aiodns:
try:
recs = await loop.getaddrinfo(host, None,
inet_recs = await loop.getaddrinfo(host, None,
family=socket.AF_INET,
type=socket.SOCK_STREAM)
return [rec[4][0] for rec in recs]
return [rec[4][0] for rec in inet_recs]
except socket.gaierror:
log.debug("DNS: Error retrieving A address info for %s." % host)
return []
@ -183,14 +203,16 @@ async def get_A(host, resolver=None, use_aiodns=True, loop=None):
# Using aiodns:
future = resolver.query(host, 'A')
try:
recs = await future
recs = cast(Iterable[AnswerProtocol], await future)
except Exception as e:
log.debug('DNS: Exception while querying for %s A records: %s', host, e)
recs = []
return [rec.host for rec in recs]
async def get_AAAA(host, resolver=None, use_aiodns=True, loop=None):
async def get_AAAA(host: str, *, loop: AbstractEventLoop,
resolver: Optional[ResolverProtocol] = None,
use_aiodns: bool = True) -> List[str]:
"""Lookup DNS AAAA records for a given host.
If ``resolver`` is not provided, or is ``None``, then resolution will
@ -217,10 +239,10 @@ async def get_AAAA(host, resolver=None, use_aiodns=True, loop=None):
log.debug("DNS: Unable to query %s for AAAA records: IPv6 is not supported", host)
return []
try:
recs = await loop.getaddrinfo(host, None,
inet_recs = await loop.getaddrinfo(host, None,
family=socket.AF_INET6,
type=socket.SOCK_STREAM)
return [rec[4][0] for rec in recs]
return [rec[4][0] for rec in inet_recs]
except (OSError, socket.gaierror):
log.debug("DNS: Error retrieving AAAA address " + \
"info for %s." % host)
@ -229,13 +251,17 @@ async def get_AAAA(host, resolver=None, use_aiodns=True, loop=None):
# Using aiodns:
future = resolver.query(host, 'AAAA')
try:
recs = await future
recs = cast(Iterable[AnswerProtocol], await future)
except Exception as e:
log.debug('DNS: Exception while querying for %s AAAA records: %s', host, e)
recs = []
return [rec.host for rec in recs]
async def get_SRV(host, port, service, proto='tcp', resolver=None, use_aiodns=True):
async def get_SRV(host: str, port: int, service: str,
proto: str = 'tcp',
resolver: Optional[ResolverProtocol] = None,
use_aiodns: bool = True) -> List[Tuple[str, int]]:
"""Perform SRV record resolution for a given host.
.. note::
@ -269,12 +295,12 @@ async def get_SRV(host, port, service, proto='tcp', resolver=None, use_aiodns=Tr
try:
future = resolver.query('_%s._%s.%s' % (service, proto, host),
'SRV')
recs = await future
recs = cast(Iterable[AnswerProtocol], await future)
except Exception as e:
log.debug('DNS: Exception while querying for %s SRV records: %s', host, e)
return []
answers = {}
answers: Dict[int, List[AnswerProtocol]] = {}
for rec in recs:
if rec.priority not in answers:
answers[rec.priority] = []