Merge branch 'aiodns-gethostbyname' into 'master'
Use gethostbyname when using aiodns See merge request poezio/slixmpp!212
This commit is contained in:
commit
b3a6c7a4ea
1 changed files with 20 additions and 16 deletions
|
@ -15,7 +15,13 @@ from slixmpp.types import Protocol
|
|||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnswerProtocol(Protocol):
|
||||
class GetHostByNameAnswerProtocol(Protocol):
|
||||
name: str
|
||||
aliases: List[str]
|
||||
addresses: List[str]
|
||||
|
||||
|
||||
class QueryAnswerProtocol(Protocol):
|
||||
host: str
|
||||
priority: int
|
||||
weight: int
|
||||
|
@ -23,6 +29,9 @@ class AnswerProtocol(Protocol):
|
|||
|
||||
|
||||
class ResolverProtocol(Protocol):
|
||||
def gethostbyname(self, host: str, socket_family: socket.AddressFamily) -> Future:
|
||||
...
|
||||
|
||||
def query(self, query: str, querytype: str) -> Future:
|
||||
...
|
||||
|
||||
|
@ -147,11 +156,6 @@ async def resolve(host: str, port: int, *, loop: AbstractEventLoop,
|
|||
|
||||
results = []
|
||||
for host, port in hosts:
|
||||
if host == 'localhost':
|
||||
if use_ipv6:
|
||||
results.append((host, '::1', port))
|
||||
results.append((host, '127.0.0.1', port))
|
||||
|
||||
if use_ipv6:
|
||||
aaaa = await get_AAAA(host, resolver=resolver,
|
||||
use_aiodns=use_aiodns, loop=loop)
|
||||
|
@ -201,13 +205,13 @@ async def get_A(host: str, *, loop: AbstractEventLoop,
|
|||
return []
|
||||
|
||||
# Using aiodns:
|
||||
future = resolver.query(host, 'A')
|
||||
future = resolver.gethostbyname(host, socket.AF_INET)
|
||||
try:
|
||||
recs = cast(Iterable[AnswerProtocol], await future)
|
||||
recs = cast(GetHostByNameAnswerProtocol, await future)
|
||||
except Exception as e:
|
||||
log.debug('DNS: Exception while querying for %s A records: %s', host, e)
|
||||
recs = []
|
||||
return [rec.host for rec in recs]
|
||||
return []
|
||||
return [addr for addr in recs.addresses]
|
||||
|
||||
|
||||
async def get_AAAA(host: str, *, loop: AbstractEventLoop,
|
||||
|
@ -249,13 +253,13 @@ async def get_AAAA(host: str, *, loop: AbstractEventLoop,
|
|||
return []
|
||||
|
||||
# Using aiodns:
|
||||
future = resolver.query(host, 'AAAA')
|
||||
future = resolver.gethostbyname(host, socket.AF_INET6)
|
||||
try:
|
||||
recs = cast(Iterable[AnswerProtocol], await future)
|
||||
recs = cast(GetHostByNameAnswerProtocol, await future)
|
||||
except Exception as e:
|
||||
log.debug('DNS: Exception while querying for %s AAAA records: %s', host, e)
|
||||
recs = []
|
||||
return [rec.host for rec in recs]
|
||||
return []
|
||||
return [addr for addr in recs.addresses]
|
||||
|
||||
|
||||
async def get_SRV(host: str, port: int, service: str,
|
||||
|
@ -295,12 +299,12 @@ async def get_SRV(host: str, port: int, service: str,
|
|||
try:
|
||||
future = resolver.query('_%s._%s.%s' % (service, proto, host),
|
||||
'SRV')
|
||||
recs = cast(Iterable[AnswerProtocol], await future)
|
||||
recs = cast(Iterable[QueryAnswerProtocol], await future)
|
||||
except Exception as e:
|
||||
log.debug('DNS: Exception while querying for %s SRV records: %s', host, e)
|
||||
return []
|
||||
|
||||
answers: Dict[int, List[AnswerProtocol]] = {}
|
||||
answers: Dict[int, List[QueryAnswerProtocol]] = {}
|
||||
for rec in recs:
|
||||
if rec.priority not in answers:
|
||||
answers[rec.priority] = []
|
||||
|
|
Loading…
Reference in a new issue