refactor: type the resolver
almost perfect, except for python < 3.9 making it so we can’t have nice things.
This commit is contained in:
parent
9f01d368c0
commit
4931e7e604
2 changed files with 50 additions and 22 deletions
|
@ -16,11 +16,13 @@ try:
|
|||
from typing import (
|
||||
Literal,
|
||||
TypedDict,
|
||||
Protocol,
|
||||
)
|
||||
except ImportError:
|
||||
from typing_extensions import (
|
||||
Literal,
|
||||
TypedDict,
|
||||
Protocol,
|
||||
)
|
||||
|
||||
from slixmpp.jid import JID
|
||||
|
|
|
@ -1,18 +1,32 @@
|
|||
|
||||
# slixmpp.xmlstream.dns
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# :copyright: (c) 2012 Nathanael C. Fritz
|
||||
# :license: MIT, see LICENSE for more details
|
||||
|
||||
from slixmpp.xmlstream.asyncio import asyncio
|
||||
import socket
|
||||
import sys
|
||||
import logging
|
||||
import random
|
||||
from asyncio import Future, AbstractEventLoop
|
||||
from typing import Optional, Tuple, Dict, List, Iterable, cast
|
||||
from slixmpp.types import Protocol
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnswerProtocol(Protocol):
|
||||
host: str
|
||||
priority: int
|
||||
weight: int
|
||||
port: int
|
||||
|
||||
|
||||
class ResolverProtocol(Protocol):
|
||||
def query(self, query: str, querytype: str) -> Future:
|
||||
...
|
||||
|
||||
|
||||
#: Global flag indicating the availability of the ``aiodns`` package.
|
||||
#: Installing ``aiodns`` can be done via:
|
||||
#:
|
||||
|
@ -23,12 +37,12 @@ AIODNS_AVAILABLE = False
|
|||
try:
|
||||
import aiodns
|
||||
AIODNS_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
log.debug("Could not find aiodns package. " + \
|
||||
except ImportError:
|
||||
log.debug("Could not find aiodns package. "
|
||||
"Not all features will be available")
|
||||
|
||||
|
||||
def default_resolver(loop):
|
||||
def default_resolver(loop: AbstractEventLoop) -> Optional[ResolverProtocol]:
|
||||
"""Return a basic DNS resolver object.
|
||||
|
||||
:returns: A :class:`aiodns.DNSResolver` object if aiodns
|
||||
|
@ -41,8 +55,11 @@ def default_resolver(loop):
|
|||
return None
|
||||
|
||||
|
||||
async def resolve(host, port=None, service=None, proto='tcp',
|
||||
resolver=None, use_ipv6=True, use_aiodns=True, loop=None):
|
||||
async def resolve(host: str, port: int, *, loop: AbstractEventLoop,
|
||||
service: Optional[str] = None, proto: str = 'tcp',
|
||||
resolver: Optional[ResolverProtocol] = None,
|
||||
use_ipv6: bool = True,
|
||||
use_aiodns: bool = True) -> List[Tuple[str, str, int]]:
|
||||
"""Peform DNS resolution for a given hostname.
|
||||
|
||||
Resolution may perform SRV record lookups if a service and protocol
|
||||
|
@ -91,8 +108,8 @@ async def resolve(host, port=None, service=None, proto='tcp',
|
|||
if not use_ipv6:
|
||||
log.debug("DNS: Use of IPv6 has been disabled.")
|
||||
|
||||
if resolver is None and AIODNS_AVAILABLE and use_aiodns:
|
||||
resolver = aiodns.DNSResolver(loop=loop)
|
||||
if resolver is None and use_aiodns:
|
||||
resolver = default_resolver(loop=loop)
|
||||
|
||||
# An IPv6 literal is allowed to be enclosed in square brackets, but
|
||||
# the brackets must be stripped in order to process the literal;
|
||||
|
@ -101,7 +118,7 @@ async def resolve(host, port=None, service=None, proto='tcp',
|
|||
|
||||
try:
|
||||
# If `host` is an IPv4 literal, we can return it immediately.
|
||||
ipv4 = socket.inet_aton(host)
|
||||
socket.inet_aton(host)
|
||||
return [(host, host, port)]
|
||||
except socket.error:
|
||||
pass
|
||||
|
@ -111,7 +128,7 @@ async def resolve(host, port=None, service=None, proto='tcp',
|
|||
# Likewise, If `host` is an IPv6 literal, we can return
|
||||
# it immediately.
|
||||
if hasattr(socket, 'inet_pton'):
|
||||
ipv6 = socket.inet_pton(socket.AF_INET6, host)
|
||||
socket.inet_pton(socket.AF_INET6, host)
|
||||
return [(host, host, port)]
|
||||
except (socket.error, ValueError):
|
||||
pass
|
||||
|
@ -148,7 +165,10 @@ async def resolve(host, port=None, service=None, proto='tcp',
|
|||
|
||||
return results
|
||||
|
||||
async def get_A(host, resolver=None, use_aiodns=True, loop=None):
|
||||
|
||||
async def get_A(host: str, *, loop: AbstractEventLoop,
|
||||
resolver: Optional[ResolverProtocol] = None,
|
||||
use_aiodns: bool = True) -> List[str]:
|
||||
"""Lookup DNS A records for a given host.
|
||||
|
||||
If ``resolver`` is not provided, or is ``None``, then resolution will
|
||||
|
@ -172,10 +192,10 @@ async def get_A(host, resolver=None, use_aiodns=True, loop=None):
|
|||
# getaddrinfo() method.
|
||||
if resolver is None or not use_aiodns:
|
||||
try:
|
||||
recs = await loop.getaddrinfo(host, None,
|
||||
inet_recs = await loop.getaddrinfo(host, None,
|
||||
family=socket.AF_INET,
|
||||
type=socket.SOCK_STREAM)
|
||||
return [rec[4][0] for rec in recs]
|
||||
return [rec[4][0] for rec in inet_recs]
|
||||
except socket.gaierror:
|
||||
log.debug("DNS: Error retrieving A address info for %s." % host)
|
||||
return []
|
||||
|
@ -183,14 +203,16 @@ async def get_A(host, resolver=None, use_aiodns=True, loop=None):
|
|||
# Using aiodns:
|
||||
future = resolver.query(host, 'A')
|
||||
try:
|
||||
recs = await future
|
||||
recs = cast(Iterable[AnswerProtocol], await future)
|
||||
except Exception as e:
|
||||
log.debug('DNS: Exception while querying for %s A records: %s', host, e)
|
||||
recs = []
|
||||
return [rec.host for rec in recs]
|
||||
|
||||
|
||||
async def get_AAAA(host, resolver=None, use_aiodns=True, loop=None):
|
||||
async def get_AAAA(host: str, *, loop: AbstractEventLoop,
|
||||
resolver: Optional[ResolverProtocol] = None,
|
||||
use_aiodns: bool = True) -> List[str]:
|
||||
"""Lookup DNS AAAA records for a given host.
|
||||
|
||||
If ``resolver`` is not provided, or is ``None``, then resolution will
|
||||
|
@ -217,10 +239,10 @@ async def get_AAAA(host, resolver=None, use_aiodns=True, loop=None):
|
|||
log.debug("DNS: Unable to query %s for AAAA records: IPv6 is not supported", host)
|
||||
return []
|
||||
try:
|
||||
recs = await loop.getaddrinfo(host, None,
|
||||
inet_recs = await loop.getaddrinfo(host, None,
|
||||
family=socket.AF_INET6,
|
||||
type=socket.SOCK_STREAM)
|
||||
return [rec[4][0] for rec in recs]
|
||||
return [rec[4][0] for rec in inet_recs]
|
||||
except (OSError, socket.gaierror):
|
||||
log.debug("DNS: Error retrieving AAAA address " + \
|
||||
"info for %s." % host)
|
||||
|
@ -229,13 +251,17 @@ async def get_AAAA(host, resolver=None, use_aiodns=True, loop=None):
|
|||
# Using aiodns:
|
||||
future = resolver.query(host, 'AAAA')
|
||||
try:
|
||||
recs = await future
|
||||
recs = cast(Iterable[AnswerProtocol], await future)
|
||||
except Exception as e:
|
||||
log.debug('DNS: Exception while querying for %s AAAA records: %s', host, e)
|
||||
recs = []
|
||||
return [rec.host for rec in recs]
|
||||
|
||||
async def get_SRV(host, port, service, proto='tcp', resolver=None, use_aiodns=True):
|
||||
|
||||
async def get_SRV(host: str, port: int, service: str,
|
||||
proto: str = 'tcp',
|
||||
resolver: Optional[ResolverProtocol] = None,
|
||||
use_aiodns: bool = True) -> List[Tuple[str, int]]:
|
||||
"""Perform SRV record resolution for a given host.
|
||||
|
||||
.. note::
|
||||
|
@ -269,12 +295,12 @@ async def get_SRV(host, port, service, proto='tcp', resolver=None, use_aiodns=Tr
|
|||
try:
|
||||
future = resolver.query('_%s._%s.%s' % (service, proto, host),
|
||||
'SRV')
|
||||
recs = await future
|
||||
recs = cast(Iterable[AnswerProtocol], await future)
|
||||
except Exception as e:
|
||||
log.debug('DNS: Exception while querying for %s SRV records: %s', host, e)
|
||||
return []
|
||||
|
||||
answers = {}
|
||||
answers: Dict[int, List[AnswerProtocol]] = {}
|
||||
for rec in recs:
|
||||
if rec.priority not in answers:
|
||||
answers[rec.priority] = []
|
||||
|
|
Loading…
Reference in a new issue