xmlstream: make dns_answers private
This commit is contained in:
parent
d3063a0368
commit
ccbba89cbd
1 changed files with 11 additions and 9 deletions
|
@ -16,10 +16,12 @@ from typing import (
|
|||
Any,
|
||||
Callable,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Union,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
import functools
|
||||
|
@ -212,7 +214,7 @@ class XMLStream(asyncio.BaseProtocol):
|
|||
self._current_connection_attempt = None
|
||||
|
||||
#: A list of DNS results that have not yet been tried.
|
||||
self.dns_answers = None
|
||||
self._dns_answers: Optional[Iterator[Tuple[str, str, int]]] = None
|
||||
|
||||
#: The service name to check with DNS SRV records. For
|
||||
#: example, setting this to ``'xmpp-client'`` would query the
|
||||
|
@ -315,7 +317,7 @@ class XMLStream(asyncio.BaseProtocol):
|
|||
self.event('reconnect_delay', self._connect_loop_wait)
|
||||
await asyncio.sleep(self._connect_loop_wait, loop=self.loop)
|
||||
|
||||
record = await self.pick_dns_answer(self.default_domain)
|
||||
record = await self._pick_dns_answer(self.default_domain)
|
||||
if record is not None:
|
||||
host, address, dns_port = record
|
||||
port = dns_port if dns_port else self.address[1]
|
||||
|
@ -324,7 +326,7 @@ class XMLStream(asyncio.BaseProtocol):
|
|||
else:
|
||||
# No DNS records left, stop iterating
|
||||
# and try (host, port) as a last resort
|
||||
self.dns_answers = None
|
||||
self._dns_answers = None
|
||||
|
||||
if self.use_ssl:
|
||||
ssl_context = self.get_ssl_context()
|
||||
|
@ -392,7 +394,7 @@ class XMLStream(asyncio.BaseProtocol):
|
|||
self._current_connection_attempt = None
|
||||
self.init_parser()
|
||||
self.send_raw(self.stream_header)
|
||||
self.dns_answers = None
|
||||
self._dns_answers = None
|
||||
|
||||
def data_received(self, data):
|
||||
"""Called when incoming data is received on the socket.
|
||||
|
@ -777,7 +779,7 @@ class XMLStream(asyncio.BaseProtocol):
|
|||
idx += 1
|
||||
return False
|
||||
|
||||
async def get_dns_records(self, domain, port=None):
|
||||
async def get_dns_records(self, domain: str, port: Optional[int] = None) -> List[Tuple[str, str, int]]:
|
||||
"""Get the DNS records for a domain.
|
||||
|
||||
:param domain: The domain in question.
|
||||
|
@ -797,7 +799,7 @@ class XMLStream(asyncio.BaseProtocol):
|
|||
loop=self.loop)
|
||||
return result
|
||||
|
||||
async def pick_dns_answer(self, domain, port=None):
|
||||
async def _pick_dns_answer(self, domain: str, port: Optional[int] = None) -> Optional[Tuple[str, str, int]]:
|
||||
"""Pick a server and port from DNS answers.
|
||||
|
||||
Gets DNS answers if none available.
|
||||
|
@ -806,12 +808,12 @@ class XMLStream(asyncio.BaseProtocol):
|
|||
:param domain: The domain in question.
|
||||
:param port: If the results don't include a port, use this one.
|
||||
"""
|
||||
if self.dns_answers is None:
|
||||
if self._dns_answers is None:
|
||||
dns_records = await self.get_dns_records(domain, port)
|
||||
self.dns_answers = iter(dns_records)
|
||||
self._dns_answers = iter(dns_records)
|
||||
|
||||
try:
|
||||
return next(self.dns_answers)
|
||||
return next(self._dns_answers)
|
||||
except StopIteration:
|
||||
return
|
||||
|
||||
|
|
Loading…
Reference in a new issue