Switch from @asyncio.coroutine to async def everywhere.
This commit is contained in:
parent
66909aafb3
commit
3502480384
14 changed files with 67 additions and 93 deletions
|
@ -265,8 +265,7 @@ class ClientXMPP(BaseXMPP):
|
||||||
self.bindfail = False
|
self.bindfail = False
|
||||||
self.features = set()
|
self.features = set()
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def _handle_stream_features(self, features):
|
||||||
def _handle_stream_features(self, features):
|
|
||||||
"""Process the received stream features.
|
"""Process the received stream features.
|
||||||
|
|
||||||
:param features: The features stanza.
|
:param features: The features stanza.
|
||||||
|
@ -275,7 +274,7 @@ class ClientXMPP(BaseXMPP):
|
||||||
if name in features['features']:
|
if name in features['features']:
|
||||||
handler, restart = self._stream_feature_handlers[name]
|
handler, restart = self._stream_feature_handlers[name]
|
||||||
if asyncio.iscoroutinefunction(handler):
|
if asyncio.iscoroutinefunction(handler):
|
||||||
result = yield from handler(features)
|
result = await handler(features)
|
||||||
else:
|
else:
|
||||||
result = handler(features)
|
result = handler(features)
|
||||||
if result and restart:
|
if result and restart:
|
||||||
|
|
|
@ -35,8 +35,7 @@ class FeatureBind(BasePlugin):
|
||||||
register_stanza_plugin(Iq, stanza.Bind)
|
register_stanza_plugin(Iq, stanza.Bind)
|
||||||
register_stanza_plugin(StreamFeatures, stanza.Bind)
|
register_stanza_plugin(StreamFeatures, stanza.Bind)
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def _handle_bind_resource(self, features):
|
||||||
def _handle_bind_resource(self, features):
|
|
||||||
"""
|
"""
|
||||||
Handle requesting a specific resource.
|
Handle requesting a specific resource.
|
||||||
|
|
||||||
|
@ -51,7 +50,7 @@ class FeatureBind(BasePlugin):
|
||||||
if self.xmpp.requested_jid.resource:
|
if self.xmpp.requested_jid.resource:
|
||||||
iq['bind']['resource'] = self.xmpp.requested_jid.resource
|
iq['bind']['resource'] = self.xmpp.requested_jid.resource
|
||||||
|
|
||||||
yield from iq.send(callback=self._on_bind_response)
|
await iq.send(callback=self._on_bind_response)
|
||||||
|
|
||||||
def _on_bind_response(self, response):
|
def _on_bind_response(self, response):
|
||||||
self.xmpp.boundjid = JID(response['bind']['jid'])
|
self.xmpp.boundjid = JID(response['bind']['jid'])
|
||||||
|
|
|
@ -35,8 +35,7 @@ class FeatureSession(BasePlugin):
|
||||||
register_stanza_plugin(Iq, stanza.Session)
|
register_stanza_plugin(Iq, stanza.Session)
|
||||||
register_stanza_plugin(StreamFeatures, stanza.Session)
|
register_stanza_plugin(StreamFeatures, stanza.Session)
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def _handle_start_session(self, features):
|
||||||
def _handle_start_session(self, features):
|
|
||||||
"""
|
"""
|
||||||
Handle the start of the session.
|
Handle the start of the session.
|
||||||
|
|
||||||
|
@ -51,7 +50,7 @@ class FeatureSession(BasePlugin):
|
||||||
iq = self.xmpp.Iq()
|
iq = self.xmpp.Iq()
|
||||||
iq['type'] = 'set'
|
iq['type'] = 'set'
|
||||||
iq.enable('session')
|
iq.enable('session')
|
||||||
yield from iq.send(callback=self._on_start_session_response)
|
await iq.send(callback=self._on_start_session_response)
|
||||||
|
|
||||||
def _on_start_session_response(self, response):
|
def _on_start_session_response(self, response):
|
||||||
self.xmpp.features.add('session')
|
self.xmpp.features.add('session')
|
||||||
|
|
|
@ -31,8 +31,7 @@ class IBBytestream(object):
|
||||||
|
|
||||||
self.recv_queue = asyncio.Queue()
|
self.recv_queue = asyncio.Queue()
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def send(self, data, timeout=None):
|
||||||
def send(self, data, timeout=None):
|
|
||||||
if not self.stream_started or self.stream_out_closed:
|
if not self.stream_started or self.stream_out_closed:
|
||||||
raise socket.error
|
raise socket.error
|
||||||
if len(data) > self.block_size:
|
if len(data) > self.block_size:
|
||||||
|
@ -56,22 +55,20 @@ class IBBytestream(object):
|
||||||
iq['ibb_data']['sid'] = self.sid
|
iq['ibb_data']['sid'] = self.sid
|
||||||
iq['ibb_data']['seq'] = seq
|
iq['ibb_data']['seq'] = seq
|
||||||
iq['ibb_data']['data'] = data
|
iq['ibb_data']['data'] = data
|
||||||
yield from iq.send(timeout=timeout)
|
await iq.send(timeout=timeout)
|
||||||
return len(data)
|
return len(data)
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def sendall(self, data, timeout=None):
|
||||||
def sendall(self, data, timeout=None):
|
|
||||||
sent_len = 0
|
sent_len = 0
|
||||||
while sent_len < len(data):
|
while sent_len < len(data):
|
||||||
sent_len += yield from self.send(data[sent_len:self.block_size], timeout=timeout)
|
sent_len += await self.send(data[sent_len:self.block_size], timeout=timeout)
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def sendfile(self, file, timeout=None):
|
||||||
def sendfile(self, file, timeout=None):
|
|
||||||
while True:
|
while True:
|
||||||
data = file.read(self.block_size)
|
data = file.read(self.block_size)
|
||||||
if not data:
|
if not data:
|
||||||
break
|
break
|
||||||
yield from self.send(data, timeout=timeout)
|
await self.send(data, timeout=timeout)
|
||||||
|
|
||||||
def _recv_data(self, stanza):
|
def _recv_data(self, stanza):
|
||||||
new_seq = stanza['ibb_data']['seq']
|
new_seq = stanza['ibb_data']['seq']
|
||||||
|
|
|
@ -55,18 +55,17 @@ class XEP_0065(BasePlugin):
|
||||||
"""Returns the socket associated to the SID."""
|
"""Returns the socket associated to the SID."""
|
||||||
return self._sessions.get(sid, None)
|
return self._sessions.get(sid, None)
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def handshake(self, to, ifrom=None, sid=None, timeout=None):
|
||||||
def handshake(self, to, ifrom=None, sid=None, timeout=None):
|
|
||||||
""" Starts the handshake to establish the socks5 bytestreams
|
""" Starts the handshake to establish the socks5 bytestreams
|
||||||
connection.
|
connection.
|
||||||
"""
|
"""
|
||||||
if not self._proxies:
|
if not self._proxies:
|
||||||
self._proxies = yield from self.discover_proxies()
|
self._proxies = await self.discover_proxies()
|
||||||
|
|
||||||
if sid is None:
|
if sid is None:
|
||||||
sid = uuid4().hex
|
sid = uuid4().hex
|
||||||
|
|
||||||
used = yield from self.request_stream(to, sid=sid, ifrom=ifrom, timeout=timeout)
|
used = await self.request_stream(to, sid=sid, ifrom=ifrom, timeout=timeout)
|
||||||
proxy = used['socks']['streamhost_used']['jid']
|
proxy = used['socks']['streamhost_used']['jid']
|
||||||
|
|
||||||
if proxy not in self._proxies:
|
if proxy not in self._proxies:
|
||||||
|
@ -74,16 +73,16 @@ class XEP_0065(BasePlugin):
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._sessions[sid] = (yield from self._connect_proxy(
|
self._sessions[sid] = (await self._connect_proxy(
|
||||||
self._get_dest_sha1(sid, self.xmpp.boundjid, to),
|
self._get_dest_sha1(sid, self.xmpp.boundjid, to),
|
||||||
self._proxies[proxy][0],
|
self._proxies[proxy][0],
|
||||||
self._proxies[proxy][1]))[1]
|
self._proxies[proxy][1]))[1]
|
||||||
except socket.error:
|
except socket.error:
|
||||||
return None
|
return None
|
||||||
addr, port = yield from self._sessions[sid].connected
|
addr, port = await self._sessions[sid].connected
|
||||||
|
|
||||||
# Request that the proxy activate the session with the target.
|
# Request that the proxy activate the session with the target.
|
||||||
yield from self.activate(proxy, sid, to, timeout=timeout)
|
await self.activate(proxy, sid, to, timeout=timeout)
|
||||||
sock = self.get_socket(sid)
|
sock = self.get_socket(sid)
|
||||||
self.xmpp.event('stream:%s:%s' % (sid, to), sock)
|
self.xmpp.event('stream:%s:%s' % (sid, to), sock)
|
||||||
return sock
|
return sock
|
||||||
|
@ -105,8 +104,7 @@ class XEP_0065(BasePlugin):
|
||||||
iq['socks'].add_streamhost(proxy, host, port)
|
iq['socks'].add_streamhost(proxy, host, port)
|
||||||
return iq.send(timeout=timeout, callback=callback)
|
return iq.send(timeout=timeout, callback=callback)
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def discover_proxies(self, jid=None, ifrom=None, timeout=None):
|
||||||
def discover_proxies(self, jid=None, ifrom=None, timeout=None):
|
|
||||||
"""Auto-discover the JIDs of SOCKS5 proxies on an XMPP server."""
|
"""Auto-discover the JIDs of SOCKS5 proxies on an XMPP server."""
|
||||||
if jid is None:
|
if jid is None:
|
||||||
if self.xmpp.is_component:
|
if self.xmpp.is_component:
|
||||||
|
@ -116,7 +114,7 @@ class XEP_0065(BasePlugin):
|
||||||
|
|
||||||
discovered = set()
|
discovered = set()
|
||||||
|
|
||||||
disco_items = yield from self.xmpp['xep_0030'].get_items(jid, timeout=timeout)
|
disco_items = await self.xmpp['xep_0030'].get_items(jid, timeout=timeout)
|
||||||
disco_items = {item[0] for item in disco_items['disco_items']['items']}
|
disco_items = {item[0] for item in disco_items['disco_items']['items']}
|
||||||
|
|
||||||
disco_info_futures = {}
|
disco_info_futures = {}
|
||||||
|
@ -125,7 +123,7 @@ class XEP_0065(BasePlugin):
|
||||||
|
|
||||||
for item in disco_items:
|
for item in disco_items:
|
||||||
try:
|
try:
|
||||||
disco_info = yield from disco_info_futures[item]
|
disco_info = await disco_info_futures[item]
|
||||||
except XMPPError:
|
except XMPPError:
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
|
@ -137,7 +135,7 @@ class XEP_0065(BasePlugin):
|
||||||
|
|
||||||
for jid in discovered:
|
for jid in discovered:
|
||||||
try:
|
try:
|
||||||
addr = yield from self.get_network_address(jid, ifrom=ifrom, timeout=timeout)
|
addr = await self.get_network_address(jid, ifrom=ifrom, timeout=timeout)
|
||||||
self._proxies[jid] = (addr['socks']['streamhost']['host'],
|
self._proxies[jid] = (addr['socks']['streamhost']['host'],
|
||||||
addr['socks']['streamhost']['port'])
|
addr['socks']['streamhost']['port'])
|
||||||
except XMPPError:
|
except XMPPError:
|
||||||
|
@ -182,9 +180,8 @@ class XEP_0065(BasePlugin):
|
||||||
streamhost['host'],
|
streamhost['host'],
|
||||||
streamhost['port']))
|
streamhost['port']))
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def gather(futures, iq, streamhosts):
|
||||||
def gather(futures, iq, streamhosts):
|
proxies = await asyncio.gather(*futures, return_exceptions=True)
|
||||||
proxies = yield from asyncio.gather(*futures, return_exceptions=True)
|
|
||||||
for streamhost, proxy in zip(streamhosts, proxies):
|
for streamhost, proxy in zip(streamhosts, proxies):
|
||||||
if isinstance(proxy, ValueError):
|
if isinstance(proxy, ValueError):
|
||||||
continue
|
continue
|
||||||
|
@ -194,7 +191,7 @@ class XEP_0065(BasePlugin):
|
||||||
proxy = proxy[1]
|
proxy = proxy[1]
|
||||||
# TODO: what if the future never happens?
|
# TODO: what if the future never happens?
|
||||||
try:
|
try:
|
||||||
addr, port = yield from proxy.connected
|
addr, port = await proxy.connected
|
||||||
except socket.error:
|
except socket.error:
|
||||||
log.exception('Socket error while connecting to the proxy.')
|
log.exception('Socket error while connecting to the proxy.')
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -137,8 +137,8 @@ class Socks5Protocol(asyncio.Protocol):
|
||||||
def resume_writing(self):
|
def resume_writing(self):
|
||||||
self.paused.set_result(None)
|
self.paused.set_result(None)
|
||||||
|
|
||||||
def write(self, data):
|
async def write(self, data):
|
||||||
yield from self.paused
|
await self.paused
|
||||||
self.transport.write(data)
|
self.transport.write(data)
|
||||||
|
|
||||||
def _send_methods(self):
|
def _send_methods(self):
|
||||||
|
|
|
@ -137,8 +137,7 @@ class XEP_0115(BasePlugin):
|
||||||
|
|
||||||
self.xmpp.event('entity_caps', p)
|
self.xmpp.event('entity_caps', p)
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def _process_caps(self, pres):
|
||||||
def _process_caps(self, pres):
|
|
||||||
if not pres['caps']['hash']:
|
if not pres['caps']['hash']:
|
||||||
log.debug("Received unsupported legacy caps: %s, %s, %s",
|
log.debug("Received unsupported legacy caps: %s, %s, %s",
|
||||||
pres['caps']['node'],
|
pres['caps']['node'],
|
||||||
|
@ -169,7 +168,7 @@ class XEP_0115(BasePlugin):
|
||||||
log.debug("New caps verification string: %s", ver)
|
log.debug("New caps verification string: %s", ver)
|
||||||
try:
|
try:
|
||||||
node = '%s#%s' % (pres['caps']['node'], ver)
|
node = '%s#%s' % (pres['caps']['node'], ver)
|
||||||
caps = yield from self.xmpp['xep_0030'].get_info(pres['from'], node,
|
caps = await self.xmpp['xep_0030'].get_info(pres['from'], node,
|
||||||
coroutine=True)
|
coroutine=True)
|
||||||
|
|
||||||
if isinstance(caps, Iq):
|
if isinstance(caps, Iq):
|
||||||
|
@ -285,10 +284,9 @@ class XEP_0115(BasePlugin):
|
||||||
binary = hash(S.encode('utf8')).digest()
|
binary = hash(S.encode('utf8')).digest()
|
||||||
return base64.b64encode(binary).decode('utf-8')
|
return base64.b64encode(binary).decode('utf-8')
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def update_caps(self, jid=None, node=None, preserve=False):
|
||||||
def update_caps(self, jid=None, node=None, preserve=False):
|
|
||||||
try:
|
try:
|
||||||
info = yield from self.xmpp['xep_0030'].get_info(jid, node, local=True)
|
info = await self.xmpp['xep_0030'].get_info(jid, node, local=True)
|
||||||
if isinstance(info, Iq):
|
if isinstance(info, Iq):
|
||||||
info = info['disco_info']
|
info = info['disco_info']
|
||||||
ver = self.generate_verstring(info, self.hash)
|
ver = self.generate_verstring(info, self.hash)
|
||||||
|
|
|
@ -98,10 +98,9 @@ class XEP_0153(BasePlugin):
|
||||||
first_future.add_done_callback(propagate_timeout_exception)
|
first_future.add_done_callback(propagate_timeout_exception)
|
||||||
return future
|
return future
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def _start(self, event):
|
||||||
def _start(self, event):
|
|
||||||
try:
|
try:
|
||||||
vcard = yield from self.xmpp['xep_0054'].get_vcard(self.xmpp.boundjid.bare)
|
vcard = await self.xmpp['xep_0054'].get_vcard(self.xmpp.boundjid.bare)
|
||||||
data = vcard['vcard_temp']['PHOTO']['BINVAL']
|
data = vcard['vcard_temp']['PHOTO']['BINVAL']
|
||||||
if not data:
|
if not data:
|
||||||
new_hash = ''
|
new_hash = ''
|
||||||
|
|
|
@ -174,8 +174,7 @@ class XEP_0198(BasePlugin):
|
||||||
req = stanza.RequestAck(self.xmpp)
|
req = stanza.RequestAck(self.xmpp)
|
||||||
self.xmpp.send_raw(str(req))
|
self.xmpp.send_raw(str(req))
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def _handle_sm_feature(self, features):
|
||||||
def _handle_sm_feature(self, features):
|
|
||||||
"""
|
"""
|
||||||
Enable or resume stream management.
|
Enable or resume stream management.
|
||||||
|
|
||||||
|
@ -203,7 +202,7 @@ class XEP_0198(BasePlugin):
|
||||||
MatchXPath(stanza.Enabled.tag_name()),
|
MatchXPath(stanza.Enabled.tag_name()),
|
||||||
MatchXPath(stanza.Failed.tag_name())]))
|
MatchXPath(stanza.Failed.tag_name())]))
|
||||||
self.xmpp.register_handler(waiter)
|
self.xmpp.register_handler(waiter)
|
||||||
result = yield from waiter.wait()
|
result = await waiter.wait()
|
||||||
elif self.sm_id and self.allow_resume and 'bind' not in self.xmpp.features:
|
elif self.sm_id and self.allow_resume and 'bind' not in self.xmpp.features:
|
||||||
self.enabled = True
|
self.enabled = True
|
||||||
resume = stanza.Resume(self.xmpp)
|
resume = stanza.Resume(self.xmpp)
|
||||||
|
@ -219,7 +218,7 @@ class XEP_0198(BasePlugin):
|
||||||
MatchXPath(stanza.Resumed.tag_name()),
|
MatchXPath(stanza.Resumed.tag_name()),
|
||||||
MatchXPath(stanza.Failed.tag_name())]))
|
MatchXPath(stanza.Failed.tag_name())]))
|
||||||
self.xmpp.register_handler(waiter)
|
self.xmpp.register_handler(waiter)
|
||||||
result = yield from waiter.wait()
|
result = await waiter.wait()
|
||||||
if result is not None and result.name == 'resumed':
|
if result is not None and result.name == 'resumed':
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
|
@ -104,11 +104,10 @@ class XEP_0199(BasePlugin):
|
||||||
def disable_keepalive(self, event=None):
|
def disable_keepalive(self, event=None):
|
||||||
self.xmpp.cancel_schedule('Ping keepalive')
|
self.xmpp.cancel_schedule('Ping keepalive')
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def _keepalive(self, event=None):
|
||||||
def _keepalive(self, event=None):
|
|
||||||
log.debug("Keepalive ping...")
|
log.debug("Keepalive ping...")
|
||||||
try:
|
try:
|
||||||
rtt = yield from self.ping(self.xmpp.boundjid.host, timeout=self.timeout)
|
rtt = await self.ping(self.xmpp.boundjid.host, timeout=self.timeout)
|
||||||
except IqTimeout:
|
except IqTimeout:
|
||||||
log.debug("Did not receive ping back in time." + \
|
log.debug("Did not receive ping back in time." + \
|
||||||
"Requesting Reconnect.")
|
"Requesting Reconnect.")
|
||||||
|
@ -145,8 +144,7 @@ class XEP_0199(BasePlugin):
|
||||||
return iq.send(timeout=timeout, callback=callback,
|
return iq.send(timeout=timeout, callback=callback,
|
||||||
timeout_callback=timeout_callback)
|
timeout_callback=timeout_callback)
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def ping(self, jid=None, ifrom=None, timeout=None):
|
||||||
def ping(self, jid=None, ifrom=None, timeout=None):
|
|
||||||
"""Send a ping request and calculate RTT.
|
"""Send a ping request and calculate RTT.
|
||||||
This is a coroutine.
|
This is a coroutine.
|
||||||
|
|
||||||
|
@ -174,7 +172,7 @@ class XEP_0199(BasePlugin):
|
||||||
|
|
||||||
log.debug('Pinging %s' % jid)
|
log.debug('Pinging %s' % jid)
|
||||||
try:
|
try:
|
||||||
yield from self.send_ping(jid, ifrom=ifrom, timeout=timeout)
|
await self.send_ping(jid, ifrom=ifrom, timeout=timeout)
|
||||||
except IqError as e:
|
except IqError as e:
|
||||||
if own_host:
|
if own_host:
|
||||||
rtt = time.time() - start
|
rtt = time.time() - start
|
||||||
|
|
|
@ -45,10 +45,9 @@ class CoroutineCallback(BaseHandler):
|
||||||
if not asyncio.iscoroutinefunction(pointer):
|
if not asyncio.iscoroutinefunction(pointer):
|
||||||
raise ValueError("Given function is not a coroutine")
|
raise ValueError("Given function is not a coroutine")
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def pointer_wrapper(stanza, *args, **kwargs):
|
||||||
def pointer_wrapper(stanza, *args, **kwargs):
|
|
||||||
try:
|
try:
|
||||||
yield from pointer(stanza, *args, **kwargs)
|
await pointer(stanza, *args, **kwargs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
stanza.exception(e)
|
stanza.exception(e)
|
||||||
|
|
||||||
|
|
|
@ -50,8 +50,7 @@ class Waiter(BaseHandler):
|
||||||
"""Do not process this handler during the main event loop."""
|
"""Do not process this handler during the main event loop."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def wait(self, timeout=None):
|
||||||
def wait(self, timeout=None):
|
|
||||||
"""Block an event handler while waiting for a stanza to arrive.
|
"""Block an event handler while waiting for a stanza to arrive.
|
||||||
|
|
||||||
Be aware that this will impact performance if called from a
|
Be aware that this will impact performance if called from a
|
||||||
|
@ -70,7 +69,7 @@ class Waiter(BaseHandler):
|
||||||
|
|
||||||
stanza = None
|
stanza = None
|
||||||
try:
|
try:
|
||||||
stanza = yield from self._payload.get()
|
stanza = await self._payload.get()
|
||||||
except TimeoutError:
|
except TimeoutError:
|
||||||
log.warning("Timed out waiting for %s", self.name)
|
log.warning("Timed out waiting for %s", self.name)
|
||||||
self.stream().remove_handler(self.name)
|
self.stream().remove_handler(self.name)
|
||||||
|
|
|
@ -45,8 +45,7 @@ def default_resolver(loop):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def resolve(host, port=None, service=None, proto='tcp',
|
||||||
def resolve(host, port=None, service=None, proto='tcp',
|
|
||||||
resolver=None, use_ipv6=True, use_aiodns=True, loop=None):
|
resolver=None, use_ipv6=True, use_aiodns=True, loop=None):
|
||||||
"""Peform DNS resolution for a given hostname.
|
"""Peform DNS resolution for a given hostname.
|
||||||
|
|
||||||
|
@ -127,7 +126,7 @@ def resolve(host, port=None, service=None, proto='tcp',
|
||||||
if not service:
|
if not service:
|
||||||
hosts = [(host, port)]
|
hosts = [(host, port)]
|
||||||
else:
|
else:
|
||||||
hosts = yield from get_SRV(host, port, service, proto,
|
hosts = await get_SRV(host, port, service, proto,
|
||||||
resolver=resolver,
|
resolver=resolver,
|
||||||
use_aiodns=use_aiodns)
|
use_aiodns=use_aiodns)
|
||||||
if not hosts:
|
if not hosts:
|
||||||
|
@ -141,19 +140,18 @@ def resolve(host, port=None, service=None, proto='tcp',
|
||||||
results.append((host, '127.0.0.1', port))
|
results.append((host, '127.0.0.1', port))
|
||||||
|
|
||||||
if use_ipv6:
|
if use_ipv6:
|
||||||
aaaa = yield from get_AAAA(host, resolver=resolver,
|
aaaa = await get_AAAA(host, resolver=resolver,
|
||||||
use_aiodns=use_aiodns, loop=loop)
|
use_aiodns=use_aiodns, loop=loop)
|
||||||
for address in aaaa:
|
for address in aaaa:
|
||||||
results.append((host, address, port))
|
results.append((host, address, port))
|
||||||
|
|
||||||
a = yield from get_A(host, resolver=resolver,
|
a = await get_A(host, resolver=resolver,
|
||||||
use_aiodns=use_aiodns, loop=loop)
|
use_aiodns=use_aiodns, loop=loop)
|
||||||
for address in a:
|
for address in a:
|
||||||
results.append((host, address, port))
|
results.append((host, address, port))
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@asyncio.coroutine
|
|
||||||
def get_A(host, resolver=None, use_aiodns=True, loop=None):
|
def get_A(host, resolver=None, use_aiodns=True, loop=None):
|
||||||
"""Lookup DNS A records for a given host.
|
"""Lookup DNS A records for a given host.
|
||||||
|
|
||||||
|
@ -178,7 +176,7 @@ def get_A(host, resolver=None, use_aiodns=True, loop=None):
|
||||||
# getaddrinfo() method.
|
# getaddrinfo() method.
|
||||||
if resolver is None or not use_aiodns:
|
if resolver is None or not use_aiodns:
|
||||||
try:
|
try:
|
||||||
recs = yield from loop.getaddrinfo(host, None,
|
recs = await loop.getaddrinfo(host, None,
|
||||||
family=socket.AF_INET,
|
family=socket.AF_INET,
|
||||||
type=socket.SOCK_STREAM)
|
type=socket.SOCK_STREAM)
|
||||||
return [rec[4][0] for rec in recs]
|
return [rec[4][0] for rec in recs]
|
||||||
|
@ -189,15 +187,14 @@ def get_A(host, resolver=None, use_aiodns=True, loop=None):
|
||||||
# Using aiodns:
|
# Using aiodns:
|
||||||
future = resolver.query(host, 'A')
|
future = resolver.query(host, 'A')
|
||||||
try:
|
try:
|
||||||
recs = yield from future
|
recs = await future
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.debug('DNS: Exception while querying for %s A records: %s', host, e)
|
log.debug('DNS: Exception while querying for %s A records: %s', host, e)
|
||||||
recs = []
|
recs = []
|
||||||
return [rec.host for rec in recs]
|
return [rec.host for rec in recs]
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def get_AAAA(host, resolver=None, use_aiodns=True, loop=None):
|
||||||
def get_AAAA(host, resolver=None, use_aiodns=True, loop=None):
|
|
||||||
"""Lookup DNS AAAA records for a given host.
|
"""Lookup DNS AAAA records for a given host.
|
||||||
|
|
||||||
If ``resolver`` is not provided, or is ``None``, then resolution will
|
If ``resolver`` is not provided, or is ``None``, then resolution will
|
||||||
|
@ -224,7 +221,7 @@ def get_AAAA(host, resolver=None, use_aiodns=True, loop=None):
|
||||||
log.debug("DNS: Unable to query %s for AAAA records: IPv6 is not supported", host)
|
log.debug("DNS: Unable to query %s for AAAA records: IPv6 is not supported", host)
|
||||||
return []
|
return []
|
||||||
try:
|
try:
|
||||||
recs = yield from loop.getaddrinfo(host, None,
|
recs = await loop.getaddrinfo(host, None,
|
||||||
family=socket.AF_INET6,
|
family=socket.AF_INET6,
|
||||||
type=socket.SOCK_STREAM)
|
type=socket.SOCK_STREAM)
|
||||||
return [rec[4][0] for rec in recs]
|
return [rec[4][0] for rec in recs]
|
||||||
|
@ -236,14 +233,13 @@ def get_AAAA(host, resolver=None, use_aiodns=True, loop=None):
|
||||||
# Using aiodns:
|
# Using aiodns:
|
||||||
future = resolver.query(host, 'AAAA')
|
future = resolver.query(host, 'AAAA')
|
||||||
try:
|
try:
|
||||||
recs = yield from future
|
recs = await future
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.debug('DNS: Exception while querying for %s AAAA records: %s', host, e)
|
log.debug('DNS: Exception while querying for %s AAAA records: %s', host, e)
|
||||||
recs = []
|
recs = []
|
||||||
return [rec.host for rec in recs]
|
return [rec.host for rec in recs]
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def get_SRV(host, port, service, proto='tcp', resolver=None, use_aiodns=True):
|
||||||
def get_SRV(host, port, service, proto='tcp', resolver=None, use_aiodns=True):
|
|
||||||
"""Perform SRV record resolution for a given host.
|
"""Perform SRV record resolution for a given host.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
@ -277,7 +273,7 @@ def get_SRV(host, port, service, proto='tcp', resolver=None, use_aiodns=True):
|
||||||
try:
|
try:
|
||||||
future = resolver.query('_%s._%s.%s' % (service, proto, host),
|
future = resolver.query('_%s._%s.%s' % (service, proto, host),
|
||||||
'SRV')
|
'SRV')
|
||||||
recs = yield from future
|
recs = await future
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.debug('DNS: Exception while querying for %s SRV records: %s', host, e)
|
log.debug('DNS: Exception while querying for %s SRV records: %s', host, e)
|
||||||
return []
|
return []
|
||||||
|
|
|
@ -287,11 +287,10 @@ class XMLStream(asyncio.BaseProtocol):
|
||||||
self.event("connecting")
|
self.event("connecting")
|
||||||
self._current_connection_attempt = asyncio.ensure_future(self._connect_routine())
|
self._current_connection_attempt = asyncio.ensure_future(self._connect_routine())
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def _connect_routine(self):
|
||||||
def _connect_routine(self):
|
|
||||||
self.event_when_connected = "connected"
|
self.event_when_connected = "connected"
|
||||||
|
|
||||||
record = yield from self.pick_dns_answer(self.default_domain)
|
record = await self.pick_dns_answer(self.default_domain)
|
||||||
if record is not None:
|
if record is not None:
|
||||||
host, address, dns_port = record
|
host, address, dns_port = record
|
||||||
port = dns_port if dns_port else self.address[1]
|
port = dns_port if dns_port else self.address[1]
|
||||||
|
@ -307,9 +306,9 @@ class XMLStream(asyncio.BaseProtocol):
|
||||||
else:
|
else:
|
||||||
ssl_context = None
|
ssl_context = None
|
||||||
|
|
||||||
yield from asyncio.sleep(self.connect_loop_wait)
|
await asyncio.sleep(self.connect_loop_wait)
|
||||||
try:
|
try:
|
||||||
yield from self.loop.create_connection(lambda: self,
|
await self.loop.create_connection(lambda: self,
|
||||||
self.address[0],
|
self.address[0],
|
||||||
self.address[1],
|
self.address[1],
|
||||||
ssl=ssl_context,
|
ssl=ssl_context,
|
||||||
|
@ -540,10 +539,9 @@ class XMLStream(asyncio.BaseProtocol):
|
||||||
ssl_connect_routine = self.loop.create_connection(lambda: self, ssl=ssl_context,
|
ssl_connect_routine = self.loop.create_connection(lambda: self, ssl=ssl_context,
|
||||||
sock=self.socket,
|
sock=self.socket,
|
||||||
server_hostname=self.default_domain)
|
server_hostname=self.default_domain)
|
||||||
@asyncio.coroutine
|
async def ssl_coro():
|
||||||
def ssl_coro():
|
|
||||||
try:
|
try:
|
||||||
transp, prot = yield from ssl_connect_routine
|
transp, prot = await ssl_connect_routine
|
||||||
except ssl.SSLError as e:
|
except ssl.SSLError as e:
|
||||||
log.debug('SSL: Unable to connect', exc_info=True)
|
log.debug('SSL: Unable to connect', exc_info=True)
|
||||||
log.error('CERT: Invalid certificate trust chain.')
|
log.error('CERT: Invalid certificate trust chain.')
|
||||||
|
@ -671,8 +669,7 @@ class XMLStream(asyncio.BaseProtocol):
|
||||||
idx += 1
|
idx += 1
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def get_dns_records(self, domain, port=None):
|
||||||
def get_dns_records(self, domain, port=None):
|
|
||||||
"""Get the DNS records for a domain.
|
"""Get the DNS records for a domain.
|
||||||
|
|
||||||
:param domain: The domain in question.
|
:param domain: The domain in question.
|
||||||
|
@ -684,7 +681,7 @@ class XMLStream(asyncio.BaseProtocol):
|
||||||
resolver = default_resolver(loop=self.loop)
|
resolver = default_resolver(loop=self.loop)
|
||||||
self.configure_dns(resolver, domain=domain, port=port)
|
self.configure_dns(resolver, domain=domain, port=port)
|
||||||
|
|
||||||
result = yield from resolve(domain, port,
|
result = await resolve(domain, port,
|
||||||
service=self.dns_service,
|
service=self.dns_service,
|
||||||
resolver=resolver,
|
resolver=resolver,
|
||||||
use_ipv6=self.use_ipv6,
|
use_ipv6=self.use_ipv6,
|
||||||
|
@ -692,8 +689,7 @@ class XMLStream(asyncio.BaseProtocol):
|
||||||
loop=self.loop)
|
loop=self.loop)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@asyncio.coroutine
|
async def pick_dns_answer(self, domain, port=None):
|
||||||
def pick_dns_answer(self, domain, port=None):
|
|
||||||
"""Pick a server and port from DNS answers.
|
"""Pick a server and port from DNS answers.
|
||||||
|
|
||||||
Gets DNS answers if none available.
|
Gets DNS answers if none available.
|
||||||
|
@ -703,7 +699,7 @@ class XMLStream(asyncio.BaseProtocol):
|
||||||
:param port: If the results don't include a port, use this one.
|
:param port: If the results don't include a port, use this one.
|
||||||
"""
|
"""
|
||||||
if self.dns_answers is None:
|
if self.dns_answers is None:
|
||||||
dns_records = yield from self.get_dns_records(domain, port)
|
dns_records = await self.get_dns_records(domain, port)
|
||||||
self.dns_answers = iter(dns_records)
|
self.dns_answers = iter(dns_records)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -768,10 +764,9 @@ class XMLStream(asyncio.BaseProtocol):
|
||||||
# If the callback is a coroutine, schedule it instead of
|
# If the callback is a coroutine, schedule it instead of
|
||||||
# running it directly
|
# running it directly
|
||||||
if asyncio.iscoroutinefunction(handler_callback):
|
if asyncio.iscoroutinefunction(handler_callback):
|
||||||
@asyncio.coroutine
|
async def handler_callback_routine(cb):
|
||||||
def handler_callback_routine(cb):
|
|
||||||
try:
|
try:
|
||||||
yield from cb(data)
|
await cb(data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if old_exception:
|
if old_exception:
|
||||||
old_exception(e)
|
old_exception(e)
|
||||||
|
|
Loading…
Reference in a new issue