Merge branch 'reconnect-logic-doomed' into 'master'
fix reconnect logic See merge request poezio/slixmpp!104
This commit is contained in:
commit
dbcd0c6050
2 changed files with 114 additions and 51 deletions
|
@ -174,6 +174,9 @@ class XEP_0198(BasePlugin):
|
|||
|
||||
def send_ack(self):
|
||||
"""Send the current ack count to the server."""
|
||||
if not self.xmpp.transport:
|
||||
log.debug('Disconnected: not sending ack')
|
||||
return
|
||||
ack = stanza.Ack(self.xmpp)
|
||||
ack['h'] = self.handled
|
||||
self.xmpp.send_raw(str(ack))
|
||||
|
@ -198,20 +201,7 @@ class XEP_0198(BasePlugin):
|
|||
# We've already negotiated stream management,
|
||||
# so no need to do it again.
|
||||
return False
|
||||
if not self.sm_id:
|
||||
if 'bind' in self.xmpp.features:
|
||||
enable = stanza.Enable(self.xmpp)
|
||||
enable['resume'] = self.allow_resume
|
||||
enable.send()
|
||||
log.debug("enabling SM")
|
||||
|
||||
waiter = Waiter('enabled_or_failed',
|
||||
MatchMany([
|
||||
MatchXPath(stanza.Enabled.tag_name()),
|
||||
MatchXPath(stanza.Failed.tag_name())]))
|
||||
self.xmpp.register_handler(waiter)
|
||||
result = await waiter.wait()
|
||||
elif self.sm_id and self.allow_resume and 'bind' not in self.xmpp.features:
|
||||
if self.sm_id and self.allow_resume and 'bind' not in self.xmpp.features:
|
||||
resume = stanza.Resume(self.xmpp)
|
||||
resume['h'] = self.handled
|
||||
resume['previd'] = self.sm_id
|
||||
|
@ -229,6 +219,19 @@ class XEP_0198(BasePlugin):
|
|||
result = await waiter.wait()
|
||||
if result is not None and result.name == 'resumed':
|
||||
return True
|
||||
self.xmpp.event("session_end")
|
||||
if 'bind' in self.xmpp.features:
|
||||
enable = stanza.Enable(self.xmpp)
|
||||
enable['resume'] = self.allow_resume
|
||||
enable.send()
|
||||
log.debug("enabling SM")
|
||||
|
||||
waiter = Waiter('enabled_or_failed',
|
||||
MatchMany([
|
||||
MatchXPath(stanza.Enabled.tag_name()),
|
||||
MatchXPath(stanza.Failed.tag_name())]))
|
||||
self.xmpp.register_handler(waiter)
|
||||
result = await waiter.wait()
|
||||
return False
|
||||
|
||||
def _handle_enabled(self, stanza):
|
||||
|
|
|
@ -12,7 +12,15 @@
|
|||
:license: MIT, see LICENSE for more details
|
||||
"""
|
||||
|
||||
from typing import Optional, Set, Callable, Any
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Union,
|
||||
)
|
||||
|
||||
import functools
|
||||
import logging
|
||||
|
@ -21,7 +29,7 @@ import ssl
|
|||
import weakref
|
||||
import uuid
|
||||
|
||||
from asyncio import iscoroutinefunction, wait
|
||||
from asyncio import iscoroutinefunction, wait, Future
|
||||
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
|
@ -224,12 +232,13 @@ class XMLStream(asyncio.BaseProtocol):
|
|||
self.disconnect_reason = None
|
||||
|
||||
#: An asyncio Future being done when the stream is disconnected.
|
||||
self.disconnected = asyncio.Future()
|
||||
self.disconnected: Future = Future()
|
||||
|
||||
self.add_event_handler('disconnected', self._remove_schedules)
|
||||
self.add_event_handler('session_start', self._start_keepalive)
|
||||
|
||||
self._run_filters = None
|
||||
|
||||
self._run_out_filters: Optional[Future] = None
|
||||
self.__slow_tasks: List[Future] = []
|
||||
|
||||
@property
|
||||
def loop(self):
|
||||
|
@ -250,6 +259,12 @@ class XMLStream(asyncio.BaseProtocol):
|
|||
"""
|
||||
return uuid.uuid4().hex
|
||||
|
||||
def _set_disconnected_future(self):
|
||||
"""Set the self.disconnected future on disconnect"""
|
||||
if not self.disconnected.done():
|
||||
self.disconnected.set_result(True)
|
||||
self.disconnected = asyncio.Future()
|
||||
|
||||
def connect(self, host='', port=0, use_ssl=False,
|
||||
force_starttls=True, disable_starttls=False):
|
||||
"""Create a new socket and connect to the server.
|
||||
|
@ -272,8 +287,8 @@ class XMLStream(asyncio.BaseProtocol):
|
|||
localhost
|
||||
|
||||
"""
|
||||
if self._run_filters is None:
|
||||
self._run_filters = asyncio.ensure_future(
|
||||
if self._run_out_filters is None or self._run_out_filters.done():
|
||||
self._run_out_filters = asyncio.ensure_future(
|
||||
self.run_filters(),
|
||||
loop=self.loop,
|
||||
)
|
||||
|
@ -418,10 +433,10 @@ class XMLStream(asyncio.BaseProtocol):
|
|||
if self.xml_depth == 0:
|
||||
# The stream's root element has closed,
|
||||
# terminating the stream.
|
||||
self.end_session_on_disconnect = True
|
||||
log.debug("End of stream received")
|
||||
self.disconnect_reason = "End of stream"
|
||||
self.abort()
|
||||
return
|
||||
elif self.xml_depth == 1:
|
||||
# A stanza is an XML element that is a direct child of
|
||||
# the root element, hence the check of depth == 1
|
||||
|
@ -463,11 +478,11 @@ class XMLStream(asyncio.BaseProtocol):
|
|||
self.parser = None
|
||||
self.transport = None
|
||||
self.socket = None
|
||||
if self._run_filters:
|
||||
self._run_filters.cancel()
|
||||
# Fire the events after cleanup
|
||||
if self.end_session_on_disconnect:
|
||||
self._reset_sendq()
|
||||
self.event('session_end')
|
||||
self._set_disconnected_future()
|
||||
self.event("disconnected", self.disconnect_reason or exception and exception.strerror)
|
||||
|
||||
def cancel_connection_attempt(self):
|
||||
|
@ -480,10 +495,8 @@ class XMLStream(asyncio.BaseProtocol):
|
|||
if self._current_connection_attempt:
|
||||
self._current_connection_attempt.cancel()
|
||||
self._current_connection_attempt = None
|
||||
if self._run_filters:
|
||||
self._run_filters.cancel()
|
||||
|
||||
def disconnect(self, wait: float = 2.0, reason: Optional[str] = None, ignore_send_queue: bool = False) -> None:
|
||||
def disconnect(self, wait: Union[float, int] = 2.0, reason: Optional[str] = None, ignore_send_queue: bool = False) -> Future:
|
||||
"""Close the XML stream and wait for an acknowldgement from the server for
|
||||
at most `wait` seconds. After the given number of seconds has
|
||||
passed without a response from the server, or when the server
|
||||
|
@ -491,10 +504,13 @@ class XMLStream(asyncio.BaseProtocol):
|
|||
called. If wait is 0.0, this will call abort() directly without closing
|
||||
the stream.
|
||||
|
||||
Does nothing if we are not connected.
|
||||
Does nothing but trigger the disconnected event if we are not connected.
|
||||
|
||||
:param wait: Time to wait for a response from the server.
|
||||
|
||||
:param reason: An optional reason for the disconnect.
|
||||
:param ignore_send_queue: Boolean to toggle if we want to ignore
|
||||
the in-flight stanzas and disconnect immediately.
|
||||
:return: A future that ends when all code involved in the disconnect has ended
|
||||
"""
|
||||
# Compat: docs/getting_started/sendlogout.rst has been promoting
|
||||
# `disconnect(wait=True)` for ages. This doesn't mean anything to the
|
||||
|
@ -504,50 +520,75 @@ class XMLStream(asyncio.BaseProtocol):
|
|||
wait = 2.0
|
||||
|
||||
if self.transport:
|
||||
self.disconnect_reason = reason
|
||||
if self.waiting_queue.empty() or ignore_send_queue:
|
||||
self.disconnect_reason = reason
|
||||
self.cancel_connection_attempt()
|
||||
if wait > 0.0:
|
||||
self.send_raw(self.stream_footer)
|
||||
self.schedule('Disconnect wait', wait,
|
||||
self.abort, repeat=False)
|
||||
return asyncio.ensure_future(
|
||||
self._end_stream_wait(wait, reason=reason),
|
||||
loop=self.loop,
|
||||
)
|
||||
else:
|
||||
asyncio.ensure_future(
|
||||
return asyncio.ensure_future(
|
||||
self._consume_send_queue_before_disconnecting(reason, wait),
|
||||
loop=self.loop,
|
||||
)
|
||||
else:
|
||||
self._set_disconnected_future()
|
||||
self.event("disconnected", reason)
|
||||
future = Future()
|
||||
future.set_result(None)
|
||||
return future
|
||||
|
||||
async def _consume_send_queue_before_disconnecting(self, reason: Optional[str], wait: float):
|
||||
"""Wait until the send queue is empty before disconnecting"""
|
||||
await self.waiting_queue.join()
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self.waiting_queue.join(),
|
||||
wait,
|
||||
loop=self.loop
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
wait = 0 # we already consumed the timeout
|
||||
self.disconnect_reason = reason
|
||||
self.cancel_connection_attempt()
|
||||
if wait > 0.0:
|
||||
await self._end_stream_wait(wait)
|
||||
|
||||
async def _end_stream_wait(self, wait: Union[int, float] = 2, reason: Optional[str] = None):
|
||||
"""
|
||||
Run abort() if we do not received the disconnected event
|
||||
after a waiting time.
|
||||
|
||||
:param wait: The waiting time (defaults to 2)
|
||||
"""
|
||||
try:
|
||||
self.send_raw(self.stream_footer)
|
||||
self.schedule('Disconnect wait', wait,
|
||||
self.abort, repeat=False)
|
||||
await self.wait_until('disconnected', wait)
|
||||
except asyncio.TimeoutError:
|
||||
self.abort()
|
||||
except NotConnectedError:
|
||||
# We are not connected when sending the end of stream
|
||||
# that means the disconnect has already been handled
|
||||
pass
|
||||
|
||||
def abort(self):
|
||||
"""
|
||||
Forcibly close the connection
|
||||
"""
|
||||
self.cancel_connection_attempt()
|
||||
if self.transport:
|
||||
self.cancel_connection_attempt()
|
||||
self.transport.close()
|
||||
self.transport.abort()
|
||||
self.event("killed")
|
||||
self.disconnected.set_result(True)
|
||||
self.disconnected = asyncio.Future()
|
||||
self.event("disconnected", self.disconnect_reason)
|
||||
|
||||
def reconnect(self, wait=2.0, reason="Reconnecting"):
|
||||
"""Calls disconnect(), and once we are disconnected (after the timeout, or
|
||||
when the server acknowledgement is received), call connect()
|
||||
"""
|
||||
log.debug("reconnecting...")
|
||||
self.add_event_handler('disconnected', lambda event: self.connect(), disposable=True)
|
||||
async def handler(event):
|
||||
# We yield here to allow synchronous handlers to work first
|
||||
await asyncio.sleep(0, loop=self.loop)
|
||||
self.connect()
|
||||
self.add_event_handler('disconnected', handler, disposable=True)
|
||||
self.disconnect(wait, reason)
|
||||
|
||||
def configure_socket(self):
|
||||
|
@ -655,7 +696,6 @@ class XMLStream(asyncio.BaseProtocol):
|
|||
def _remove_schedules(self, event):
|
||||
"""Remove some schedules that become pointless when disconnected"""
|
||||
self.cancel_schedule('Whitespace Keepalive')
|
||||
self.cancel_schedule('Disconnect wait')
|
||||
|
||||
def start_stream_handler(self, xml):
|
||||
"""Perform any initialization actions, such as handshakes,
|
||||
|
@ -833,7 +873,7 @@ class XMLStream(asyncio.BaseProtocol):
|
|||
"""
|
||||
log.debug("Event triggered: %s", name)
|
||||
|
||||
handlers = self.__event_handlers.get(name, [])
|
||||
handlers = self.__event_handlers.get(name, [])[:]
|
||||
for handler in handlers:
|
||||
handler_callback, disposable = handler
|
||||
old_exception = getattr(data, 'exception', None)
|
||||
|
@ -941,6 +981,18 @@ class XMLStream(asyncio.BaseProtocol):
|
|||
"""
|
||||
return xml
|
||||
|
||||
def _reset_sendq(self):
|
||||
"""Clear sending tasks on session end"""
|
||||
# Cancel all pending slow send tasks
|
||||
log.debug('Cancelling %d slow send tasks', len(self.__slow_tasks))
|
||||
for slow_task in self.__slow_tasks:
|
||||
slow_task.cancel()
|
||||
self.__slow_tasks.clear()
|
||||
# Purge pending stanzas
|
||||
while not self.waiting_queue.empty():
|
||||
discarded = self.waiting_queue.get_nowait()
|
||||
log.debug('Discarded stanza: %s', discarded)
|
||||
|
||||
async def _continue_slow_send(
|
||||
self,
|
||||
task: asyncio.Task,
|
||||
|
@ -954,6 +1006,7 @@ class XMLStream(asyncio.BaseProtocol):
|
|||
:param set already_used: Filters already used on this outgoing stanza
|
||||
"""
|
||||
data = await task
|
||||
self.__slow_tasks.remove(task)
|
||||
for filter in self.__filters['out']:
|
||||
if filter in already_used:
|
||||
continue
|
||||
|
@ -975,7 +1028,6 @@ class XMLStream(asyncio.BaseProtocol):
|
|||
else:
|
||||
self.send_raw(data)
|
||||
|
||||
|
||||
async def run_filters(self):
|
||||
"""
|
||||
Background loop that processes stanzas to send.
|
||||
|
@ -995,11 +1047,13 @@ class XMLStream(asyncio.BaseProtocol):
|
|||
timeout=1,
|
||||
)
|
||||
if pending:
|
||||
self.slow_tasks.append(task)
|
||||
asyncio.ensure_future(
|
||||
self._continue_slow_send(
|
||||
task,
|
||||
already_run_filters
|
||||
)
|
||||
),
|
||||
loop=self.loop,
|
||||
)
|
||||
raise Exception("Slow coro, rescheduling")
|
||||
data = task.result()
|
||||
|
@ -1142,9 +1196,15 @@ class XMLStream(asyncio.BaseProtocol):
|
|||
:param int timeout: Timeout
|
||||
"""
|
||||
fut = asyncio.Future()
|
||||
def result_handler(event_data):
|
||||
if not fut.done():
|
||||
fut.set_result(event_data)
|
||||
else:
|
||||
log.debug("Future registered on event '%s' was alredy done", event)
|
||||
|
||||
self.add_event_handler(
|
||||
event,
|
||||
fut.set_result,
|
||||
result_handler,
|
||||
disposable=True,
|
||||
)
|
||||
return await asyncio.wait_for(fut, timeout)
|
||||
return await asyncio.wait_for(fut, timeout, loop=self.loop)
|
||||
|
|
Loading…
Reference in a new issue