refactor: type the resolver
almost perfect, except for python < 3.9 making it so we can’t have nice things.
This commit is contained in:
parent
9f01d368c0
commit
4931e7e604
2 changed files with 50 additions and 22 deletions
|
@ -16,11 +16,13 @@ try:
|
||||||
from typing import (
|
from typing import (
|
||||||
Literal,
|
Literal,
|
||||||
TypedDict,
|
TypedDict,
|
||||||
|
Protocol,
|
||||||
)
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from typing_extensions import (
|
from typing_extensions import (
|
||||||
Literal,
|
Literal,
|
||||||
TypedDict,
|
TypedDict,
|
||||||
|
Protocol,
|
||||||
)
|
)
|
||||||
|
|
||||||
from slixmpp.jid import JID
|
from slixmpp.jid import JID
|
||||||
|
|
|
@ -1,18 +1,32 @@
|
||||||
|
|
||||||
# slixmpp.xmlstream.dns
|
# slixmpp.xmlstream.dns
|
||||||
# ~~~~~~~~~~~~~~~~~~~~~~~
|
# ~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
# :copyright: (c) 2012 Nathanael C. Fritz
|
# :copyright: (c) 2012 Nathanael C. Fritz
|
||||||
# :license: MIT, see LICENSE for more details
|
# :license: MIT, see LICENSE for more details
|
||||||
|
|
||||||
from slixmpp.xmlstream.asyncio import asyncio
|
|
||||||
import socket
|
import socket
|
||||||
|
import sys
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
|
from asyncio import Future, AbstractEventLoop
|
||||||
|
from typing import Optional, Tuple, Dict, List, Iterable, cast
|
||||||
|
from slixmpp.types import Protocol
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AnswerProtocol(Protocol):
|
||||||
|
host: str
|
||||||
|
priority: int
|
||||||
|
weight: int
|
||||||
|
port: int
|
||||||
|
|
||||||
|
|
||||||
|
class ResolverProtocol(Protocol):
|
||||||
|
def query(self, query: str, querytype: str) -> Future:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
#: Global flag indicating the availability of the ``aiodns`` package.
|
#: Global flag indicating the availability of the ``aiodns`` package.
|
||||||
#: Installing ``aiodns`` can be done via:
|
#: Installing ``aiodns`` can be done via:
|
||||||
#:
|
#:
|
||||||
|
@ -23,12 +37,12 @@ AIODNS_AVAILABLE = False
|
||||||
try:
|
try:
|
||||||
import aiodns
|
import aiodns
|
||||||
AIODNS_AVAILABLE = True
|
AIODNS_AVAILABLE = True
|
||||||
except ImportError as e:
|
except ImportError:
|
||||||
log.debug("Could not find aiodns package. " + \
|
log.debug("Could not find aiodns package. "
|
||||||
"Not all features will be available")
|
"Not all features will be available")
|
||||||
|
|
||||||
|
|
||||||
def default_resolver(loop):
|
def default_resolver(loop: AbstractEventLoop) -> Optional[ResolverProtocol]:
|
||||||
"""Return a basic DNS resolver object.
|
"""Return a basic DNS resolver object.
|
||||||
|
|
||||||
:returns: A :class:`aiodns.DNSResolver` object if aiodns
|
:returns: A :class:`aiodns.DNSResolver` object if aiodns
|
||||||
|
@ -41,8 +55,11 @@ def default_resolver(loop):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def resolve(host, port=None, service=None, proto='tcp',
|
async def resolve(host: str, port: int, *, loop: AbstractEventLoop,
|
||||||
resolver=None, use_ipv6=True, use_aiodns=True, loop=None):
|
service: Optional[str] = None, proto: str = 'tcp',
|
||||||
|
resolver: Optional[ResolverProtocol] = None,
|
||||||
|
use_ipv6: bool = True,
|
||||||
|
use_aiodns: bool = True) -> List[Tuple[str, str, int]]:
|
||||||
"""Peform DNS resolution for a given hostname.
|
"""Peform DNS resolution for a given hostname.
|
||||||
|
|
||||||
Resolution may perform SRV record lookups if a service and protocol
|
Resolution may perform SRV record lookups if a service and protocol
|
||||||
|
@ -91,8 +108,8 @@ async def resolve(host, port=None, service=None, proto='tcp',
|
||||||
if not use_ipv6:
|
if not use_ipv6:
|
||||||
log.debug("DNS: Use of IPv6 has been disabled.")
|
log.debug("DNS: Use of IPv6 has been disabled.")
|
||||||
|
|
||||||
if resolver is None and AIODNS_AVAILABLE and use_aiodns:
|
if resolver is None and use_aiodns:
|
||||||
resolver = aiodns.DNSResolver(loop=loop)
|
resolver = default_resolver(loop=loop)
|
||||||
|
|
||||||
# An IPv6 literal is allowed to be enclosed in square brackets, but
|
# An IPv6 literal is allowed to be enclosed in square brackets, but
|
||||||
# the brackets must be stripped in order to process the literal;
|
# the brackets must be stripped in order to process the literal;
|
||||||
|
@ -101,7 +118,7 @@ async def resolve(host, port=None, service=None, proto='tcp',
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# If `host` is an IPv4 literal, we can return it immediately.
|
# If `host` is an IPv4 literal, we can return it immediately.
|
||||||
ipv4 = socket.inet_aton(host)
|
socket.inet_aton(host)
|
||||||
return [(host, host, port)]
|
return [(host, host, port)]
|
||||||
except socket.error:
|
except socket.error:
|
||||||
pass
|
pass
|
||||||
|
@ -111,7 +128,7 @@ async def resolve(host, port=None, service=None, proto='tcp',
|
||||||
# Likewise, If `host` is an IPv6 literal, we can return
|
# Likewise, If `host` is an IPv6 literal, we can return
|
||||||
# it immediately.
|
# it immediately.
|
||||||
if hasattr(socket, 'inet_pton'):
|
if hasattr(socket, 'inet_pton'):
|
||||||
ipv6 = socket.inet_pton(socket.AF_INET6, host)
|
socket.inet_pton(socket.AF_INET6, host)
|
||||||
return [(host, host, port)]
|
return [(host, host, port)]
|
||||||
except (socket.error, ValueError):
|
except (socket.error, ValueError):
|
||||||
pass
|
pass
|
||||||
|
@ -148,7 +165,10 @@ async def resolve(host, port=None, service=None, proto='tcp',
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
async def get_A(host, resolver=None, use_aiodns=True, loop=None):
|
|
||||||
|
async def get_A(host: str, *, loop: AbstractEventLoop,
|
||||||
|
resolver: Optional[ResolverProtocol] = None,
|
||||||
|
use_aiodns: bool = True) -> List[str]:
|
||||||
"""Lookup DNS A records for a given host.
|
"""Lookup DNS A records for a given host.
|
||||||
|
|
||||||
If ``resolver`` is not provided, or is ``None``, then resolution will
|
If ``resolver`` is not provided, or is ``None``, then resolution will
|
||||||
|
@ -172,10 +192,10 @@ async def get_A(host, resolver=None, use_aiodns=True, loop=None):
|
||||||
# getaddrinfo() method.
|
# getaddrinfo() method.
|
||||||
if resolver is None or not use_aiodns:
|
if resolver is None or not use_aiodns:
|
||||||
try:
|
try:
|
||||||
recs = await loop.getaddrinfo(host, None,
|
inet_recs = await loop.getaddrinfo(host, None,
|
||||||
family=socket.AF_INET,
|
family=socket.AF_INET,
|
||||||
type=socket.SOCK_STREAM)
|
type=socket.SOCK_STREAM)
|
||||||
return [rec[4][0] for rec in recs]
|
return [rec[4][0] for rec in inet_recs]
|
||||||
except socket.gaierror:
|
except socket.gaierror:
|
||||||
log.debug("DNS: Error retrieving A address info for %s." % host)
|
log.debug("DNS: Error retrieving A address info for %s." % host)
|
||||||
return []
|
return []
|
||||||
|
@ -183,14 +203,16 @@ async def get_A(host, resolver=None, use_aiodns=True, loop=None):
|
||||||
# Using aiodns:
|
# Using aiodns:
|
||||||
future = resolver.query(host, 'A')
|
future = resolver.query(host, 'A')
|
||||||
try:
|
try:
|
||||||
recs = await future
|
recs = cast(Iterable[AnswerProtocol], await future)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.debug('DNS: Exception while querying for %s A records: %s', host, e)
|
log.debug('DNS: Exception while querying for %s A records: %s', host, e)
|
||||||
recs = []
|
recs = []
|
||||||
return [rec.host for rec in recs]
|
return [rec.host for rec in recs]
|
||||||
|
|
||||||
|
|
||||||
async def get_AAAA(host, resolver=None, use_aiodns=True, loop=None):
|
async def get_AAAA(host: str, *, loop: AbstractEventLoop,
|
||||||
|
resolver: Optional[ResolverProtocol] = None,
|
||||||
|
use_aiodns: bool = True) -> List[str]:
|
||||||
"""Lookup DNS AAAA records for a given host.
|
"""Lookup DNS AAAA records for a given host.
|
||||||
|
|
||||||
If ``resolver`` is not provided, or is ``None``, then resolution will
|
If ``resolver`` is not provided, or is ``None``, then resolution will
|
||||||
|
@ -217,10 +239,10 @@ async def get_AAAA(host, resolver=None, use_aiodns=True, loop=None):
|
||||||
log.debug("DNS: Unable to query %s for AAAA records: IPv6 is not supported", host)
|
log.debug("DNS: Unable to query %s for AAAA records: IPv6 is not supported", host)
|
||||||
return []
|
return []
|
||||||
try:
|
try:
|
||||||
recs = await loop.getaddrinfo(host, None,
|
inet_recs = await loop.getaddrinfo(host, None,
|
||||||
family=socket.AF_INET6,
|
family=socket.AF_INET6,
|
||||||
type=socket.SOCK_STREAM)
|
type=socket.SOCK_STREAM)
|
||||||
return [rec[4][0] for rec in recs]
|
return [rec[4][0] for rec in inet_recs]
|
||||||
except (OSError, socket.gaierror):
|
except (OSError, socket.gaierror):
|
||||||
log.debug("DNS: Error retrieving AAAA address " + \
|
log.debug("DNS: Error retrieving AAAA address " + \
|
||||||
"info for %s." % host)
|
"info for %s." % host)
|
||||||
|
@ -229,13 +251,17 @@ async def get_AAAA(host, resolver=None, use_aiodns=True, loop=None):
|
||||||
# Using aiodns:
|
# Using aiodns:
|
||||||
future = resolver.query(host, 'AAAA')
|
future = resolver.query(host, 'AAAA')
|
||||||
try:
|
try:
|
||||||
recs = await future
|
recs = cast(Iterable[AnswerProtocol], await future)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.debug('DNS: Exception while querying for %s AAAA records: %s', host, e)
|
log.debug('DNS: Exception while querying for %s AAAA records: %s', host, e)
|
||||||
recs = []
|
recs = []
|
||||||
return [rec.host for rec in recs]
|
return [rec.host for rec in recs]
|
||||||
|
|
||||||
async def get_SRV(host, port, service, proto='tcp', resolver=None, use_aiodns=True):
|
|
||||||
|
async def get_SRV(host: str, port: int, service: str,
|
||||||
|
proto: str = 'tcp',
|
||||||
|
resolver: Optional[ResolverProtocol] = None,
|
||||||
|
use_aiodns: bool = True) -> List[Tuple[str, int]]:
|
||||||
"""Perform SRV record resolution for a given host.
|
"""Perform SRV record resolution for a given host.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
@ -269,12 +295,12 @@ async def get_SRV(host, port, service, proto='tcp', resolver=None, use_aiodns=Tr
|
||||||
try:
|
try:
|
||||||
future = resolver.query('_%s._%s.%s' % (service, proto, host),
|
future = resolver.query('_%s._%s.%s' % (service, proto, host),
|
||||||
'SRV')
|
'SRV')
|
||||||
recs = await future
|
recs = cast(Iterable[AnswerProtocol], await future)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.debug('DNS: Exception while querying for %s SRV records: %s', host, e)
|
log.debug('DNS: Exception while querying for %s SRV records: %s', host, e)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
answers = {}
|
answers: Dict[int, List[AnswerProtocol]] = {}
|
||||||
for rec in recs:
|
for rec in recs:
|
||||||
if rec.priority not in answers:
|
if rec.priority not in answers:
|
||||||
answers[rec.priority] = []
|
answers[rec.priority] = []
|
||||||
|
|
Loading…
Reference in a new issue