diff --git a/slixmpp/clientxmpp.py b/slixmpp/clientxmpp.py index 754db100..8ef3493b 100644 --- a/slixmpp/clientxmpp.py +++ b/slixmpp/clientxmpp.py @@ -138,8 +138,8 @@ class ClientXMPP(BaseXMPP): self.credentials['password'] = value def connect(self, address: Optional[Tuple[str, int]] = None, # type: ignore - use_ssl: bool = False, force_starttls: bool = True, - disable_starttls: bool = False) -> None: + use_ssl: Optional[bool] = None, force_starttls: Optional[bool] = None, + disable_starttls: Optional[bool] = None) -> None: """Connect to the XMPP server. When no address is given, a SRV lookup for the server will @@ -166,8 +166,8 @@ class ClientXMPP(BaseXMPP): host, port = (self.boundjid.host, 5222) self.dns_service = 'xmpp-client' - return XMLStream.connect(self, host, port, use_ssl=use_ssl, - force_starttls=force_starttls, disable_starttls=disable_starttls) + XMLStream.connect(self, host, port, use_ssl=use_ssl, + force_starttls=force_starttls, disable_starttls=disable_starttls) def register_feature(self, name: str, handler: Callable, restart: bool = False, order: int = 5000) -> None: """Register a stream feature handler. diff --git a/slixmpp/componentxmpp.py b/slixmpp/componentxmpp.py index 2ef6ea45..3ed7bf82 100644 --- a/slixmpp/componentxmpp.py +++ b/slixmpp/componentxmpp.py @@ -9,6 +9,8 @@ import logging import hashlib +from typing import Optional + from slixmpp import Message, Iq, Presence from slixmpp.basexmpp import BaseXMPP from slixmpp.stanza import Handshake @@ -93,7 +95,7 @@ class ComponentXMPP(BaseXMPP): for st in Message, Iq, Presence: register_stanza_plugin(st, Error) - def connect(self, host=None, port=None, use_ssl=False): + def connect(self, host: str = '', port: int = 0, use_ssl: Optional[bool] = None) -> None: """Connect to the server. @@ -104,16 +106,15 @@ class ComponentXMPP(BaseXMPP): :param use_ssl: Flag indicating if SSL should be used by connecting directly to a port using SSL. """ - if host is None: - host = self.server_host - if port is None: - port = self.server_port + if host is not None: + self.server_host = host + if port: + self.server_port = port self.server_name = self.boundjid.host log.debug("Connecting to %s:%s", host, port) - return XMLStream.connect(self, host=host, port=port, - use_ssl=use_ssl) + XMLStream.connect(self, host=self.server_host, port=self.server_port, use_ssl=use_ssl) def incoming_filter(self, xml): """ diff --git a/slixmpp/xmlstream/xmlstream.py b/slixmpp/xmlstream/xmlstream.py index c7af73a4..65409ec7 100644 --- a/slixmpp/xmlstream/xmlstream.py +++ b/slixmpp/xmlstream/xmlstream.py @@ -290,8 +290,8 @@ class XMLStream(asyncio.BaseProtocol): self.xml_depth = 0 self.xml_root = None - self.force_starttls = None - self.disable_starttls = None + self.force_starttls = True + self.disable_starttls = False self.waiting_queue = asyncio.Queue() @@ -405,8 +405,9 @@ class XMLStream(asyncio.BaseProtocol): self.disconnected.set_result(True) self.disconnected = asyncio.Future() - def connect(self, host: str = '', port: int = 0, use_ssl: Optional[bool] = False, - force_starttls: Optional[bool] = True, disable_starttls: Optional[bool] = False) -> None: + def connect(self, host: str = '', port: int = 0, use_ssl: Optional[bool] = None, + force_starttls: Optional[bool] = None, + disable_starttls: Optional[bool] = None) -> None: """Create a new socket and connect to the server. :param host: The name of the desired server for the connection.