xmlstream: make dns_answers private

This commit is contained in:
mathieui 2021-02-04 18:42:01 +01:00
parent d3063a0368
commit ccbba89cbd

View file

@ -16,10 +16,12 @@ from typing import (
Any,
Callable,
Iterable,
Iterator,
List,
Optional,
Set,
Union,
Tuple,
)
import functools
@ -212,7 +214,7 @@ class XMLStream(asyncio.BaseProtocol):
self._current_connection_attempt = None
#: A list of DNS results that have not yet been tried.
self.dns_answers = None
self._dns_answers: Optional[Iterator[Tuple[str, str, int]]] = None
#: The service name to check with DNS SRV records. For
#: example, setting this to ``'xmpp-client'`` would query the
@ -315,7 +317,7 @@ class XMLStream(asyncio.BaseProtocol):
self.event('reconnect_delay', self._connect_loop_wait)
await asyncio.sleep(self._connect_loop_wait, loop=self.loop)
record = await self.pick_dns_answer(self.default_domain)
record = await self._pick_dns_answer(self.default_domain)
if record is not None:
host, address, dns_port = record
port = dns_port if dns_port else self.address[1]
@ -324,7 +326,7 @@ class XMLStream(asyncio.BaseProtocol):
else:
# No DNS records left, stop iterating
# and try (host, port) as a last resort
self.dns_answers = None
self._dns_answers = None
if self.use_ssl:
ssl_context = self.get_ssl_context()
@ -392,7 +394,7 @@ class XMLStream(asyncio.BaseProtocol):
self._current_connection_attempt = None
self.init_parser()
self.send_raw(self.stream_header)
self.dns_answers = None
self._dns_answers = None
def data_received(self, data):
"""Called when incoming data is received on the socket.
@ -777,7 +779,7 @@ class XMLStream(asyncio.BaseProtocol):
idx += 1
return False
async def get_dns_records(self, domain, port=None):
async def get_dns_records(self, domain: str, port: Optional[int] = None) -> List[Tuple[str, str, int]]:
"""Get the DNS records for a domain.
:param domain: The domain in question.
@ -797,7 +799,7 @@ class XMLStream(asyncio.BaseProtocol):
loop=self.loop)
return result
async def pick_dns_answer(self, domain, port=None):
async def _pick_dns_answer(self, domain: str, port: Optional[int] = None) -> Optional[Tuple[str, str, int]]:
"""Pick a server and port from DNS answers.
Gets DNS answers if none available.
@ -806,12 +808,12 @@ class XMLStream(asyncio.BaseProtocol):
:param domain: The domain in question.
:param port: If the results don't include a port, use this one.
"""
if self.dns_answers is None:
if self._dns_answers is None:
dns_records = await self.get_dns_records(domain, port)
self.dns_answers = iter(dns_records)
self._dns_answers = iter(dns_records)
try:
return next(self.dns_answers)
return next(self._dns_answers)
except StopIteration:
return