Switch from @asyncio.coroutine to async def everywhere.

This commit is contained in:
Emmanuel Gil Peyrot 2018-07-01 18:46:33 +02:00
parent 66909aafb3
commit 3502480384
14 changed files with 67 additions and 93 deletions

View file

@ -265,8 +265,7 @@ class ClientXMPP(BaseXMPP):
self.bindfail = False self.bindfail = False
self.features = set() self.features = set()
@asyncio.coroutine async def _handle_stream_features(self, features):
def _handle_stream_features(self, features):
"""Process the received stream features. """Process the received stream features.
:param features: The features stanza. :param features: The features stanza.
@ -275,7 +274,7 @@ class ClientXMPP(BaseXMPP):
if name in features['features']: if name in features['features']:
handler, restart = self._stream_feature_handlers[name] handler, restart = self._stream_feature_handlers[name]
if asyncio.iscoroutinefunction(handler): if asyncio.iscoroutinefunction(handler):
result = yield from handler(features) result = await handler(features)
else: else:
result = handler(features) result = handler(features)
if result and restart: if result and restart:

View file

@ -35,8 +35,7 @@ class FeatureBind(BasePlugin):
register_stanza_plugin(Iq, stanza.Bind) register_stanza_plugin(Iq, stanza.Bind)
register_stanza_plugin(StreamFeatures, stanza.Bind) register_stanza_plugin(StreamFeatures, stanza.Bind)
@asyncio.coroutine async def _handle_bind_resource(self, features):
def _handle_bind_resource(self, features):
""" """
Handle requesting a specific resource. Handle requesting a specific resource.
@ -51,7 +50,7 @@ class FeatureBind(BasePlugin):
if self.xmpp.requested_jid.resource: if self.xmpp.requested_jid.resource:
iq['bind']['resource'] = self.xmpp.requested_jid.resource iq['bind']['resource'] = self.xmpp.requested_jid.resource
yield from iq.send(callback=self._on_bind_response) await iq.send(callback=self._on_bind_response)
def _on_bind_response(self, response): def _on_bind_response(self, response):
self.xmpp.boundjid = JID(response['bind']['jid']) self.xmpp.boundjid = JID(response['bind']['jid'])

View file

@ -35,8 +35,7 @@ class FeatureSession(BasePlugin):
register_stanza_plugin(Iq, stanza.Session) register_stanza_plugin(Iq, stanza.Session)
register_stanza_plugin(StreamFeatures, stanza.Session) register_stanza_plugin(StreamFeatures, stanza.Session)
@asyncio.coroutine async def _handle_start_session(self, features):
def _handle_start_session(self, features):
""" """
Handle the start of the session. Handle the start of the session.
@ -51,7 +50,7 @@ class FeatureSession(BasePlugin):
iq = self.xmpp.Iq() iq = self.xmpp.Iq()
iq['type'] = 'set' iq['type'] = 'set'
iq.enable('session') iq.enable('session')
yield from iq.send(callback=self._on_start_session_response) await iq.send(callback=self._on_start_session_response)
def _on_start_session_response(self, response): def _on_start_session_response(self, response):
self.xmpp.features.add('session') self.xmpp.features.add('session')

View file

@ -31,8 +31,7 @@ class IBBytestream(object):
self.recv_queue = asyncio.Queue() self.recv_queue = asyncio.Queue()
@asyncio.coroutine async def send(self, data, timeout=None):
def send(self, data, timeout=None):
if not self.stream_started or self.stream_out_closed: if not self.stream_started or self.stream_out_closed:
raise socket.error raise socket.error
if len(data) > self.block_size: if len(data) > self.block_size:
@ -56,22 +55,20 @@ class IBBytestream(object):
iq['ibb_data']['sid'] = self.sid iq['ibb_data']['sid'] = self.sid
iq['ibb_data']['seq'] = seq iq['ibb_data']['seq'] = seq
iq['ibb_data']['data'] = data iq['ibb_data']['data'] = data
yield from iq.send(timeout=timeout) await iq.send(timeout=timeout)
return len(data) return len(data)
@asyncio.coroutine async def sendall(self, data, timeout=None):
def sendall(self, data, timeout=None):
sent_len = 0 sent_len = 0
while sent_len < len(data): while sent_len < len(data):
sent_len += yield from self.send(data[sent_len:self.block_size], timeout=timeout) sent_len += await self.send(data[sent_len:self.block_size], timeout=timeout)
@asyncio.coroutine async def sendfile(self, file, timeout=None):
def sendfile(self, file, timeout=None):
while True: while True:
data = file.read(self.block_size) data = file.read(self.block_size)
if not data: if not data:
break break
yield from self.send(data, timeout=timeout) await self.send(data, timeout=timeout)
def _recv_data(self, stanza): def _recv_data(self, stanza):
new_seq = stanza['ibb_data']['seq'] new_seq = stanza['ibb_data']['seq']

View file

@ -55,18 +55,17 @@ class XEP_0065(BasePlugin):
"""Returns the socket associated to the SID.""" """Returns the socket associated to the SID."""
return self._sessions.get(sid, None) return self._sessions.get(sid, None)
@asyncio.coroutine async def handshake(self, to, ifrom=None, sid=None, timeout=None):
def handshake(self, to, ifrom=None, sid=None, timeout=None):
""" Starts the handshake to establish the socks5 bytestreams """ Starts the handshake to establish the socks5 bytestreams
connection. connection.
""" """
if not self._proxies: if not self._proxies:
self._proxies = yield from self.discover_proxies() self._proxies = await self.discover_proxies()
if sid is None: if sid is None:
sid = uuid4().hex sid = uuid4().hex
used = yield from self.request_stream(to, sid=sid, ifrom=ifrom, timeout=timeout) used = await self.request_stream(to, sid=sid, ifrom=ifrom, timeout=timeout)
proxy = used['socks']['streamhost_used']['jid'] proxy = used['socks']['streamhost_used']['jid']
if proxy not in self._proxies: if proxy not in self._proxies:
@ -74,16 +73,16 @@ class XEP_0065(BasePlugin):
return return
try: try:
self._sessions[sid] = (yield from self._connect_proxy( self._sessions[sid] = (await self._connect_proxy(
self._get_dest_sha1(sid, self.xmpp.boundjid, to), self._get_dest_sha1(sid, self.xmpp.boundjid, to),
self._proxies[proxy][0], self._proxies[proxy][0],
self._proxies[proxy][1]))[1] self._proxies[proxy][1]))[1]
except socket.error: except socket.error:
return None return None
addr, port = yield from self._sessions[sid].connected addr, port = await self._sessions[sid].connected
# Request that the proxy activate the session with the target. # Request that the proxy activate the session with the target.
yield from self.activate(proxy, sid, to, timeout=timeout) await self.activate(proxy, sid, to, timeout=timeout)
sock = self.get_socket(sid) sock = self.get_socket(sid)
self.xmpp.event('stream:%s:%s' % (sid, to), sock) self.xmpp.event('stream:%s:%s' % (sid, to), sock)
return sock return sock
@ -105,8 +104,7 @@ class XEP_0065(BasePlugin):
iq['socks'].add_streamhost(proxy, host, port) iq['socks'].add_streamhost(proxy, host, port)
return iq.send(timeout=timeout, callback=callback) return iq.send(timeout=timeout, callback=callback)
@asyncio.coroutine async def discover_proxies(self, jid=None, ifrom=None, timeout=None):
def discover_proxies(self, jid=None, ifrom=None, timeout=None):
"""Auto-discover the JIDs of SOCKS5 proxies on an XMPP server.""" """Auto-discover the JIDs of SOCKS5 proxies on an XMPP server."""
if jid is None: if jid is None:
if self.xmpp.is_component: if self.xmpp.is_component:
@ -116,7 +114,7 @@ class XEP_0065(BasePlugin):
discovered = set() discovered = set()
disco_items = yield from self.xmpp['xep_0030'].get_items(jid, timeout=timeout) disco_items = await self.xmpp['xep_0030'].get_items(jid, timeout=timeout)
disco_items = {item[0] for item in disco_items['disco_items']['items']} disco_items = {item[0] for item in disco_items['disco_items']['items']}
disco_info_futures = {} disco_info_futures = {}
@ -125,7 +123,7 @@ class XEP_0065(BasePlugin):
for item in disco_items: for item in disco_items:
try: try:
disco_info = yield from disco_info_futures[item] disco_info = await disco_info_futures[item]
except XMPPError: except XMPPError:
continue continue
else: else:
@ -137,7 +135,7 @@ class XEP_0065(BasePlugin):
for jid in discovered: for jid in discovered:
try: try:
addr = yield from self.get_network_address(jid, ifrom=ifrom, timeout=timeout) addr = await self.get_network_address(jid, ifrom=ifrom, timeout=timeout)
self._proxies[jid] = (addr['socks']['streamhost']['host'], self._proxies[jid] = (addr['socks']['streamhost']['host'],
addr['socks']['streamhost']['port']) addr['socks']['streamhost']['port'])
except XMPPError: except XMPPError:
@ -182,9 +180,8 @@ class XEP_0065(BasePlugin):
streamhost['host'], streamhost['host'],
streamhost['port'])) streamhost['port']))
@asyncio.coroutine async def gather(futures, iq, streamhosts):
def gather(futures, iq, streamhosts): proxies = await asyncio.gather(*futures, return_exceptions=True)
proxies = yield from asyncio.gather(*futures, return_exceptions=True)
for streamhost, proxy in zip(streamhosts, proxies): for streamhost, proxy in zip(streamhosts, proxies):
if isinstance(proxy, ValueError): if isinstance(proxy, ValueError):
continue continue
@ -194,7 +191,7 @@ class XEP_0065(BasePlugin):
proxy = proxy[1] proxy = proxy[1]
# TODO: what if the future never happens? # TODO: what if the future never happens?
try: try:
addr, port = yield from proxy.connected addr, port = await proxy.connected
except socket.error: except socket.error:
log.exception('Socket error while connecting to the proxy.') log.exception('Socket error while connecting to the proxy.')
continue continue

View file

@ -137,8 +137,8 @@ class Socks5Protocol(asyncio.Protocol):
def resume_writing(self): def resume_writing(self):
self.paused.set_result(None) self.paused.set_result(None)
def write(self, data): async def write(self, data):
yield from self.paused await self.paused
self.transport.write(data) self.transport.write(data)
def _send_methods(self): def _send_methods(self):

View file

@ -137,8 +137,7 @@ class XEP_0115(BasePlugin):
self.xmpp.event('entity_caps', p) self.xmpp.event('entity_caps', p)
@asyncio.coroutine async def _process_caps(self, pres):
def _process_caps(self, pres):
if not pres['caps']['hash']: if not pres['caps']['hash']:
log.debug("Received unsupported legacy caps: %s, %s, %s", log.debug("Received unsupported legacy caps: %s, %s, %s",
pres['caps']['node'], pres['caps']['node'],
@ -169,7 +168,7 @@ class XEP_0115(BasePlugin):
log.debug("New caps verification string: %s", ver) log.debug("New caps verification string: %s", ver)
try: try:
node = '%s#%s' % (pres['caps']['node'], ver) node = '%s#%s' % (pres['caps']['node'], ver)
caps = yield from self.xmpp['xep_0030'].get_info(pres['from'], node, caps = await self.xmpp['xep_0030'].get_info(pres['from'], node,
coroutine=True) coroutine=True)
if isinstance(caps, Iq): if isinstance(caps, Iq):
@ -285,10 +284,9 @@ class XEP_0115(BasePlugin):
binary = hash(S.encode('utf8')).digest() binary = hash(S.encode('utf8')).digest()
return base64.b64encode(binary).decode('utf-8') return base64.b64encode(binary).decode('utf-8')
@asyncio.coroutine async def update_caps(self, jid=None, node=None, preserve=False):
def update_caps(self, jid=None, node=None, preserve=False):
try: try:
info = yield from self.xmpp['xep_0030'].get_info(jid, node, local=True) info = await self.xmpp['xep_0030'].get_info(jid, node, local=True)
if isinstance(info, Iq): if isinstance(info, Iq):
info = info['disco_info'] info = info['disco_info']
ver = self.generate_verstring(info, self.hash) ver = self.generate_verstring(info, self.hash)

View file

@ -98,10 +98,9 @@ class XEP_0153(BasePlugin):
first_future.add_done_callback(propagate_timeout_exception) first_future.add_done_callback(propagate_timeout_exception)
return future return future
@asyncio.coroutine async def _start(self, event):
def _start(self, event):
try: try:
vcard = yield from self.xmpp['xep_0054'].get_vcard(self.xmpp.boundjid.bare) vcard = await self.xmpp['xep_0054'].get_vcard(self.xmpp.boundjid.bare)
data = vcard['vcard_temp']['PHOTO']['BINVAL'] data = vcard['vcard_temp']['PHOTO']['BINVAL']
if not data: if not data:
new_hash = '' new_hash = ''

View file

@ -174,8 +174,7 @@ class XEP_0198(BasePlugin):
req = stanza.RequestAck(self.xmpp) req = stanza.RequestAck(self.xmpp)
self.xmpp.send_raw(str(req)) self.xmpp.send_raw(str(req))
@asyncio.coroutine async def _handle_sm_feature(self, features):
def _handle_sm_feature(self, features):
""" """
Enable or resume stream management. Enable or resume stream management.
@ -203,7 +202,7 @@ class XEP_0198(BasePlugin):
MatchXPath(stanza.Enabled.tag_name()), MatchXPath(stanza.Enabled.tag_name()),
MatchXPath(stanza.Failed.tag_name())])) MatchXPath(stanza.Failed.tag_name())]))
self.xmpp.register_handler(waiter) self.xmpp.register_handler(waiter)
result = yield from waiter.wait() result = await waiter.wait()
elif self.sm_id and self.allow_resume and 'bind' not in self.xmpp.features: elif self.sm_id and self.allow_resume and 'bind' not in self.xmpp.features:
self.enabled = True self.enabled = True
resume = stanza.Resume(self.xmpp) resume = stanza.Resume(self.xmpp)
@ -219,7 +218,7 @@ class XEP_0198(BasePlugin):
MatchXPath(stanza.Resumed.tag_name()), MatchXPath(stanza.Resumed.tag_name()),
MatchXPath(stanza.Failed.tag_name())])) MatchXPath(stanza.Failed.tag_name())]))
self.xmpp.register_handler(waiter) self.xmpp.register_handler(waiter)
result = yield from waiter.wait() result = await waiter.wait()
if result is not None and result.name == 'resumed': if result is not None and result.name == 'resumed':
return True return True
return False return False

View file

@ -104,11 +104,10 @@ class XEP_0199(BasePlugin):
def disable_keepalive(self, event=None): def disable_keepalive(self, event=None):
self.xmpp.cancel_schedule('Ping keepalive') self.xmpp.cancel_schedule('Ping keepalive')
@asyncio.coroutine async def _keepalive(self, event=None):
def _keepalive(self, event=None):
log.debug("Keepalive ping...") log.debug("Keepalive ping...")
try: try:
rtt = yield from self.ping(self.xmpp.boundjid.host, timeout=self.timeout) rtt = await self.ping(self.xmpp.boundjid.host, timeout=self.timeout)
except IqTimeout: except IqTimeout:
log.debug("Did not receive ping back in time." + \ log.debug("Did not receive ping back in time." + \
"Requesting Reconnect.") "Requesting Reconnect.")
@ -145,8 +144,7 @@ class XEP_0199(BasePlugin):
return iq.send(timeout=timeout, callback=callback, return iq.send(timeout=timeout, callback=callback,
timeout_callback=timeout_callback) timeout_callback=timeout_callback)
@asyncio.coroutine async def ping(self, jid=None, ifrom=None, timeout=None):
def ping(self, jid=None, ifrom=None, timeout=None):
"""Send a ping request and calculate RTT. """Send a ping request and calculate RTT.
This is a coroutine. This is a coroutine.
@ -174,7 +172,7 @@ class XEP_0199(BasePlugin):
log.debug('Pinging %s' % jid) log.debug('Pinging %s' % jid)
try: try:
yield from self.send_ping(jid, ifrom=ifrom, timeout=timeout) await self.send_ping(jid, ifrom=ifrom, timeout=timeout)
except IqError as e: except IqError as e:
if own_host: if own_host:
rtt = time.time() - start rtt = time.time() - start

View file

@ -45,10 +45,9 @@ class CoroutineCallback(BaseHandler):
if not asyncio.iscoroutinefunction(pointer): if not asyncio.iscoroutinefunction(pointer):
raise ValueError("Given function is not a coroutine") raise ValueError("Given function is not a coroutine")
@asyncio.coroutine async def pointer_wrapper(stanza, *args, **kwargs):
def pointer_wrapper(stanza, *args, **kwargs):
try: try:
yield from pointer(stanza, *args, **kwargs) await pointer(stanza, *args, **kwargs)
except Exception as e: except Exception as e:
stanza.exception(e) stanza.exception(e)

View file

@ -50,8 +50,7 @@ class Waiter(BaseHandler):
"""Do not process this handler during the main event loop.""" """Do not process this handler during the main event loop."""
pass pass
@asyncio.coroutine async def wait(self, timeout=None):
def wait(self, timeout=None):
"""Block an event handler while waiting for a stanza to arrive. """Block an event handler while waiting for a stanza to arrive.
Be aware that this will impact performance if called from a Be aware that this will impact performance if called from a
@ -70,7 +69,7 @@ class Waiter(BaseHandler):
stanza = None stanza = None
try: try:
stanza = yield from self._payload.get() stanza = await self._payload.get()
except TimeoutError: except TimeoutError:
log.warning("Timed out waiting for %s", self.name) log.warning("Timed out waiting for %s", self.name)
self.stream().remove_handler(self.name) self.stream().remove_handler(self.name)

View file

@ -45,8 +45,7 @@ def default_resolver(loop):
return None return None
@asyncio.coroutine async def resolve(host, port=None, service=None, proto='tcp',
def resolve(host, port=None, service=None, proto='tcp',
resolver=None, use_ipv6=True, use_aiodns=True, loop=None): resolver=None, use_ipv6=True, use_aiodns=True, loop=None):
"""Peform DNS resolution for a given hostname. """Peform DNS resolution for a given hostname.
@ -127,7 +126,7 @@ def resolve(host, port=None, service=None, proto='tcp',
if not service: if not service:
hosts = [(host, port)] hosts = [(host, port)]
else: else:
hosts = yield from get_SRV(host, port, service, proto, hosts = await get_SRV(host, port, service, proto,
resolver=resolver, resolver=resolver,
use_aiodns=use_aiodns) use_aiodns=use_aiodns)
if not hosts: if not hosts:
@ -141,19 +140,18 @@ def resolve(host, port=None, service=None, proto='tcp',
results.append((host, '127.0.0.1', port)) results.append((host, '127.0.0.1', port))
if use_ipv6: if use_ipv6:
aaaa = yield from get_AAAA(host, resolver=resolver, aaaa = await get_AAAA(host, resolver=resolver,
use_aiodns=use_aiodns, loop=loop) use_aiodns=use_aiodns, loop=loop)
for address in aaaa: for address in aaaa:
results.append((host, address, port)) results.append((host, address, port))
a = yield from get_A(host, resolver=resolver, a = await get_A(host, resolver=resolver,
use_aiodns=use_aiodns, loop=loop) use_aiodns=use_aiodns, loop=loop)
for address in a: for address in a:
results.append((host, address, port)) results.append((host, address, port))
return results return results
@asyncio.coroutine
def get_A(host, resolver=None, use_aiodns=True, loop=None): def get_A(host, resolver=None, use_aiodns=True, loop=None):
"""Lookup DNS A records for a given host. """Lookup DNS A records for a given host.
@ -178,7 +176,7 @@ def get_A(host, resolver=None, use_aiodns=True, loop=None):
# getaddrinfo() method. # getaddrinfo() method.
if resolver is None or not use_aiodns: if resolver is None or not use_aiodns:
try: try:
recs = yield from loop.getaddrinfo(host, None, recs = await loop.getaddrinfo(host, None,
family=socket.AF_INET, family=socket.AF_INET,
type=socket.SOCK_STREAM) type=socket.SOCK_STREAM)
return [rec[4][0] for rec in recs] return [rec[4][0] for rec in recs]
@ -189,15 +187,14 @@ def get_A(host, resolver=None, use_aiodns=True, loop=None):
# Using aiodns: # Using aiodns:
future = resolver.query(host, 'A') future = resolver.query(host, 'A')
try: try:
recs = yield from future recs = await future
except Exception as e: except Exception as e:
log.debug('DNS: Exception while querying for %s A records: %s', host, e) log.debug('DNS: Exception while querying for %s A records: %s', host, e)
recs = [] recs = []
return [rec.host for rec in recs] return [rec.host for rec in recs]
@asyncio.coroutine async def get_AAAA(host, resolver=None, use_aiodns=True, loop=None):
def get_AAAA(host, resolver=None, use_aiodns=True, loop=None):
"""Lookup DNS AAAA records for a given host. """Lookup DNS AAAA records for a given host.
If ``resolver`` is not provided, or is ``None``, then resolution will If ``resolver`` is not provided, or is ``None``, then resolution will
@ -224,7 +221,7 @@ def get_AAAA(host, resolver=None, use_aiodns=True, loop=None):
log.debug("DNS: Unable to query %s for AAAA records: IPv6 is not supported", host) log.debug("DNS: Unable to query %s for AAAA records: IPv6 is not supported", host)
return [] return []
try: try:
recs = yield from loop.getaddrinfo(host, None, recs = await loop.getaddrinfo(host, None,
family=socket.AF_INET6, family=socket.AF_INET6,
type=socket.SOCK_STREAM) type=socket.SOCK_STREAM)
return [rec[4][0] for rec in recs] return [rec[4][0] for rec in recs]
@ -236,14 +233,13 @@ def get_AAAA(host, resolver=None, use_aiodns=True, loop=None):
# Using aiodns: # Using aiodns:
future = resolver.query(host, 'AAAA') future = resolver.query(host, 'AAAA')
try: try:
recs = yield from future recs = await future
except Exception as e: except Exception as e:
log.debug('DNS: Exception while querying for %s AAAA records: %s', host, e) log.debug('DNS: Exception while querying for %s AAAA records: %s', host, e)
recs = [] recs = []
return [rec.host for rec in recs] return [rec.host for rec in recs]
@asyncio.coroutine async def get_SRV(host, port, service, proto='tcp', resolver=None, use_aiodns=True):
def get_SRV(host, port, service, proto='tcp', resolver=None, use_aiodns=True):
"""Perform SRV record resolution for a given host. """Perform SRV record resolution for a given host.
.. note:: .. note::
@ -277,7 +273,7 @@ def get_SRV(host, port, service, proto='tcp', resolver=None, use_aiodns=True):
try: try:
future = resolver.query('_%s._%s.%s' % (service, proto, host), future = resolver.query('_%s._%s.%s' % (service, proto, host),
'SRV') 'SRV')
recs = yield from future recs = await future
except Exception as e: except Exception as e:
log.debug('DNS: Exception while querying for %s SRV records: %s', host, e) log.debug('DNS: Exception while querying for %s SRV records: %s', host, e)
return [] return []

View file

@ -287,11 +287,10 @@ class XMLStream(asyncio.BaseProtocol):
self.event("connecting") self.event("connecting")
self._current_connection_attempt = asyncio.ensure_future(self._connect_routine()) self._current_connection_attempt = asyncio.ensure_future(self._connect_routine())
@asyncio.coroutine async def _connect_routine(self):
def _connect_routine(self):
self.event_when_connected = "connected" self.event_when_connected = "connected"
record = yield from self.pick_dns_answer(self.default_domain) record = await self.pick_dns_answer(self.default_domain)
if record is not None: if record is not None:
host, address, dns_port = record host, address, dns_port = record
port = dns_port if dns_port else self.address[1] port = dns_port if dns_port else self.address[1]
@ -307,9 +306,9 @@ class XMLStream(asyncio.BaseProtocol):
else: else:
ssl_context = None ssl_context = None
yield from asyncio.sleep(self.connect_loop_wait) await asyncio.sleep(self.connect_loop_wait)
try: try:
yield from self.loop.create_connection(lambda: self, await self.loop.create_connection(lambda: self,
self.address[0], self.address[0],
self.address[1], self.address[1],
ssl=ssl_context, ssl=ssl_context,
@ -540,10 +539,9 @@ class XMLStream(asyncio.BaseProtocol):
ssl_connect_routine = self.loop.create_connection(lambda: self, ssl=ssl_context, ssl_connect_routine = self.loop.create_connection(lambda: self, ssl=ssl_context,
sock=self.socket, sock=self.socket,
server_hostname=self.default_domain) server_hostname=self.default_domain)
@asyncio.coroutine async def ssl_coro():
def ssl_coro():
try: try:
transp, prot = yield from ssl_connect_routine transp, prot = await ssl_connect_routine
except ssl.SSLError as e: except ssl.SSLError as e:
log.debug('SSL: Unable to connect', exc_info=True) log.debug('SSL: Unable to connect', exc_info=True)
log.error('CERT: Invalid certificate trust chain.') log.error('CERT: Invalid certificate trust chain.')
@ -671,8 +669,7 @@ class XMLStream(asyncio.BaseProtocol):
idx += 1 idx += 1
return False return False
@asyncio.coroutine async def get_dns_records(self, domain, port=None):
def get_dns_records(self, domain, port=None):
"""Get the DNS records for a domain. """Get the DNS records for a domain.
:param domain: The domain in question. :param domain: The domain in question.
@ -684,7 +681,7 @@ class XMLStream(asyncio.BaseProtocol):
resolver = default_resolver(loop=self.loop) resolver = default_resolver(loop=self.loop)
self.configure_dns(resolver, domain=domain, port=port) self.configure_dns(resolver, domain=domain, port=port)
result = yield from resolve(domain, port, result = await resolve(domain, port,
service=self.dns_service, service=self.dns_service,
resolver=resolver, resolver=resolver,
use_ipv6=self.use_ipv6, use_ipv6=self.use_ipv6,
@ -692,8 +689,7 @@ class XMLStream(asyncio.BaseProtocol):
loop=self.loop) loop=self.loop)
return result return result
@asyncio.coroutine async def pick_dns_answer(self, domain, port=None):
def pick_dns_answer(self, domain, port=None):
"""Pick a server and port from DNS answers. """Pick a server and port from DNS answers.
Gets DNS answers if none available. Gets DNS answers if none available.
@ -703,7 +699,7 @@ class XMLStream(asyncio.BaseProtocol):
:param port: If the results don't include a port, use this one. :param port: If the results don't include a port, use this one.
""" """
if self.dns_answers is None: if self.dns_answers is None:
dns_records = yield from self.get_dns_records(domain, port) dns_records = await self.get_dns_records(domain, port)
self.dns_answers = iter(dns_records) self.dns_answers = iter(dns_records)
try: try:
@ -768,10 +764,9 @@ class XMLStream(asyncio.BaseProtocol):
# If the callback is a coroutine, schedule it instead of # If the callback is a coroutine, schedule it instead of
# running it directly # running it directly
if asyncio.iscoroutinefunction(handler_callback): if asyncio.iscoroutinefunction(handler_callback):
@asyncio.coroutine async def handler_callback_routine(cb):
def handler_callback_routine(cb):
try: try:
yield from cb(data) await cb(data)
except Exception as e: except Exception as e:
if old_exception: if old_exception:
old_exception(e) old_exception(e)