From 8fc6814b6d7c05621dd2ca1d914960c6dbdadb49 Mon Sep 17 00:00:00 2001 From: mathieui Date: Sat, 4 Jun 2016 20:51:59 +0200 Subject: [PATCH] Update XEP-0198 for asyncio --- slixmpp/plugins/xep_0198/stream_management.py | 93 ++++++++++--------- 1 file changed, 48 insertions(+), 45 deletions(-) diff --git a/slixmpp/plugins/xep_0198/stream_management.py b/slixmpp/plugins/xep_0198/stream_management.py index acf37cd7..fbc9e023 100644 --- a/slixmpp/plugins/xep_0198/stream_management.py +++ b/slixmpp/plugins/xep_0198/stream_management.py @@ -6,8 +6,8 @@ See the file LICENSE for copying permission. """ +import asyncio import logging -import threading import collections from slixmpp.stanza import Message, Presence, Iq, StreamFeatures @@ -70,15 +70,10 @@ class XEP_0198(BasePlugin): return self.window_counter = self.window - self.window_counter_lock = threading.Lock() - self.enabled = threading.Event() + self.enabled = False self.unacked_queue = collections.deque() - self.seq_lock = threading.Lock() - self.handled_lock = threading.Lock() - self.ack_lock = threading.Lock() - register_stanza_plugin(StreamFeatures, stanza.StreamManagement) self.xmpp.register_stanza(stanza.Enable) self.xmpp.register_stanza(stanza.Enabled) @@ -161,7 +156,7 @@ class XEP_0198(BasePlugin): def session_end(self, event): """Reset stream management state.""" - self.enabled.clear() + self.enabled = False self.unacked_queue.clear() self.sm_id = None self.handled = 0 @@ -171,15 +166,15 @@ class XEP_0198(BasePlugin): def send_ack(self): """Send the current ack count to the server.""" ack = stanza.Ack(self.xmpp) - with self.handled_lock: - ack['h'] = self.handled + ack['h'] = self.handled self.xmpp.send_raw(str(ack)) def request_ack(self, e=None): """Request an ack from the server.""" req = stanza.RequestAck(self.xmpp) - self.xmpp.send_queue.put(str(req)) + self.xmpp.send_raw(str(req)) + @asyncio.coroutine def _handle_sm_feature(self, features): """ Enable or resume stream management. @@ -196,13 +191,21 @@ class XEP_0198(BasePlugin): return False if not self.sm_id: if 'bind' in self.xmpp.features: - self.enabled.set() enable = stanza.Enable(self.xmpp) enable['resume'] = self.allow_resume enable.send() + self.enabled = True self.handled = 0 - elif self.sm_id and self.allow_resume: - self.enabled.set() + self.unacked_queue.clear() + + waiter = Waiter('enabled_or_failed', + MatchMany([ + MatchXPath(stanza.Enabled.tag_name()), + MatchXPath(stanza.Failed.tag_name())])) + self.xmpp.register_handler(waiter) + result = yield from waiter.wait() + elif self.sm_id and self.allow_resume and 'bind' not in self.xmpp.features: + self.enabled = True resume = stanza.Resume(self.xmpp) resume['h'] = self.handled resume['previd'] = self.sm_id @@ -216,7 +219,7 @@ class XEP_0198(BasePlugin): MatchXPath(stanza.Resumed.tag_name()), MatchXPath(stanza.Failed.tag_name())])) self.xmpp.register_handler(waiter) - result = waiter.wait() + result = yield from waiter.wait() if result is not None and result.name == 'resumed': return True return False @@ -250,7 +253,7 @@ class XEP_0198(BasePlugin): Raises an :term:`sm_failed` event. """ - self.enabled.clear() + self.enabled = False self.unacked_queue.clear() self.xmpp.event('sm_failed', stanza) @@ -262,21 +265,24 @@ class XEP_0198(BasePlugin): if ack['h'] == self.last_ack: return - with self.ack_lock: - num_acked = (ack['h'] - self.last_ack) % MAX_SEQ - num_unacked = len(self.unacked_queue) - log.debug("Ack: %s, Last Ack: %s, " + \ - "Unacked: %s, Num Acked: %s, " + \ - "Remaining: %s", - ack['h'], - self.last_ack, - num_unacked, - num_acked, - num_unacked - num_acked) - for x in range(num_acked): - seq, stanza = self.unacked_queue.popleft() - self.xmpp.event('stanza_acked', stanza) - self.last_ack = ack['h'] + num_acked = (ack['h'] - self.last_ack) % MAX_SEQ + num_unacked = len(self.unacked_queue) + log.debug("Ack: %s, Last Ack: %s, " + \ + "Unacked: %s, Num Acked: %s, " + \ + "Remaining: %s", + ack['h'], + self.last_ack, + num_unacked, + num_acked, + num_unacked - num_acked) + if num_acked > len(self.unacked_queue) or num_acked < 0: + log.error('Inconsistent sequence numbers from the server,' + ' ignoring and replacing ours with them.') + num_acked = len(self.unacked_queue) + for x in range(num_acked): + seq, stanza = self.unacked_queue.popleft() + self.xmpp.event('stanza_acked', stanza) + self.last_ack = ack['h'] def _handle_request_ack(self, req): """Handle an ack request by sending an ack.""" @@ -284,30 +290,27 @@ class XEP_0198(BasePlugin): def _handle_incoming(self, stanza): """Increment the handled counter for each inbound stanza.""" - if not self.enabled.is_set(): + if not self.enabled: return stanza if isinstance(stanza, (Message, Presence, Iq)): - with self.handled_lock: - # Sequence numbers are mod 2^32 - self.handled = (self.handled + 1) % MAX_SEQ + # Sequence numbers are mod 2^32 + self.handled = (self.handled + 1) % MAX_SEQ return stanza def _handle_outgoing(self, stanza): """Store outgoing stanzas in a queue to be acked.""" - if not self.enabled.is_set(): + if not self.enabled: return stanza if isinstance(stanza, (Message, Presence, Iq)): seq = None - with self.seq_lock: - # Sequence numbers are mod 2^32 - self.seq = (self.seq + 1) % MAX_SEQ - seq = self.seq + # Sequence numbers are mod 2^32 + self.seq = (self.seq + 1) % MAX_SEQ + seq = self.seq self.unacked_queue.append((seq, stanza)) - with self.window_counter_lock: - self.window_counter -= 1 - if self.window_counter == 0: - self.window_counter = self.window - self.request_ack() + self.window_counter -= 1 + if self.window_counter == 0: + self.window_counter = self.window + self.request_ack() return stanza