refactor: type the resolver

almost perfect, except for python < 3.9 making it so we can’t have nice
things.
This commit is contained in:
mathieui 2021-04-20 22:14:01 +02:00
parent 9f01d368c0
commit 4931e7e604
2 changed files with 50 additions and 22 deletions

View file

@ -16,11 +16,13 @@ try:
from typing import ( from typing import (
Literal, Literal,
TypedDict, TypedDict,
Protocol,
) )
except ImportError: except ImportError:
from typing_extensions import ( from typing_extensions import (
Literal, Literal,
TypedDict, TypedDict,
Protocol,
) )
from slixmpp.jid import JID from slixmpp.jid import JID

View file

@ -1,18 +1,32 @@
# slixmpp.xmlstream.dns # slixmpp.xmlstream.dns
# ~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~
# :copyright: (c) 2012 Nathanael C. Fritz # :copyright: (c) 2012 Nathanael C. Fritz
# :license: MIT, see LICENSE for more details # :license: MIT, see LICENSE for more details
from slixmpp.xmlstream.asyncio import asyncio
import socket import socket
import sys
import logging import logging
import random import random
from asyncio import Future, AbstractEventLoop
from typing import Optional, Tuple, Dict, List, Iterable, cast
from slixmpp.types import Protocol
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class AnswerProtocol(Protocol):
host: str
priority: int
weight: int
port: int
class ResolverProtocol(Protocol):
def query(self, query: str, querytype: str) -> Future:
...
#: Global flag indicating the availability of the ``aiodns`` package. #: Global flag indicating the availability of the ``aiodns`` package.
#: Installing ``aiodns`` can be done via: #: Installing ``aiodns`` can be done via:
#: #:
@ -23,12 +37,12 @@ AIODNS_AVAILABLE = False
try: try:
import aiodns import aiodns
AIODNS_AVAILABLE = True AIODNS_AVAILABLE = True
except ImportError as e: except ImportError:
log.debug("Could not find aiodns package. " + \ log.debug("Could not find aiodns package. "
"Not all features will be available") "Not all features will be available")
def default_resolver(loop): def default_resolver(loop: AbstractEventLoop) -> Optional[ResolverProtocol]:
"""Return a basic DNS resolver object. """Return a basic DNS resolver object.
:returns: A :class:`aiodns.DNSResolver` object if aiodns :returns: A :class:`aiodns.DNSResolver` object if aiodns
@ -41,8 +55,11 @@ def default_resolver(loop):
return None return None
async def resolve(host, port=None, service=None, proto='tcp', async def resolve(host: str, port: int, *, loop: AbstractEventLoop,
resolver=None, use_ipv6=True, use_aiodns=True, loop=None): service: Optional[str] = None, proto: str = 'tcp',
resolver: Optional[ResolverProtocol] = None,
use_ipv6: bool = True,
use_aiodns: bool = True) -> List[Tuple[str, str, int]]:
"""Peform DNS resolution for a given hostname. """Peform DNS resolution for a given hostname.
Resolution may perform SRV record lookups if a service and protocol Resolution may perform SRV record lookups if a service and protocol
@ -91,8 +108,8 @@ async def resolve(host, port=None, service=None, proto='tcp',
if not use_ipv6: if not use_ipv6:
log.debug("DNS: Use of IPv6 has been disabled.") log.debug("DNS: Use of IPv6 has been disabled.")
if resolver is None and AIODNS_AVAILABLE and use_aiodns: if resolver is None and use_aiodns:
resolver = aiodns.DNSResolver(loop=loop) resolver = default_resolver(loop=loop)
# An IPv6 literal is allowed to be enclosed in square brackets, but # An IPv6 literal is allowed to be enclosed in square brackets, but
# the brackets must be stripped in order to process the literal; # the brackets must be stripped in order to process the literal;
@ -101,7 +118,7 @@ async def resolve(host, port=None, service=None, proto='tcp',
try: try:
# If `host` is an IPv4 literal, we can return it immediately. # If `host` is an IPv4 literal, we can return it immediately.
ipv4 = socket.inet_aton(host) socket.inet_aton(host)
return [(host, host, port)] return [(host, host, port)]
except socket.error: except socket.error:
pass pass
@ -111,7 +128,7 @@ async def resolve(host, port=None, service=None, proto='tcp',
# Likewise, If `host` is an IPv6 literal, we can return # Likewise, If `host` is an IPv6 literal, we can return
# it immediately. # it immediately.
if hasattr(socket, 'inet_pton'): if hasattr(socket, 'inet_pton'):
ipv6 = socket.inet_pton(socket.AF_INET6, host) socket.inet_pton(socket.AF_INET6, host)
return [(host, host, port)] return [(host, host, port)]
except (socket.error, ValueError): except (socket.error, ValueError):
pass pass
@ -148,7 +165,10 @@ async def resolve(host, port=None, service=None, proto='tcp',
return results return results
async def get_A(host, resolver=None, use_aiodns=True, loop=None):
async def get_A(host: str, *, loop: AbstractEventLoop,
resolver: Optional[ResolverProtocol] = None,
use_aiodns: bool = True) -> List[str]:
"""Lookup DNS A records for a given host. """Lookup DNS A records for a given host.
If ``resolver`` is not provided, or is ``None``, then resolution will If ``resolver`` is not provided, or is ``None``, then resolution will
@ -172,10 +192,10 @@ async def get_A(host, resolver=None, use_aiodns=True, loop=None):
# getaddrinfo() method. # getaddrinfo() method.
if resolver is None or not use_aiodns: if resolver is None or not use_aiodns:
try: try:
recs = await loop.getaddrinfo(host, None, inet_recs = await loop.getaddrinfo(host, None,
family=socket.AF_INET, family=socket.AF_INET,
type=socket.SOCK_STREAM) type=socket.SOCK_STREAM)
return [rec[4][0] for rec in recs] return [rec[4][0] for rec in inet_recs]
except socket.gaierror: except socket.gaierror:
log.debug("DNS: Error retrieving A address info for %s." % host) log.debug("DNS: Error retrieving A address info for %s." % host)
return [] return []
@ -183,14 +203,16 @@ async def get_A(host, resolver=None, use_aiodns=True, loop=None):
# Using aiodns: # Using aiodns:
future = resolver.query(host, 'A') future = resolver.query(host, 'A')
try: try:
recs = await future recs = cast(Iterable[AnswerProtocol], await future)
except Exception as e: except Exception as e:
log.debug('DNS: Exception while querying for %s A records: %s', host, e) log.debug('DNS: Exception while querying for %s A records: %s', host, e)
recs = [] recs = []
return [rec.host for rec in recs] return [rec.host for rec in recs]
async def get_AAAA(host, resolver=None, use_aiodns=True, loop=None): async def get_AAAA(host: str, *, loop: AbstractEventLoop,
resolver: Optional[ResolverProtocol] = None,
use_aiodns: bool = True) -> List[str]:
"""Lookup DNS AAAA records for a given host. """Lookup DNS AAAA records for a given host.
If ``resolver`` is not provided, or is ``None``, then resolution will If ``resolver`` is not provided, or is ``None``, then resolution will
@ -217,10 +239,10 @@ async def get_AAAA(host, resolver=None, use_aiodns=True, loop=None):
log.debug("DNS: Unable to query %s for AAAA records: IPv6 is not supported", host) log.debug("DNS: Unable to query %s for AAAA records: IPv6 is not supported", host)
return [] return []
try: try:
recs = await loop.getaddrinfo(host, None, inet_recs = await loop.getaddrinfo(host, None,
family=socket.AF_INET6, family=socket.AF_INET6,
type=socket.SOCK_STREAM) type=socket.SOCK_STREAM)
return [rec[4][0] for rec in recs] return [rec[4][0] for rec in inet_recs]
except (OSError, socket.gaierror): except (OSError, socket.gaierror):
log.debug("DNS: Error retrieving AAAA address " + \ log.debug("DNS: Error retrieving AAAA address " + \
"info for %s." % host) "info for %s." % host)
@ -229,13 +251,17 @@ async def get_AAAA(host, resolver=None, use_aiodns=True, loop=None):
# Using aiodns: # Using aiodns:
future = resolver.query(host, 'AAAA') future = resolver.query(host, 'AAAA')
try: try:
recs = await future recs = cast(Iterable[AnswerProtocol], await future)
except Exception as e: except Exception as e:
log.debug('DNS: Exception while querying for %s AAAA records: %s', host, e) log.debug('DNS: Exception while querying for %s AAAA records: %s', host, e)
recs = [] recs = []
return [rec.host for rec in recs] return [rec.host for rec in recs]
async def get_SRV(host, port, service, proto='tcp', resolver=None, use_aiodns=True):
async def get_SRV(host: str, port: int, service: str,
proto: str = 'tcp',
resolver: Optional[ResolverProtocol] = None,
use_aiodns: bool = True) -> List[Tuple[str, int]]:
"""Perform SRV record resolution for a given host. """Perform SRV record resolution for a given host.
.. note:: .. note::
@ -269,12 +295,12 @@ async def get_SRV(host, port, service, proto='tcp', resolver=None, use_aiodns=Tr
try: try:
future = resolver.query('_%s._%s.%s' % (service, proto, host), future = resolver.query('_%s._%s.%s' % (service, proto, host),
'SRV') 'SRV')
recs = await future recs = cast(Iterable[AnswerProtocol], await future)
except Exception as e: except Exception as e:
log.debug('DNS: Exception while querying for %s SRV records: %s', host, e) log.debug('DNS: Exception while querying for %s SRV records: %s', host, e)
return [] return []
answers = {} answers: Dict[int, List[AnswerProtocol]] = {}
for rec in recs: for rec in recs:
if rec.priority not in answers: if rec.priority not in answers:
answers[rec.priority] = [] answers[rec.priority] = []