Use gethostbyname when using aiodns

Slixmpp behaves differently when resolving host names, whether aiodns
is used or not. With aiodns only DNS is used, while without
`asyncio.loop.getaddrinfo()` is used instead, which utilizes the Name
Service Switch (NSS) to resolve host names by other means (hosts-file,
mDNS, ...) as well.

To unify the behavior, this replaces the use of
`aiodns.DNSResolver().query()` with
`aiodns.DNSResolver().gethostbyname()`. This makes the behavior
resolving host names more consistent between using aiodns or not, as
both now honor the NSS configuration and removes the need for the
previously existing workaround to resolve localhost.
This commit is contained in:
Daniel Roschka 2022-07-29 12:04:01 +02:00
parent 1f47acaec1
commit d43c83800e
No known key found for this signature in database
GPG key ID: 885B16854284E0B2

View file

@ -15,7 +15,13 @@ from slixmpp.types import Protocol
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class AnswerProtocol(Protocol): class GetHostByNameAnswerProtocol(Protocol):
name: str
aliases: List[str]
addresses: List[str]
class QueryAnswerProtocol(Protocol):
host: str host: str
priority: int priority: int
weight: int weight: int
@ -23,6 +29,9 @@ class AnswerProtocol(Protocol):
class ResolverProtocol(Protocol): class ResolverProtocol(Protocol):
def gethostbyname(self, host: str, socket_family: socket.AddressFamily) -> Future:
...
def query(self, query: str, querytype: str) -> Future: def query(self, query: str, querytype: str) -> Future:
... ...
@ -147,11 +156,6 @@ async def resolve(host: str, port: int, *, loop: AbstractEventLoop,
results = [] results = []
for host, port in hosts: for host, port in hosts:
if host == 'localhost':
if use_ipv6:
results.append((host, '::1', port))
results.append((host, '127.0.0.1', port))
if use_ipv6: if use_ipv6:
aaaa = await get_AAAA(host, resolver=resolver, aaaa = await get_AAAA(host, resolver=resolver,
use_aiodns=use_aiodns, loop=loop) use_aiodns=use_aiodns, loop=loop)
@ -201,13 +205,13 @@ async def get_A(host: str, *, loop: AbstractEventLoop,
return [] return []
# Using aiodns: # Using aiodns:
future = resolver.query(host, 'A') future = resolver.gethostbyname(host, socket.AF_INET)
try: try:
recs = cast(Iterable[AnswerProtocol], await future) recs = cast(GetHostByNameAnswerProtocol, await future)
except Exception as e: except Exception as e:
log.debug('DNS: Exception while querying for %s A records: %s', host, e) log.debug('DNS: Exception while querying for %s A records: %s', host, e)
recs = [] return []
return [rec.host for rec in recs] return [addr for addr in recs.addresses]
async def get_AAAA(host: str, *, loop: AbstractEventLoop, async def get_AAAA(host: str, *, loop: AbstractEventLoop,
@ -249,13 +253,13 @@ async def get_AAAA(host: str, *, loop: AbstractEventLoop,
return [] return []
# Using aiodns: # Using aiodns:
future = resolver.query(host, 'AAAA') future = resolver.gethostbyname(host, socket.AF_INET6)
try: try:
recs = cast(Iterable[AnswerProtocol], await future) recs = cast(GetHostByNameAnswerProtocol, await future)
except Exception as e: except Exception as e:
log.debug('DNS: Exception while querying for %s AAAA records: %s', host, e) log.debug('DNS: Exception while querying for %s AAAA records: %s', host, e)
recs = [] return []
return [rec.host for rec in recs] return [addr for addr in recs.addresses]
async def get_SRV(host: str, port: int, service: str, async def get_SRV(host: str, port: int, service: str,
@ -295,12 +299,12 @@ async def get_SRV(host: str, port: int, service: str,
try: try:
future = resolver.query('_%s._%s.%s' % (service, proto, host), future = resolver.query('_%s._%s.%s' % (service, proto, host),
'SRV') 'SRV')
recs = cast(Iterable[AnswerProtocol], await future) recs = cast(Iterable[QueryAnswerProtocol], await future)
except Exception as e: except Exception as e:
log.debug('DNS: Exception while querying for %s SRV records: %s', host, e) log.debug('DNS: Exception while querying for %s SRV records: %s', host, e)
return [] return []
answers: Dict[int, List[AnswerProtocol]] = {} answers: Dict[int, List[QueryAnswerProtocol]] = {}
for rec in recs: for rec in recs:
if rec.priority not in answers: if rec.priority not in answers:
answers[rec.priority] = [] answers[rec.priority] = []