Allow the use of a custom loop instead of asyncio.get_event_loop()
This commit is contained in:
parent
f1e6d6b0a9
commit
a2852eb249
3 changed files with 42 additions and 36 deletions
|
@ -332,7 +332,7 @@ class XEP_0325(BasePlugin):
|
|||
self.sessions[session]["nodeDone"][node] = False
|
||||
|
||||
for node in self.sessions[session]["node_list"]:
|
||||
timer = asyncio.get_event_loop().call_later(self.nodes[node]['commTimeout'], partial(self._event_comm_timeout, args=(session, node)))
|
||||
timer = self.xmpp.loop.call_later(self.nodes[node]['commTimeout'], partial(self._event_comm_timeout, args=(session, node)))
|
||||
self.sessions[session]["commTimers"][node] = timer
|
||||
self.nodes[node]['device'].set_control_fields(process_fields, session=session, callback=self._device_set_command_callback)
|
||||
|
||||
|
|
|
@ -32,14 +32,14 @@ except ImportError as e:
|
|||
"Not all features will be available")
|
||||
|
||||
|
||||
def default_resolver():
|
||||
def default_resolver(loop):
|
||||
"""Return a basic DNS resolver object.
|
||||
|
||||
:returns: A :class:`aiodns.DNSResolver` object if aiodns
|
||||
is available. Otherwise, ``None``.
|
||||
"""
|
||||
if AIODNS_AVAILABLE:
|
||||
return aiodns.DNSResolver(loop=asyncio.get_event_loop(),
|
||||
return aiodns.DNSResolver(loop=loop,
|
||||
tries=1,
|
||||
timeout=1.0)
|
||||
return None
|
||||
|
@ -47,7 +47,7 @@ def default_resolver():
|
|||
|
||||
@asyncio.coroutine
|
||||
def resolve(host, port=None, service=None, proto='tcp',
|
||||
resolver=None, use_ipv6=True, use_aiodns=True):
|
||||
resolver=None, use_ipv6=True, use_aiodns=True, loop=None):
|
||||
"""Peform DNS resolution for a given hostname.
|
||||
|
||||
Resolution may perform SRV record lookups if a service and protocol
|
||||
|
@ -97,7 +97,7 @@ def resolve(host, port=None, service=None, proto='tcp',
|
|||
log.debug("DNS: Use of IPv6 has been disabled.")
|
||||
|
||||
if resolver is None and AIODNS_AVAILABLE and use_aiodns:
|
||||
resolver = aiodns.DNSResolver(loop=asyncio.get_event_loop())
|
||||
resolver = aiodns.DNSResolver(loop=loop)
|
||||
|
||||
# An IPv6 literal is allowed to be enclosed in square brackets, but
|
||||
# the brackets must be stripped in order to process the literal;
|
||||
|
@ -142,19 +142,19 @@ def resolve(host, port=None, service=None, proto='tcp',
|
|||
|
||||
if use_ipv6:
|
||||
aaaa = yield from get_AAAA(host, resolver=resolver,
|
||||
use_aiodns=use_aiodns)
|
||||
use_aiodns=use_aiodns, loop=loop)
|
||||
for address in aaaa:
|
||||
results.append((host, address, port))
|
||||
|
||||
a = yield from get_A(host, resolver=resolver,
|
||||
use_aiodns=use_aiodns)
|
||||
use_aiodns=use_aiodns, loop=loop)
|
||||
for address in a:
|
||||
results.append((host, address, port))
|
||||
|
||||
return results
|
||||
|
||||
@asyncio.coroutine
|
||||
def get_A(host, resolver=None, use_aiodns=True):
|
||||
def get_A(host, resolver=None, use_aiodns=True, loop=None):
|
||||
"""Lookup DNS A records for a given host.
|
||||
|
||||
If ``resolver`` is not provided, or is ``None``, then resolution will
|
||||
|
@ -177,7 +177,6 @@ def get_A(host, resolver=None, use_aiodns=True):
|
|||
# If not using aiodns, attempt lookup using the OS level
|
||||
# getaddrinfo() method.
|
||||
if resolver is None or not use_aiodns:
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
recs = yield from loop.getaddrinfo(host, None,
|
||||
family=socket.AF_INET,
|
||||
|
@ -198,7 +197,7 @@ def get_A(host, resolver=None, use_aiodns=True):
|
|||
|
||||
|
||||
@asyncio.coroutine
|
||||
def get_AAAA(host, resolver=None, use_aiodns=True):
|
||||
def get_AAAA(host, resolver=None, use_aiodns=True, loop=None):
|
||||
"""Lookup DNS AAAA records for a given host.
|
||||
|
||||
If ``resolver`` is not provided, or is ``None``, then resolution will
|
||||
|
@ -224,7 +223,6 @@ def get_AAAA(host, resolver=None, use_aiodns=True):
|
|||
if not socket.has_ipv6:
|
||||
log.debug("DNS: Unable to query %s for AAAA records: IPv6 is not supported", host)
|
||||
return []
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
recs = yield from loop.getaddrinfo(host, None,
|
||||
family=socket.AF_INET6,
|
||||
|
|
|
@ -116,6 +116,9 @@ class XMLStream(asyncio.BaseProtocol):
|
|||
|
||||
self._der_cert = None
|
||||
|
||||
# The asyncio event loop
|
||||
self._loop = None
|
||||
|
||||
#: The default port to return when querying DNS records.
|
||||
self.default_port = int(port)
|
||||
|
||||
|
@ -213,6 +216,16 @@ class XMLStream(asyncio.BaseProtocol):
|
|||
self.add_event_handler('disconnected', self._remove_schedules)
|
||||
self.add_event_handler('session_start', self._start_keepalive)
|
||||
|
||||
@property
|
||||
def loop(self):
|
||||
if self._loop is None:
|
||||
self._loop = asyncio.get_event_loop()
|
||||
return self._loop
|
||||
|
||||
@loop.setter
|
||||
def loop(self, value):
|
||||
self._loop = value
|
||||
|
||||
def new_id(self):
|
||||
"""Generate and return a new stream ID in hexadecimal form.
|
||||
|
||||
|
@ -270,7 +283,6 @@ class XMLStream(asyncio.BaseProtocol):
|
|||
|
||||
@asyncio.coroutine
|
||||
def _connect_routine(self):
|
||||
loop = asyncio.get_event_loop()
|
||||
self.event_when_connected = "connected"
|
||||
|
||||
try:
|
||||
|
@ -290,10 +302,10 @@ class XMLStream(asyncio.BaseProtocol):
|
|||
self.dns_answers = None
|
||||
|
||||
try:
|
||||
yield from loop.create_connection(lambda: self,
|
||||
self.address[0],
|
||||
self.address[1],
|
||||
ssl=self.use_ssl)
|
||||
yield from self.loop.create_connection(lambda: self,
|
||||
self.address[0],
|
||||
self.address[1],
|
||||
ssl=self.use_ssl)
|
||||
except Socket.gaierror as e:
|
||||
self.event('connection_failed',
|
||||
'No DNS record available for %s' % self.default_domain)
|
||||
|
@ -309,17 +321,16 @@ class XMLStream(asyncio.BaseProtocol):
|
|||
function will run forever. If timeout is a number, this function
|
||||
will return after the given time in seconds.
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
if timeout is None:
|
||||
if forever:
|
||||
loop.run_forever()
|
||||
self.loop.run_forever()
|
||||
else:
|
||||
loop.run_until_complete(self.disconnected)
|
||||
self.loop.run_until_complete(self.disconnected)
|
||||
else:
|
||||
tasks = [asyncio.sleep(timeout)]
|
||||
if not forever:
|
||||
tasks.append(self.disconnected)
|
||||
loop.run_until_complete(asyncio.wait(tasks))
|
||||
self.loop.run_until_complete(asyncio.wait(tasks))
|
||||
|
||||
def init_parser(self):
|
||||
"""init the XML parser. The parser must always be reset for each new
|
||||
|
@ -367,8 +378,7 @@ class XMLStream(asyncio.BaseProtocol):
|
|||
elif self.xml_depth == 1:
|
||||
# A stanza is an XML element that is a direct child of
|
||||
# the root element, hence the check of depth == 1
|
||||
asyncio.get_event_loop().\
|
||||
idle_call(functools.partial(self.__spawn_event, xml))
|
||||
self.loop.idle_call(functools.partial(self.__spawn_event, xml))
|
||||
if self.xml_root is not None:
|
||||
# Keep the root element empty of children to
|
||||
# save on memory use.
|
||||
|
@ -461,7 +471,6 @@ class XMLStream(asyncio.BaseProtocol):
|
|||
If the handshake is successful, the XML stream will need
|
||||
to be restarted.
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
self.event_when_connected = "tls_success"
|
||||
|
||||
if self.ciphers is not None:
|
||||
|
@ -478,9 +487,9 @@ class XMLStream(asyncio.BaseProtocol):
|
|||
self.ssl_context.verify_mode = ssl.CERT_REQUIRED
|
||||
self.ssl_context.load_verify_locations(cafile=self.ca_certs)
|
||||
|
||||
ssl_connect_routine = loop.create_connection(lambda: self, ssl=self.ssl_context,
|
||||
sock=self.socket,
|
||||
server_hostname=self.address[0])
|
||||
ssl_connect_routine = self.loop.create_connection(lambda: self, ssl=self.ssl_context,
|
||||
sock=self.socket,
|
||||
server_hostname=self.address[0])
|
||||
@asyncio.coroutine
|
||||
def ssl_coro():
|
||||
try:
|
||||
|
@ -621,14 +630,15 @@ class XMLStream(asyncio.BaseProtocol):
|
|||
if port is None:
|
||||
port = self.default_port
|
||||
|
||||
resolver = default_resolver()
|
||||
resolver = default_resolver(loop=self.loop)
|
||||
self.configure_dns(resolver, domain=domain, port=port)
|
||||
|
||||
result = yield from resolve(domain, port,
|
||||
service=self.dns_service,
|
||||
resolver=resolver,
|
||||
use_ipv6=self.use_ipv6,
|
||||
use_aiodns=self.use_aiodns)
|
||||
use_aiodns=self.use_aiodns,
|
||||
loop=self.loop)
|
||||
return result
|
||||
|
||||
@asyncio.coroutine
|
||||
|
@ -746,14 +756,13 @@ class XMLStream(asyncio.BaseProtocol):
|
|||
"""
|
||||
if seconds is None:
|
||||
seconds = RESPONSE_TIMEOUT
|
||||
loop = asyncio.get_event_loop()
|
||||
cb = functools.partial(callback, *args, **kwargs)
|
||||
if repeat:
|
||||
handle = loop.call_later(seconds, self._execute_and_reschedule,
|
||||
name, cb, seconds)
|
||||
handle = self.loop.call_later(seconds, self._execute_and_reschedule,
|
||||
name, cb, seconds)
|
||||
else:
|
||||
handle = loop.call_later(seconds, self._execute_and_unschedule,
|
||||
name, cb)
|
||||
handle = self.loop.call_later(seconds, self._execute_and_unschedule,
|
||||
name, cb)
|
||||
|
||||
# Save that handle, so we can just cancel this scheduled event by
|
||||
# canceling scheduled_events[name]
|
||||
|
@ -778,9 +787,8 @@ class XMLStream(asyncio.BaseProtocol):
|
|||
be called after the given number of seconds.
|
||||
"""
|
||||
self._safe_cb_run(name, cb)
|
||||
loop = asyncio.get_event_loop()
|
||||
handle = loop.call_later(seconds, self._execute_and_reschedule,
|
||||
name, cb, seconds)
|
||||
handle = self.loop.call_later(seconds, self._execute_and_reschedule,
|
||||
name, cb, seconds)
|
||||
self.scheduled_events[name] = handle
|
||||
|
||||
def _execute_and_unschedule(self, name, cb):
|
||||
|
|
Loading…
Reference in a new issue