From d43c83800e51c2455b5070a1ccaca56b57fb1575 Mon Sep 17 00:00:00 2001 From: Daniel Roschka Date: Fri, 29 Jul 2022 12:04:01 +0200 Subject: [PATCH] Use gethostbyname when using aiodns Slixmpp behaves differently when resolving host names, whether aiodns is used or not. With aiodns only DNS is used, while without `asyncio.loop.getaddrinfo()` is used instead, which utilizes the Name Service Switch (NSS) to resolve host names by other means (hosts-file, mDNS, ...) as well. To unify the behavior, this replaces the use of `aiodns.DNSResolver().query()` with `aiodns.DNSResolver().gethostbyname()`. This makes the behavior resolving host names more consistent between using aiodns or not, as both now honor the NSS configuration and removes the need for the previously existing workaround to resolve localhost. --- slixmpp/xmlstream/resolver.py | 36 +++++++++++++++++++---------------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/slixmpp/xmlstream/resolver.py b/slixmpp/xmlstream/resolver.py index e524da3b..3de6629d 100644 --- a/slixmpp/xmlstream/resolver.py +++ b/slixmpp/xmlstream/resolver.py @@ -15,7 +15,13 @@ from slixmpp.types import Protocol log = logging.getLogger(__name__) -class AnswerProtocol(Protocol): +class GetHostByNameAnswerProtocol(Protocol): + name: str + aliases: List[str] + addresses: List[str] + + +class QueryAnswerProtocol(Protocol): host: str priority: int weight: int @@ -23,6 +29,9 @@ class AnswerProtocol(Protocol): class ResolverProtocol(Protocol): + def gethostbyname(self, host: str, socket_family: socket.AddressFamily) -> Future: + ... + def query(self, query: str, querytype: str) -> Future: ... @@ -147,11 +156,6 @@ async def resolve(host: str, port: int, *, loop: AbstractEventLoop, results = [] for host, port in hosts: - if host == 'localhost': - if use_ipv6: - results.append((host, '::1', port)) - results.append((host, '127.0.0.1', port)) - if use_ipv6: aaaa = await get_AAAA(host, resolver=resolver, use_aiodns=use_aiodns, loop=loop) @@ -201,13 +205,13 @@ async def get_A(host: str, *, loop: AbstractEventLoop, return [] # Using aiodns: - future = resolver.query(host, 'A') + future = resolver.gethostbyname(host, socket.AF_INET) try: - recs = cast(Iterable[AnswerProtocol], await future) + recs = cast(GetHostByNameAnswerProtocol, await future) except Exception as e: log.debug('DNS: Exception while querying for %s A records: %s', host, e) - recs = [] - return [rec.host for rec in recs] + return [] + return [addr for addr in recs.addresses] async def get_AAAA(host: str, *, loop: AbstractEventLoop, @@ -249,13 +253,13 @@ async def get_AAAA(host: str, *, loop: AbstractEventLoop, return [] # Using aiodns: - future = resolver.query(host, 'AAAA') + future = resolver.gethostbyname(host, socket.AF_INET6) try: - recs = cast(Iterable[AnswerProtocol], await future) + recs = cast(GetHostByNameAnswerProtocol, await future) except Exception as e: log.debug('DNS: Exception while querying for %s AAAA records: %s', host, e) - recs = [] - return [rec.host for rec in recs] + return [] + return [addr for addr in recs.addresses] async def get_SRV(host: str, port: int, service: str, @@ -295,12 +299,12 @@ async def get_SRV(host: str, port: int, service: str, try: future = resolver.query('_%s._%s.%s' % (service, proto, host), 'SRV') - recs = cast(Iterable[AnswerProtocol], await future) + recs = cast(Iterable[QueryAnswerProtocol], await future) except Exception as e: log.debug('DNS: Exception while querying for %s SRV records: %s', host, e) return [] - answers: Dict[int, List[AnswerProtocol]] = {} + answers: Dict[int, List[QueryAnswerProtocol]] = {} for rec in recs: if rec.priority not in answers: answers[rec.priority] = []