From c1d36cad4679419679896c81e4d4520ec1627132 Mon Sep 17 00:00:00 2001 From: Lance Stout Date: Thu, 29 Mar 2012 15:11:24 -0700 Subject: [PATCH] Add better DNS resolver wrapper. --- sleekxmpp/clientxmpp.py | 41 ++--- sleekxmpp/xmlstream/resolver.py | 287 +++++++++++++++++++++++++++++++ sleekxmpp/xmlstream/xmlstream.py | 124 +++---------- 3 files changed, 324 insertions(+), 128 deletions(-) create mode 100644 sleekxmpp/xmlstream/resolver.py diff --git a/sleekxmpp/clientxmpp.py b/sleekxmpp/clientxmpp.py index 590192db..d2f24d16 100644 --- a/sleekxmpp/clientxmpp.py +++ b/sleekxmpp/clientxmpp.py @@ -84,6 +84,8 @@ class ClientXMPP(BaseXMPP): self._stream_feature_handlers = {} self._stream_feature_order = [] + self.dns_service = 'xmpp-client' + #TODO: Use stream state here self.authenticated = False self.sessionstarted = False @@ -139,43 +141,20 @@ class ClientXMPP(BaseXMPP): should be used. Defaults to ``False``. """ self.session_started_event.clear() - if not address: + + # If an address was provided, disable using DNS SRV lookup; + # otherwise, use the domain from the client JID with the standard + # XMPP client port and allow SRV lookup. + if address: + self.dns_service = None + else: address = (self.boundjid.host, 5222) + self.dns_service = 'xmpp-client' return XMLStream.connect(self, address[0], address[1], use_tls=use_tls, use_ssl=use_ssl, reattempt=reattempt) - def get_dns_records(self, domain, port=None): - """Get the DNS records for a domain, including SRV records. - - :param domain: The domain in question. - :param port: If the results don't include a port, use this one. - """ - if port is None: - port = self.default_port - if DNSPYTHON: - try: - record = "_xmpp-client._tcp.%s" % domain - answers = [] - log.debug("Querying SRV records for %s" % domain) - for answer in dns.resolver.query(record, dns.rdatatype.SRV): - address = (answer.target.to_text()[:-1], answer.port) - log.debug("Found SRV record: %s", address) - answers.append((address, answer.priority, answer.weight)) - except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer): - log.warning("No SRV records for %s", domain) - answers = super(ClientXMPP, self).get_dns_records(domain, port) - except dns.exception.Timeout: - log.warning("DNS resolution timed out " + \ - "for SRV record of %s", domain) - answers = super(ClientXMPP, self).get_dns_records(domain, port) - return answers - else: - log.warning("dnspython is not installed -- " + \ - "relying on OS A/AAAA record resolution") - return [((domain, port), 0, 0)] - def register_feature(self, name, handler, restart=False, order=5000): """Register a stream feature handler. diff --git a/sleekxmpp/xmlstream/resolver.py b/sleekxmpp/xmlstream/resolver.py new file mode 100644 index 00000000..ecb76519 --- /dev/null +++ b/sleekxmpp/xmlstream/resolver.py @@ -0,0 +1,287 @@ +# -*- encoding: utf-8 -*- + +""" + sleekxmpp.xmlstream.dns + ~~~~~~~~~~~~~~~~~~~~~~~ + + :copyright: (c) 2012 Nathanael C. Fritz + :license: MIT, see LICENSE for more details +""" + +import socket +import logging +import random + + +log = logging.getLogger(__name__) + + +#: Global flag indicating the availability of the ``dnspython`` package. +#: Installing ``dnspython`` can be done via: +#: +#: .. code-block:: sh +#: +#: pip install dnspython +#: +#: For Python3, installation may require installing from source using +#: the ``python3`` branch: +#: +#: .. code-block:: sh +#: +#: git clone http://github.com/rthalley/dnspython +#: cd dnspython +#: git checkout python3 +#: python3 setup.py install +USE_DNSPYTHON = False +try: + import dns.resolver + USE_DNSPYTHON = True +except ImportError as e: + log.debug("Could not find dnspython package. " + \ + "Not all features will be available") + + +def default_resolver(): + """Return a basic DNS resolver object. + + :returns: A :class:`dns.resolver.Resolver` object if dnspython + is available. Otherwise, ``None``. + """ + if USE_DNSPYTHON: + return dns.resolver.get_default_resolver() + return None + + +def resolve(host, port=None, service=None, proto='tcp', resolver=None): + """Peform DNS resolution for a given hostname. + + Resolution may perform SRV record lookups if a service and protocol + are specified. The returned addresses will be sorted according to + the SRV priorities and weights. + + If no resolver is provided, the dnspython resolver will be used if + available. Otherwise the built-in socket facilities will be used, + but those do not provide SRV support. + + If SRV records were used, queries to resolve alternative hosts will + be made as needed instead of all at once. + + :param host: The hostname to resolve. + :param port: A default port to connect with. SRV records may + dictate use of a different port. + :param service: Optional SRV service name without leading underscore. + :param proto: Optional SRV protocol name without leading underscore. + :param resolver: Optionally provide a DNS resolver object that has + been custom configured. + + :type host: string + :type port: int + :type service: string + :type proto: string + :type resolver: :class:`dns.resolver.Resolver` + + :return: An iterable of IP address, port pairs in the order + dictated by SRV priorities and weights, if applicable. + """ + if resolver is None and USE_DNSPYTHON: + resolver = dns.resolver.get_default_resolver() + + # An IPv6 literal is allowed to be enclosed in square brackets, but + # the brackets must be stripped in order to process the literal; + # otherwise, things break. + host = host.strip('[]') + + try: + # If `host` is an IPv4 literal, we can return it immediately. + ipv4 = socket.inet_pton(socket.AF_INET, host) + yield [(host, port)] + except socket.error: + pass + + try: + # Likewise, If `host` is an IPv6 literal, we can return it immediately. + ipv6 = socket.inet_pton(socket.AF_INET6, host) + yield [(host, port)] + except socket.error: + pass + + # If no service was provided, then we can just do A/AAAA lookups on the + # provided host. Otherwise we need to get an ordered list of hosts to + # resolve based on SRV records. + if not service: + hosts = [(host, port)] + else: + hosts = get_SRV(host, port, service, proto, resolver=resolver) + + for host, port in hosts: + results = [] + for address in get_AAAA(host, resolver=resolver): + results.append((address, port)) + for address in get_A(host, resolver=resolver): + results.append((address, port)) + + for address, port in results: + yield address, port + + +def get_A(host, resolver=None): + """Lookup DNS A records for a given host. + + If ``resolver`` is not provided, or is ``None``, then resolution will + be performed using the built-in :mod:`socket` module. + + :param host: The hostname to resolve for A record IPv4 addresses. + :param resolver: Optional DNS resolver object to use for the query. + + :type host: string + :type resolver: :class:`dns.resolver.Resolver` or ``None`` + + :return: A list of IPv4 literals. + """ + log.debug("DNS: Querying %s for A records." % host) + + # If not using dnspython, attempt lookup using the OS level + # getaddrinfo() method. + if resolver is None: + try: + recs = socket.getaddrinfo(host, None, socket.AF_INET, + socket.SOCK_STREAM) + return [rec[4][0] for rec in recs] + except socket.gaierror: + log.debug("DNS: Error retreiving A address info for %s." % host) + return [] + + # Using dnspython: + try: + recs = resolver.query(host, dns.rdatatype.A) + return [rec.to_text() for rec in recs] + except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer): + log.debug("DNS: No A records for %s." % host) + return [] + except dns.exception.Timeout: + log.debug("DNS: A record resolution timed out for %s." % host) + return [] + except dns.exception.DNSException as e: + log.debug("DNS: Error querying A records for %s." % host) + log.exception(e) + return [] + + +def get_AAAA(host, resolver=None): + """Lookup DNS AAAA records for a given host. + + If ``resolver`` is not provided, or is ``None``, then resolution will + be performed using the built-in :mod:`socket` module. + + :param host: The hostname to resolve for AAAA record IPv6 addresses. + :param resolver: Optional DNS resolver object to use for the query. + + :type host: string + :type resolver: :class:`dns.resolver.Resolver` or ``None`` + + :return: A list of IPv6 literals. + """ + log.debug("DNS: Querying %s for AAAA records." % host) + + # If not using dnspython, attempt lookup using the OS level + # getaddrinfo() method. + if resolver is None: + try: + recs = socket.getaddrinfo(host, None, socket.AF_INET6, + socket.SOCK_STREAM) + return [rec[4][0] for rec in recs] + except socket.gaierror: + log.debug("DNS: Error retreiving AAAA address " + \ + "info for %s." % host) + return [] + + # Using dnspython: + try: + recs = resolver.query(host, dns.rdatatype.AAAA) + return [rec.to_text() for rec in recs] + except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer): + log.debug("DNS: No AAAA records for %s." % host) + return [] + except dns.exception.Timeout: + log.debug("DNS: AAAA record resolution timed out for %s." % host) + return [] + except dns.exception.DNSException as e: + log.debug("DNS: Error querying AAAA records for %s." % host) + log.exception(e) + return [] + + +def get_SRV(host, port, service, proto='tcp', resolver=None): + """Perform SRV record resolution for a given host. + + .. note:: + + This function requires the use of the ``dnspython`` package. Calling + :func:`get_SRV` without ``dnspython`` will return the provided host + and port without performing any DNS queries. + + :param host: The hostname to resolve. + :param port: A default port to connect with. SRV records may + dictate use of a different port. + :param service: Optional SRV service name without leading underscore. + :param proto: Optional SRV protocol name without leading underscore. + :param resolver: Optionally provide a DNS resolver object that has + been custom configured. + + :type host: string + :type port: int + :type service: string + :type proto: string + :type resolver: :class:`dns.resolver.Resolver` + + :return: A list of hostname, port pairs in the order dictacted + by SRV priorities and weights. + """ + if resolver is None: + return [(host, port)] + + log.debug("Querying SRV records for %s" % host) + try: + recs = resolver.query('_%s._%s.%s' % (service, proto, host), + dns.rdatatype.SRV) + except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer): + log.debug("DNS: No SRV records for %s." % host) + return [(host, port)] + except dns.exception.Timeout: + log.debug("DNS: SRV record resolution timed out for %s." % host) + return [(host, port)] + except dns.exception.DNSException as e: + log.debug("DNS: Error querying SRV records for %s." % host) + log.exception(e) + return [(host, port)] + + if len(recs) == 1 and recs[0].target == '.': + return [(host, port)] + + answers = {} + for rec in recs: + if rec.priority not in answers: + answers[rec.priority] = [] + if rec.weight == 0: + answers[rec.priority].insert(0, rec) + else: + answers[rec.priority].append(rec) + + sorted_recs = [] + for priority in sorted(answers.keys()): + while answers[priority]: + running_sum = 0 + sums = {} + for rec in answers[priority]: + running_sum += rec.weight + sums[running_sum] = rec + + selected = random.randint(0, running_sum + 1) + for running_sum in sums: + if running_sum >= selected: + rec = sums[running_sum] + sorted_recs.append((rec.target.to_text(), rec.port)) + answers[priority].remove(rec) + break + + return sorted_recs diff --git a/sleekxmpp/xmlstream/xmlstream.py b/sleekxmpp/xmlstream/xmlstream.py index 45f27929..31ca9cfe 100644 --- a/sleekxmpp/xmlstream/xmlstream.py +++ b/sleekxmpp/xmlstream/xmlstream.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ sleekxmpp.xmlstream.xmlstream ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -39,19 +38,13 @@ from sleekxmpp.xmlstream import Scheduler, tostring from sleekxmpp.xmlstream.stanzabase import StanzaBase, ET, ElementBase from sleekxmpp.xmlstream.handler import Waiter, XMLCallback from sleekxmpp.xmlstream.matcher import MatchXMLMask +from sleekxmpp.xmlstream.resolver import resolve, default_resolver # In Python 2.x, file socket objects are broken. A patched socket # wrapper is provided for this case in filesocket.py. if sys.version_info < (3, 0): from sleekxmpp.xmlstream.filesocket import FileSocket, Socket26 -try: - import dns.resolver -except ImportError: - DNSPYTHON = False -else: - DNSPYTHON = True - #: The time in seconds to wait before timing out waiting for response stanzas. RESPONSE_TIMEOUT = 30 @@ -306,6 +299,11 @@ class XMLStream(object): #: A list of DNS results that have not yet been tried. self.dns_answers = [] + #: The service name to check with DNS SRV records. For + #: example, setting this to ``'xmpp-client'`` would query the + #: ``_xmpp-client._tcp`` service. + self.dns_service = None + self.add_event_handler('connected', self._handle_connected) self.add_event_handler('session_start', self._start_keepalive) self.add_event_handler('disconnected', self._end_keepalive) @@ -445,25 +443,10 @@ class XMLStream(object): self.stop.set() return False - try: - # Look for IPv6 addresses, in addition to IPv4 - for res in Socket.getaddrinfo(self.address[0], - int(self.address[1]), - 0, - Socket.SOCK_STREAM): - log.debug("Trying: %s", res[-1]) - af, sock_type, proto, canonical, sock_addr = res - try: - self.socket = self.socket_class(af, sock_type, proto) - break - except Socket.error: - log.debug("Could not open IPv%s socket." % proto) - except Socket.gaierror: - log.warning("Socket could not be opened: no connectivity" + \ - " or wrong IP versions.") - if reattempt: - self.reconnect_delay = delay - return False + af = Socket.AF_INET + if ':' in self.address[0]: + af = Socket.AF_INET6 + self.socket = self.socket_class(af, Socket.SOCK_STREAM) self.configure_socket() @@ -511,7 +494,10 @@ class XMLStream(object): except Socket.error as serr: error_msg = "Could not connect to %s:%s. Socket Error #%s: %s" self.event('socket_error', serr, direct=True) - log.error(error_msg, self.address[0], self.address[1], + domain = self.address[0] + if ':' in domain: + domain = '[%s]' % domain + log.error(error_msg, domain, self.address[1], serr.errno, serr.strerror) if reattempt: self.reconnect_delay = delay @@ -915,50 +901,11 @@ class XMLStream(object): """ if port is None: port = self.default_port - if DNSPYTHON: - resolver = dns.resolver.get_default_resolver() - self.configure_dns(resolver, domain=domain, port=port) + + resolver = default_resolver() + self.configure_dns(resolver, domain=domain, port=port) - v4_answers = [] - v6_answers = [] - answers = [] - - try: - log.debug("Querying A records for %s" % domain) - v4_answers = resolver.query(domain, dns.rdatatype.A) - except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer): - log.warning("No A records for %s", domain) - v4_answers = [((domain, port), 0, 0)] - except dns.exception.Timeout: - log.warning("DNS resolution timed out " + \ - "for A record of %s", domain) - v4_answers = [((domain, port), 0, 0)] - else: - for ans in v4_answers: - log.debug("Found A record: %s", ans.address) - answers.append(((ans.address, port), 0, 0)) - - try: - log.debug("Querying AAAA records for %s" % domain) - v6_answers = resolver.query(domain, dns.rdatatype.AAAA) - except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer): - log.warning("No AAAA records for %s", domain) - v6_answers = [((domain, port), 0, 0)] - except dns.exception.Timeout: - log.warning("DNS resolution timed out " + \ - "for AAAA record of %s", domain) - v6_answers = [((domain, port), 0, 0)] - else: - for ans in v6_answers: - log.debug("Found AAAA record: %s", ans.address) - answers.append(((ans.address, port), 0, 0)) - - return answers - else: - log.warning("dnspython is not installed -- " + \ - "relying on OS A/AAAA record resolution") - self.configure_dns(None, domain=domain, port=port) - return [((domain, port), 0, 0)] + return resolve(domain, port, service=self.dns_service, resolver=resolver) def pick_dns_answer(self, domain, port=None): """Pick a server and port from DNS answers. @@ -971,33 +918,16 @@ class XMLStream(object): """ if not self.dns_answers: self.dns_answers = self.get_dns_records(domain, port) - addresses = {} - intmax = 0 - topprio = 65535 - for answer in self.dns_answers: - topprio = min(topprio, answer[1]) - for answer in self.dns_answers: - if answer[1] == topprio: - intmax += answer[2] - addresses[intmax] = answer[0] - - #python3 returns a generator for dictionary keys - items = [x for x in addresses.keys()] - items.sort() - - address = (domain, port) - picked = random.randint(0, intmax) - for item in items: - if picked <= item: - address = addresses[item] - break - for idx, answer in enumerate(self.dns_answers): - if self.dns_answers[0] == address: - self.dns_answers.pop(idx) - break - log.debug("Trying to connect to %s:%s", *address) - return address + try: + if sys.version_info < (3, 0): + return self.dns_answers.next() + else: + return next(self.dns_answers) + except StopIteration: + self.dns_answers = None + return (domain, port) + def add_event_handler(self, name, pointer, threaded=False, disposable=False): """Add a custom event handler that will be executed whenever