Allow the use of a custom loop instead of asyncio.get_event_loop()

This commit is contained in:
mathieui 2015-05-12 00:02:32 +02:00
parent f1e6d6b0a9
commit a2852eb249
No known key found for this signature in database
GPG key ID: C59F84CEEFD616E3
3 changed files with 42 additions and 36 deletions

View file

@ -332,7 +332,7 @@ class XEP_0325(BasePlugin):
self.sessions[session]["nodeDone"][node] = False
for node in self.sessions[session]["node_list"]:
timer = asyncio.get_event_loop().call_later(self.nodes[node]['commTimeout'], partial(self._event_comm_timeout, args=(session, node)))
timer = self.xmpp.loop.call_later(self.nodes[node]['commTimeout'], partial(self._event_comm_timeout, args=(session, node)))
self.sessions[session]["commTimers"][node] = timer
self.nodes[node]['device'].set_control_fields(process_fields, session=session, callback=self._device_set_command_callback)

View file

@ -32,14 +32,14 @@ except ImportError as e:
"Not all features will be available")
def default_resolver():
def default_resolver(loop):
"""Return a basic DNS resolver object.
:returns: A :class:`aiodns.DNSResolver` object if aiodns
is available. Otherwise, ``None``.
"""
if AIODNS_AVAILABLE:
return aiodns.DNSResolver(loop=asyncio.get_event_loop(),
return aiodns.DNSResolver(loop=loop,
tries=1,
timeout=1.0)
return None
@ -47,7 +47,7 @@ def default_resolver():
@asyncio.coroutine
def resolve(host, port=None, service=None, proto='tcp',
resolver=None, use_ipv6=True, use_aiodns=True):
resolver=None, use_ipv6=True, use_aiodns=True, loop=None):
"""Peform DNS resolution for a given hostname.
Resolution may perform SRV record lookups if a service and protocol
@ -97,7 +97,7 @@ def resolve(host, port=None, service=None, proto='tcp',
log.debug("DNS: Use of IPv6 has been disabled.")
if resolver is None and AIODNS_AVAILABLE and use_aiodns:
resolver = aiodns.DNSResolver(loop=asyncio.get_event_loop())
resolver = aiodns.DNSResolver(loop=loop)
# An IPv6 literal is allowed to be enclosed in square brackets, but
# the brackets must be stripped in order to process the literal;
@ -142,19 +142,19 @@ def resolve(host, port=None, service=None, proto='tcp',
if use_ipv6:
aaaa = yield from get_AAAA(host, resolver=resolver,
use_aiodns=use_aiodns)
use_aiodns=use_aiodns, loop=loop)
for address in aaaa:
results.append((host, address, port))
a = yield from get_A(host, resolver=resolver,
use_aiodns=use_aiodns)
use_aiodns=use_aiodns, loop=loop)
for address in a:
results.append((host, address, port))
return results
@asyncio.coroutine
def get_A(host, resolver=None, use_aiodns=True):
def get_A(host, resolver=None, use_aiodns=True, loop=None):
"""Lookup DNS A records for a given host.
If ``resolver`` is not provided, or is ``None``, then resolution will
@ -177,7 +177,6 @@ def get_A(host, resolver=None, use_aiodns=True):
# If not using aiodns, attempt lookup using the OS level
# getaddrinfo() method.
if resolver is None or not use_aiodns:
loop = asyncio.get_event_loop()
try:
recs = yield from loop.getaddrinfo(host, None,
family=socket.AF_INET,
@ -198,7 +197,7 @@ def get_A(host, resolver=None, use_aiodns=True):
@asyncio.coroutine
def get_AAAA(host, resolver=None, use_aiodns=True):
def get_AAAA(host, resolver=None, use_aiodns=True, loop=None):
"""Lookup DNS AAAA records for a given host.
If ``resolver`` is not provided, or is ``None``, then resolution will
@ -224,7 +223,6 @@ def get_AAAA(host, resolver=None, use_aiodns=True):
if not socket.has_ipv6:
log.debug("DNS: Unable to query %s for AAAA records: IPv6 is not supported", host)
return []
loop = asyncio.get_event_loop()
try:
recs = yield from loop.getaddrinfo(host, None,
family=socket.AF_INET6,

View file

@ -116,6 +116,9 @@ class XMLStream(asyncio.BaseProtocol):
self._der_cert = None
# The asyncio event loop
self._loop = None
#: The default port to return when querying DNS records.
self.default_port = int(port)
@ -213,6 +216,16 @@ class XMLStream(asyncio.BaseProtocol):
self.add_event_handler('disconnected', self._remove_schedules)
self.add_event_handler('session_start', self._start_keepalive)
@property
def loop(self):
if self._loop is None:
self._loop = asyncio.get_event_loop()
return self._loop
@loop.setter
def loop(self, value):
self._loop = value
def new_id(self):
"""Generate and return a new stream ID in hexadecimal form.
@ -270,7 +283,6 @@ class XMLStream(asyncio.BaseProtocol):
@asyncio.coroutine
def _connect_routine(self):
loop = asyncio.get_event_loop()
self.event_when_connected = "connected"
try:
@ -290,10 +302,10 @@ class XMLStream(asyncio.BaseProtocol):
self.dns_answers = None
try:
yield from loop.create_connection(lambda: self,
self.address[0],
self.address[1],
ssl=self.use_ssl)
yield from self.loop.create_connection(lambda: self,
self.address[0],
self.address[1],
ssl=self.use_ssl)
except Socket.gaierror as e:
self.event('connection_failed',
'No DNS record available for %s' % self.default_domain)
@ -309,17 +321,16 @@ class XMLStream(asyncio.BaseProtocol):
function will run forever. If timeout is a number, this function
will return after the given time in seconds.
"""
loop = asyncio.get_event_loop()
if timeout is None:
if forever:
loop.run_forever()
self.loop.run_forever()
else:
loop.run_until_complete(self.disconnected)
self.loop.run_until_complete(self.disconnected)
else:
tasks = [asyncio.sleep(timeout)]
if not forever:
tasks.append(self.disconnected)
loop.run_until_complete(asyncio.wait(tasks))
self.loop.run_until_complete(asyncio.wait(tasks))
def init_parser(self):
"""init the XML parser. The parser must always be reset for each new
@ -367,8 +378,7 @@ class XMLStream(asyncio.BaseProtocol):
elif self.xml_depth == 1:
# A stanza is an XML element that is a direct child of
# the root element, hence the check of depth == 1
asyncio.get_event_loop().\
idle_call(functools.partial(self.__spawn_event, xml))
self.loop.idle_call(functools.partial(self.__spawn_event, xml))
if self.xml_root is not None:
# Keep the root element empty of children to
# save on memory use.
@ -461,7 +471,6 @@ class XMLStream(asyncio.BaseProtocol):
If the handshake is successful, the XML stream will need
to be restarted.
"""
loop = asyncio.get_event_loop()
self.event_when_connected = "tls_success"
if self.ciphers is not None:
@ -478,9 +487,9 @@ class XMLStream(asyncio.BaseProtocol):
self.ssl_context.verify_mode = ssl.CERT_REQUIRED
self.ssl_context.load_verify_locations(cafile=self.ca_certs)
ssl_connect_routine = loop.create_connection(lambda: self, ssl=self.ssl_context,
sock=self.socket,
server_hostname=self.address[0])
ssl_connect_routine = self.loop.create_connection(lambda: self, ssl=self.ssl_context,
sock=self.socket,
server_hostname=self.address[0])
@asyncio.coroutine
def ssl_coro():
try:
@ -621,14 +630,15 @@ class XMLStream(asyncio.BaseProtocol):
if port is None:
port = self.default_port
resolver = default_resolver()
resolver = default_resolver(loop=self.loop)
self.configure_dns(resolver, domain=domain, port=port)
result = yield from resolve(domain, port,
service=self.dns_service,
resolver=resolver,
use_ipv6=self.use_ipv6,
use_aiodns=self.use_aiodns)
use_aiodns=self.use_aiodns,
loop=self.loop)
return result
@asyncio.coroutine
@ -746,14 +756,13 @@ class XMLStream(asyncio.BaseProtocol):
"""
if seconds is None:
seconds = RESPONSE_TIMEOUT
loop = asyncio.get_event_loop()
cb = functools.partial(callback, *args, **kwargs)
if repeat:
handle = loop.call_later(seconds, self._execute_and_reschedule,
name, cb, seconds)
handle = self.loop.call_later(seconds, self._execute_and_reschedule,
name, cb, seconds)
else:
handle = loop.call_later(seconds, self._execute_and_unschedule,
name, cb)
handle = self.loop.call_later(seconds, self._execute_and_unschedule,
name, cb)
# Save that handle, so we can just cancel this scheduled event by
# canceling scheduled_events[name]
@ -778,9 +787,8 @@ class XMLStream(asyncio.BaseProtocol):
be called after the given number of seconds.
"""
self._safe_cb_run(name, cb)
loop = asyncio.get_event_loop()
handle = loop.call_later(seconds, self._execute_and_reschedule,
name, cb, seconds)
handle = self.loop.call_later(seconds, self._execute_and_reschedule,
name, cb, seconds)
self.scheduled_events[name] = handle
def _execute_and_unschedule(self, name, cb):