Merge branch 'reconnect-logic-doomed' into 'master'

fix reconnect logic

See merge request poezio/slixmpp!104
This commit is contained in:
Link Mauve 2021-01-29 16:11:29 +01:00
commit dbcd0c6050
2 changed files with 114 additions and 51 deletions

View file

@ -174,6 +174,9 @@ class XEP_0198(BasePlugin):
def send_ack(self):
"""Send the current ack count to the server."""
if not self.xmpp.transport:
log.debug('Disconnected: not sending ack')
return
ack = stanza.Ack(self.xmpp)
ack['h'] = self.handled
self.xmpp.send_raw(str(ack))
@ -198,20 +201,7 @@ class XEP_0198(BasePlugin):
# We've already negotiated stream management,
# so no need to do it again.
return False
if not self.sm_id:
if 'bind' in self.xmpp.features:
enable = stanza.Enable(self.xmpp)
enable['resume'] = self.allow_resume
enable.send()
log.debug("enabling SM")
waiter = Waiter('enabled_or_failed',
MatchMany([
MatchXPath(stanza.Enabled.tag_name()),
MatchXPath(stanza.Failed.tag_name())]))
self.xmpp.register_handler(waiter)
result = await waiter.wait()
elif self.sm_id and self.allow_resume and 'bind' not in self.xmpp.features:
if self.sm_id and self.allow_resume and 'bind' not in self.xmpp.features:
resume = stanza.Resume(self.xmpp)
resume['h'] = self.handled
resume['previd'] = self.sm_id
@ -229,6 +219,19 @@ class XEP_0198(BasePlugin):
result = await waiter.wait()
if result is not None and result.name == 'resumed':
return True
self.xmpp.event("session_end")
if 'bind' in self.xmpp.features:
enable = stanza.Enable(self.xmpp)
enable['resume'] = self.allow_resume
enable.send()
log.debug("enabling SM")
waiter = Waiter('enabled_or_failed',
MatchMany([
MatchXPath(stanza.Enabled.tag_name()),
MatchXPath(stanza.Failed.tag_name())]))
self.xmpp.register_handler(waiter)
result = await waiter.wait()
return False
def _handle_enabled(self, stanza):

View file

@ -12,7 +12,15 @@
:license: MIT, see LICENSE for more details
"""
from typing import Optional, Set, Callable, Any
from typing import (
Any,
Callable,
Iterable,
List,
Optional,
Set,
Union,
)
import functools
import logging
@ -21,7 +29,7 @@ import ssl
import weakref
import uuid
from asyncio import iscoroutinefunction, wait
from asyncio import iscoroutinefunction, wait, Future
import xml.etree.ElementTree as ET
@ -224,12 +232,13 @@ class XMLStream(asyncio.BaseProtocol):
self.disconnect_reason = None
#: An asyncio Future being done when the stream is disconnected.
self.disconnected = asyncio.Future()
self.disconnected: Future = Future()
self.add_event_handler('disconnected', self._remove_schedules)
self.add_event_handler('session_start', self._start_keepalive)
self._run_filters = None
self._run_out_filters: Optional[Future] = None
self.__slow_tasks: List[Future] = []
@property
def loop(self):
@ -250,6 +259,12 @@ class XMLStream(asyncio.BaseProtocol):
"""
return uuid.uuid4().hex
def _set_disconnected_future(self):
"""Set the self.disconnected future on disconnect"""
if not self.disconnected.done():
self.disconnected.set_result(True)
self.disconnected = asyncio.Future()
def connect(self, host='', port=0, use_ssl=False,
force_starttls=True, disable_starttls=False):
"""Create a new socket and connect to the server.
@ -272,8 +287,8 @@ class XMLStream(asyncio.BaseProtocol):
localhost
"""
if self._run_filters is None:
self._run_filters = asyncio.ensure_future(
if self._run_out_filters is None or self._run_out_filters.done():
self._run_out_filters = asyncio.ensure_future(
self.run_filters(),
loop=self.loop,
)
@ -418,10 +433,10 @@ class XMLStream(asyncio.BaseProtocol):
if self.xml_depth == 0:
# The stream's root element has closed,
# terminating the stream.
self.end_session_on_disconnect = True
log.debug("End of stream received")
self.disconnect_reason = "End of stream"
self.abort()
return
elif self.xml_depth == 1:
# A stanza is an XML element that is a direct child of
# the root element, hence the check of depth == 1
@ -463,11 +478,11 @@ class XMLStream(asyncio.BaseProtocol):
self.parser = None
self.transport = None
self.socket = None
if self._run_filters:
self._run_filters.cancel()
# Fire the events after cleanup
if self.end_session_on_disconnect:
self._reset_sendq()
self.event('session_end')
self._set_disconnected_future()
self.event("disconnected", self.disconnect_reason or exception and exception.strerror)
def cancel_connection_attempt(self):
@ -480,10 +495,8 @@ class XMLStream(asyncio.BaseProtocol):
if self._current_connection_attempt:
self._current_connection_attempt.cancel()
self._current_connection_attempt = None
if self._run_filters:
self._run_filters.cancel()
def disconnect(self, wait: float = 2.0, reason: Optional[str] = None, ignore_send_queue: bool = False) -> None:
def disconnect(self, wait: Union[float, int] = 2.0, reason: Optional[str] = None, ignore_send_queue: bool = False) -> Future:
"""Close the XML stream and wait for an acknowldgement from the server for
at most `wait` seconds. After the given number of seconds has
passed without a response from the server, or when the server
@ -491,10 +504,13 @@ class XMLStream(asyncio.BaseProtocol):
called. If wait is 0.0, this will call abort() directly without closing
the stream.
Does nothing if we are not connected.
Does nothing but trigger the disconnected event if we are not connected.
:param wait: Time to wait for a response from the server.
:param reason: An optional reason for the disconnect.
:param ignore_send_queue: Boolean to toggle if we want to ignore
the in-flight stanzas and disconnect immediately.
:return: A future that ends when all code involved in the disconnect has ended
"""
# Compat: docs/getting_started/sendlogout.rst has been promoting
# `disconnect(wait=True)` for ages. This doesn't mean anything to the
@ -504,50 +520,75 @@ class XMLStream(asyncio.BaseProtocol):
wait = 2.0
if self.transport:
if self.waiting_queue.empty() or ignore_send_queue:
self.disconnect_reason = reason
if self.waiting_queue.empty() or ignore_send_queue:
self.cancel_connection_attempt()
if wait > 0.0:
self.send_raw(self.stream_footer)
self.schedule('Disconnect wait', wait,
self.abort, repeat=False)
return asyncio.ensure_future(
self._end_stream_wait(wait, reason=reason),
loop=self.loop,
)
else:
asyncio.ensure_future(
return asyncio.ensure_future(
self._consume_send_queue_before_disconnecting(reason, wait),
loop=self.loop,
)
else:
self._set_disconnected_future()
self.event("disconnected", reason)
future = Future()
future.set_result(None)
return future
async def _consume_send_queue_before_disconnecting(self, reason: Optional[str], wait: float):
"""Wait until the send queue is empty before disconnecting"""
await self.waiting_queue.join()
try:
await asyncio.wait_for(
self.waiting_queue.join(),
wait,
loop=self.loop
)
except asyncio.TimeoutError:
wait = 0 # we already consumed the timeout
self.disconnect_reason = reason
self.cancel_connection_attempt()
if wait > 0.0:
await self._end_stream_wait(wait)
async def _end_stream_wait(self, wait: Union[int, float] = 2, reason: Optional[str] = None):
"""
Run abort() if we do not received the disconnected event
after a waiting time.
:param wait: The waiting time (defaults to 2)
"""
try:
self.send_raw(self.stream_footer)
self.schedule('Disconnect wait', wait,
self.abort, repeat=False)
await self.wait_until('disconnected', wait)
except asyncio.TimeoutError:
self.abort()
except NotConnectedError:
# We are not connected when sending the end of stream
# that means the disconnect has already been handled
pass
def abort(self):
"""
Forcibly close the connection
"""
self.cancel_connection_attempt()
if self.transport:
self.cancel_connection_attempt()
self.transport.close()
self.transport.abort()
self.event("killed")
self.disconnected.set_result(True)
self.disconnected = asyncio.Future()
self.event("disconnected", self.disconnect_reason)
def reconnect(self, wait=2.0, reason="Reconnecting"):
"""Calls disconnect(), and once we are disconnected (after the timeout, or
when the server acknowledgement is received), call connect()
"""
log.debug("reconnecting...")
self.add_event_handler('disconnected', lambda event: self.connect(), disposable=True)
async def handler(event):
# We yield here to allow synchronous handlers to work first
await asyncio.sleep(0, loop=self.loop)
self.connect()
self.add_event_handler('disconnected', handler, disposable=True)
self.disconnect(wait, reason)
def configure_socket(self):
@ -655,7 +696,6 @@ class XMLStream(asyncio.BaseProtocol):
def _remove_schedules(self, event):
"""Remove some schedules that become pointless when disconnected"""
self.cancel_schedule('Whitespace Keepalive')
self.cancel_schedule('Disconnect wait')
def start_stream_handler(self, xml):
"""Perform any initialization actions, such as handshakes,
@ -833,7 +873,7 @@ class XMLStream(asyncio.BaseProtocol):
"""
log.debug("Event triggered: %s", name)
handlers = self.__event_handlers.get(name, [])
handlers = self.__event_handlers.get(name, [])[:]
for handler in handlers:
handler_callback, disposable = handler
old_exception = getattr(data, 'exception', None)
@ -941,6 +981,18 @@ class XMLStream(asyncio.BaseProtocol):
"""
return xml
def _reset_sendq(self):
"""Clear sending tasks on session end"""
# Cancel all pending slow send tasks
log.debug('Cancelling %d slow send tasks', len(self.__slow_tasks))
for slow_task in self.__slow_tasks:
slow_task.cancel()
self.__slow_tasks.clear()
# Purge pending stanzas
while not self.waiting_queue.empty():
discarded = self.waiting_queue.get_nowait()
log.debug('Discarded stanza: %s', discarded)
async def _continue_slow_send(
self,
task: asyncio.Task,
@ -954,6 +1006,7 @@ class XMLStream(asyncio.BaseProtocol):
:param set already_used: Filters already used on this outgoing stanza
"""
data = await task
self.__slow_tasks.remove(task)
for filter in self.__filters['out']:
if filter in already_used:
continue
@ -975,7 +1028,6 @@ class XMLStream(asyncio.BaseProtocol):
else:
self.send_raw(data)
async def run_filters(self):
"""
Background loop that processes stanzas to send.
@ -995,11 +1047,13 @@ class XMLStream(asyncio.BaseProtocol):
timeout=1,
)
if pending:
self.slow_tasks.append(task)
asyncio.ensure_future(
self._continue_slow_send(
task,
already_run_filters
)
),
loop=self.loop,
)
raise Exception("Slow coro, rescheduling")
data = task.result()
@ -1142,9 +1196,15 @@ class XMLStream(asyncio.BaseProtocol):
:param int timeout: Timeout
"""
fut = asyncio.Future()
def result_handler(event_data):
if not fut.done():
fut.set_result(event_data)
else:
log.debug("Future registered on event '%s' was alredy done", event)
self.add_event_handler(
event,
fut.set_result,
result_handler,
disposable=True,
)
return await asyncio.wait_for(fut, timeout)
return await asyncio.wait_for(fut, timeout, loop=self.loop)