Merge branch 'aiodns-gethostbyname' into 'master'

Use gethostbyname when using aiodns

See merge request poezio/slixmpp!212
This commit is contained in:
mathieui 2022-09-09 16:04:14 +00:00
commit b3a6c7a4ea

View file

@ -15,7 +15,13 @@ from slixmpp.types import Protocol
log = logging.getLogger(__name__)
class AnswerProtocol(Protocol):
class GetHostByNameAnswerProtocol(Protocol):
name: str
aliases: List[str]
addresses: List[str]
class QueryAnswerProtocol(Protocol):
host: str
priority: int
weight: int
@ -23,6 +29,9 @@ class AnswerProtocol(Protocol):
class ResolverProtocol(Protocol):
def gethostbyname(self, host: str, socket_family: socket.AddressFamily) -> Future:
...
def query(self, query: str, querytype: str) -> Future:
...
@ -147,11 +156,6 @@ async def resolve(host: str, port: int, *, loop: AbstractEventLoop,
results = []
for host, port in hosts:
if host == 'localhost':
if use_ipv6:
results.append((host, '::1', port))
results.append((host, '127.0.0.1', port))
if use_ipv6:
aaaa = await get_AAAA(host, resolver=resolver,
use_aiodns=use_aiodns, loop=loop)
@ -201,13 +205,13 @@ async def get_A(host: str, *, loop: AbstractEventLoop,
return []
# Using aiodns:
future = resolver.query(host, 'A')
future = resolver.gethostbyname(host, socket.AF_INET)
try:
recs = cast(Iterable[AnswerProtocol], await future)
recs = cast(GetHostByNameAnswerProtocol, await future)
except Exception as e:
log.debug('DNS: Exception while querying for %s A records: %s', host, e)
recs = []
return [rec.host for rec in recs]
return []
return [addr for addr in recs.addresses]
async def get_AAAA(host: str, *, loop: AbstractEventLoop,
@ -249,13 +253,13 @@ async def get_AAAA(host: str, *, loop: AbstractEventLoop,
return []
# Using aiodns:
future = resolver.query(host, 'AAAA')
future = resolver.gethostbyname(host, socket.AF_INET6)
try:
recs = cast(Iterable[AnswerProtocol], await future)
recs = cast(GetHostByNameAnswerProtocol, await future)
except Exception as e:
log.debug('DNS: Exception while querying for %s AAAA records: %s', host, e)
recs = []
return [rec.host for rec in recs]
return []
return [addr for addr in recs.addresses]
async def get_SRV(host: str, port: int, service: str,
@ -295,12 +299,12 @@ async def get_SRV(host: str, port: int, service: str,
try:
future = resolver.query('_%s._%s.%s' % (service, proto, host),
'SRV')
recs = cast(Iterable[AnswerProtocol], await future)
recs = cast(Iterable[QueryAnswerProtocol], await future)
except Exception as e:
log.debug('DNS: Exception while querying for %s SRV records: %s', host, e)
return []
answers: Dict[int, List[AnswerProtocol]] = {}
answers: Dict[int, List[QueryAnswerProtocol]] = {}
for rec in recs:
if rec.priority not in answers:
answers[rec.priority] = []