Use gethostbyname when using aiodns
Slixmpp behaves differently when resolving host names, whether aiodns is used or not. With aiodns only DNS is used, while without `asyncio.loop.getaddrinfo()` is used instead, which utilizes the Name Service Switch (NSS) to resolve host names by other means (hosts-file, mDNS, ...) as well. To unify the behavior, this replaces the use of `aiodns.DNSResolver().query()` with `aiodns.DNSResolver().gethostbyname()`. This makes the behavior resolving host names more consistent between using aiodns or not, as both now honor the NSS configuration and removes the need for the previously existing workaround to resolve localhost.
This commit is contained in:
parent
1f47acaec1
commit
d43c83800e
1 changed files with 20 additions and 16 deletions
|
@ -15,7 +15,13 @@ from slixmpp.types import Protocol
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AnswerProtocol(Protocol):
|
class GetHostByNameAnswerProtocol(Protocol):
|
||||||
|
name: str
|
||||||
|
aliases: List[str]
|
||||||
|
addresses: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
class QueryAnswerProtocol(Protocol):
|
||||||
host: str
|
host: str
|
||||||
priority: int
|
priority: int
|
||||||
weight: int
|
weight: int
|
||||||
|
@ -23,6 +29,9 @@ class AnswerProtocol(Protocol):
|
||||||
|
|
||||||
|
|
||||||
class ResolverProtocol(Protocol):
|
class ResolverProtocol(Protocol):
|
||||||
|
def gethostbyname(self, host: str, socket_family: socket.AddressFamily) -> Future:
|
||||||
|
...
|
||||||
|
|
||||||
def query(self, query: str, querytype: str) -> Future:
|
def query(self, query: str, querytype: str) -> Future:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -147,11 +156,6 @@ async def resolve(host: str, port: int, *, loop: AbstractEventLoop,
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for host, port in hosts:
|
for host, port in hosts:
|
||||||
if host == 'localhost':
|
|
||||||
if use_ipv6:
|
|
||||||
results.append((host, '::1', port))
|
|
||||||
results.append((host, '127.0.0.1', port))
|
|
||||||
|
|
||||||
if use_ipv6:
|
if use_ipv6:
|
||||||
aaaa = await get_AAAA(host, resolver=resolver,
|
aaaa = await get_AAAA(host, resolver=resolver,
|
||||||
use_aiodns=use_aiodns, loop=loop)
|
use_aiodns=use_aiodns, loop=loop)
|
||||||
|
@ -201,13 +205,13 @@ async def get_A(host: str, *, loop: AbstractEventLoop,
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# Using aiodns:
|
# Using aiodns:
|
||||||
future = resolver.query(host, 'A')
|
future = resolver.gethostbyname(host, socket.AF_INET)
|
||||||
try:
|
try:
|
||||||
recs = cast(Iterable[AnswerProtocol], await future)
|
recs = cast(GetHostByNameAnswerProtocol, await future)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.debug('DNS: Exception while querying for %s A records: %s', host, e)
|
log.debug('DNS: Exception while querying for %s A records: %s', host, e)
|
||||||
recs = []
|
return []
|
||||||
return [rec.host for rec in recs]
|
return [addr for addr in recs.addresses]
|
||||||
|
|
||||||
|
|
||||||
async def get_AAAA(host: str, *, loop: AbstractEventLoop,
|
async def get_AAAA(host: str, *, loop: AbstractEventLoop,
|
||||||
|
@ -249,13 +253,13 @@ async def get_AAAA(host: str, *, loop: AbstractEventLoop,
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# Using aiodns:
|
# Using aiodns:
|
||||||
future = resolver.query(host, 'AAAA')
|
future = resolver.gethostbyname(host, socket.AF_INET6)
|
||||||
try:
|
try:
|
||||||
recs = cast(Iterable[AnswerProtocol], await future)
|
recs = cast(GetHostByNameAnswerProtocol, await future)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.debug('DNS: Exception while querying for %s AAAA records: %s', host, e)
|
log.debug('DNS: Exception while querying for %s AAAA records: %s', host, e)
|
||||||
recs = []
|
return []
|
||||||
return [rec.host for rec in recs]
|
return [addr for addr in recs.addresses]
|
||||||
|
|
||||||
|
|
||||||
async def get_SRV(host: str, port: int, service: str,
|
async def get_SRV(host: str, port: int, service: str,
|
||||||
|
@ -295,12 +299,12 @@ async def get_SRV(host: str, port: int, service: str,
|
||||||
try:
|
try:
|
||||||
future = resolver.query('_%s._%s.%s' % (service, proto, host),
|
future = resolver.query('_%s._%s.%s' % (service, proto, host),
|
||||||
'SRV')
|
'SRV')
|
||||||
recs = cast(Iterable[AnswerProtocol], await future)
|
recs = cast(Iterable[QueryAnswerProtocol], await future)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.debug('DNS: Exception while querying for %s SRV records: %s', host, e)
|
log.debug('DNS: Exception while querying for %s SRV records: %s', host, e)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
answers: Dict[int, List[AnswerProtocol]] = {}
|
answers: Dict[int, List[QueryAnswerProtocol]] = {}
|
||||||
for rec in recs:
|
for rec in recs:
|
||||||
if rec.priority not in answers:
|
if rec.priority not in answers:
|
||||||
answers[rec.priority] = []
|
answers[rec.priority] = []
|
||||||
|
|
Loading…
Reference in a new issue