From 4e12e228cb2fbc6fe2941070fe1ea44e01e4a9fb Mon Sep 17 00:00:00 2001 From: Lance Stout Date: Fri, 10 Aug 2012 12:40:28 -0700 Subject: [PATCH] Fix tracking service name for DIGEST-MD5 --- .../features/feature_mechanisms/mechanisms.py | 2 +- sleekxmpp/xmlstream/resolver.py | 21 +++++++++++-------- sleekxmpp/xmlstream/xmlstream.py | 7 +++++-- 3 files changed, 18 insertions(+), 12 deletions(-) diff --git a/sleekxmpp/features/feature_mechanisms/mechanisms.py b/sleekxmpp/features/feature_mechanisms/mechanisms.py index 9a391628..dae2f59f 100644 --- a/sleekxmpp/features/feature_mechanisms/mechanisms.py +++ b/sleekxmpp/features/feature_mechanisms/mechanisms.py @@ -109,7 +109,7 @@ class FeatureMechanisms(BasePlugin): elif value == 'realm': result[value] = self.xmpp.boundjid.domain elif value == 'service-name': - result[value] = self.xmpp.address[0] + result[value] = self.xmpp._service_name elif value == 'service': result[value] = 'xmpp' elif value in creds: diff --git a/sleekxmpp/xmlstream/resolver.py b/sleekxmpp/xmlstream/resolver.py index 0d7a8c0d..394daa64 100644 --- a/sleekxmpp/xmlstream/resolver.py +++ b/sleekxmpp/xmlstream/resolver.py @@ -102,7 +102,7 @@ def resolve(host, port=None, service=None, proto='tcp', try: # If `host` is an IPv4 literal, we can return it immediately. ipv4 = socket.inet_aton(host) - yield (host, port) + yield (host, host, port) except socket.error: pass @@ -112,7 +112,7 @@ def resolve(host, port=None, service=None, proto='tcp', # it immediately. if hasattr(socket, 'inet_pton'): ipv6 = socket.inet_pton(socket.AF_INET6, host) - yield (host, port) + yield (host, host, port) except socket.error: pass @@ -128,16 +128,16 @@ def resolve(host, port=None, service=None, proto='tcp', results = [] if host == 'localhost': if use_ipv6: - results.append(('::1', port)) - results.append(('127.0.0.1', port)) + results.append((host, '::1', port)) + results.append((host, '127.0.0.1', port)) if use_ipv6: for address in get_AAAA(host, resolver=resolver): - results.append((address, port)) + results.append((host, address, port)) for address in get_A(host, resolver=resolver): - results.append((address, port)) + results.append((host, address, port)) - for address, port in results: - yield address, port + for host, address, port in results: + yield host, address, port def get_A(host, resolver=None): @@ -297,7 +297,10 @@ def get_SRV(host, port, service, proto='tcp', resolver=None): for running_sum in sums: if running_sum >= selected: rec = sums[running_sum] - sorted_recs.append((rec.target.to_text(), rec.port)) + host = rec.target.to_text() + if host.endswith('.'): + host = host[:-1] + sorted_recs.append((host, rec.port)) answers[priority].remove(rec) break diff --git a/sleekxmpp/xmlstream/xmlstream.py b/sleekxmpp/xmlstream/xmlstream.py index f72171a1..8f8e94fd 100644 --- a/sleekxmpp/xmlstream/xmlstream.py +++ b/sleekxmpp/xmlstream/xmlstream.py @@ -192,6 +192,7 @@ class XMLStream(object): #: The expected name of the server, for validation. self._expected_server_name = '' + self._service_name = '' #: The desired, or actual, address of the connected server. self.address = (host, int(port)) @@ -473,8 +474,10 @@ class XMLStream(object): if self.default_domain: try: - self.address = self.pick_dns_answer(self.default_domain, - self.address[1]) + host, address, port = self.pick_dns_answer(self.default_domain, + self.address[1]) + self.address = (address, port) + self._service_name = host except StopIteration: log.debug("No remaining DNS records to try.") self.dns_answers = None