From 06358d0665ee68962acef6936e063f4baadc1930 Mon Sep 17 00:00:00 2001 From: mathieui Date: Sun, 22 Feb 2015 20:13:48 +0100 Subject: [PATCH] Use CallbackCoroutine with Iq callbacks too --- slixmpp/stanza/iq.py | 22 +++++++++++++--------- slixmpp/xmlstream/xmlstream.py | 1 + 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/slixmpp/stanza/iq.py b/slixmpp/stanza/iq.py index e2cef50d..0d8051e2 100644 --- a/slixmpp/stanza/iq.py +++ b/slixmpp/stanza/iq.py @@ -8,7 +8,7 @@ from slixmpp.stanza.rootstanza import RootStanza from slixmpp.xmlstream import StanzaBase, ET -from slixmpp.xmlstream.handler import Waiter, Callback +from slixmpp.xmlstream.handler import Waiter, Callback, CoroutineCallback from slixmpp.xmlstream.asyncio import asyncio from slixmpp.xmlstream.matcher import MatchIDSender, MatcherId from slixmpp.exceptions import IqTimeout, IqError @@ -249,6 +249,10 @@ class Iq(RootStanza): if callback is not None and self['type'] in ('get', 'set'): handler_name = 'IqCallback_%s' % self['id'] + if asyncio.iscoroutinefunction(callback): + constr = CoroutineCallback + else: + constr = Callback if timeout_callback: self.callback = callback self.timeout_callback = timeout_callback @@ -256,15 +260,15 @@ class Iq(RootStanza): timeout, self._fire_timeout, repeat=False) - handler = Callback(handler_name, - matcher, - self._handle_result, - once=True) + handler = constr(handler_name, + matcher, + self._handle_result, + once=True) else: - handler = Callback(handler_name, - matcher, - callback, - once=True) + handler = constr(handler_name, + matcher, + callback, + once=True) self.stream.register_handler(handler) StanzaBase.send(self) return handler_name diff --git a/slixmpp/xmlstream/xmlstream.py b/slixmpp/xmlstream/xmlstream.py index 573ca829..959169d1 100644 --- a/slixmpp/xmlstream/xmlstream.py +++ b/slixmpp/xmlstream/xmlstream.py @@ -489,6 +489,7 @@ class XMLStream(asyncio.BaseProtocol): ssl_connect_routine = loop.create_connection(lambda: self, ssl=self.ssl_context, sock=self.socket, server_hostname=self.address[0]) + @asyncio.coroutine def ssl_coro(): try: transp, prot = yield from ssl_connect_routine