diff --git a/sleekxmpp/jid.py b/sleekxmpp/jid.py index e6da5746..dc6eb6b9 100644 --- a/sleekxmpp/jid.py +++ b/sleekxmpp/jid.py @@ -29,6 +29,30 @@ ILLEGAL_CHARS = '\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r' + \ JID_PATTERN = "^(?:([^\"&'/:<>@]{1,1023})@)?([^/@]{1,1023})(?:/(.{1,1023}))?$" +JID_ESCAPE_SEQUENCES = set(['\\20', '\\22', '\\26', '\\27', '\\2f', + '\\3a', '\\3c', '\\3e', '\\40', '\\5c']) + +JID_ESCAPE_TRANSFORMATIONS = {' ': '\\20', + '"': '\\22', + '&': '\\26', + "'": '\\27', + '/': '\\2f', + ':': '\\3a', + '<': '\\3c', + '>': '\\3e', + '@': '\\40'} + +JID_UNESCAPE_TRANSFORMATIONS = {'\\20': ' ', + '\\22': '"', + '\\26': '&', + '\\27': "'", + '\\2f': '/', + '\\3a': ':', + '\\3c': '<', + '\\3e': '>', + '\\40': '@', + '\\5c': '\\'} + nodeprep = stringprep_profiles.create( nfkc=True, @@ -48,7 +72,7 @@ nodeprep = stringprep_profiles.create( stringprep.in_table_c7, stringprep.in_table_c8, stringprep.in_table_c9, - lambda c: c in '\'"&/:<>@'], + lambda c: c in ' \'"&/:<>@'], unassigned=[stringprep.in_table_a1]) @@ -70,21 +94,33 @@ resourceprep = stringprep_profiles.create( unassigned=[stringprep.in_table_a1]) -class InvalidJID(ValueError): - pass - - -def parse_jid(data): +def _parse_jid(data): """ Parse string data into the node, domain, and resource components of a JID. """ match = re.match(JID_PATTERN, data) if not match: - raise InvalidJID + raise InvalidJID('JID could not be parsed') (node, domain, resource) = match.groups() + _validate_node(node) + _validate_domain(domain) + _validate_resource(resource) + + return node, domain, resource + + +def _validate_node(node): + try: + if node is not None: + node = nodeprep(node) + except stringprep_profiles.StringPrepError: + raise InvalidJID('Invalid local part') + + +def _validate_domain(domain): ip_addr = False try: @@ -107,27 +143,122 @@ def parse_jid(data): label = encodings.idna.nameprep(label) encodings.idna.ToASCII(label) except UnicodeError: - raise InvalidJID + raise InvalidJID('Could not encode domain as ASCII') for char in label: if char in ILLEGAL_CHARS: - raise InvalidJID + raise InvalidJID('Domain contains illegar characters') if '-' in (label[0], label[-1]): - raise InvalidJID + raise InvalidJID('Domain started or ended with -') domain_parts.append(label) domain = '.'.join(domain_parts) + if not domain: + raise InvalidJID('Missing domain') + + +def _validate_resource(resource): try: - if node is not None: - node = nodeprep(node) if resource is not None: resource = resourceprep(resource) except stringprep_profiles.StringPrepError: - raise InvalidJID + raise InvalidJID('Invalid resource') - return node, domain, resource + +def _escape_node(node): + result = [] + + for i, char in enumerate(node): + if char == '\\': + if ''.join((data[i:i+3])) in JID_ESCAPE_SEQUENCES: + result.append('\\5c') + continue + result.append(char) + + for i, char in enumerate(result): + result[i] = JID_ESCAPE_TRANSFORMATIONS.get(char, char) + + escaped = ''.join(result) + + if escaped.startswith('\\20') or escaped.endswith('\\20'): + raise InvalidJID('Escaped local part starts or ends with "\\20"') + + _validate_node(escaped) + + return escaped + + +def _unescape_node(node): + unescaped = [] + seq = '' + for i, char in enumerate(node): + if char == '\\': + seq = node[i:i+3] + if seq not in JID_ESCAPE_SEQUENCES: + seq = '' + if seq: + if len(seq) == 3: + unescaped.append(JID_UNESCAPE_TRANSFORMATIONS.get(seq, char)) + + # Pop character off the escape sequence, and ignore it + seq = seq[1:] + else: + unescaped.append(char) + unescaped = ''.join(unescaped) + + return unescaped + + +def _format_jid(local=None, domain=None, resource=None): + result = [] + if local: + result.append(local) + result.append('@') + if domain: + result.append(domain) + if resource: + result.append('/') + result.append(resource) + return ''.join(result) + + +class InvalidJID(ValueError): + pass + + +class UnescapedJID(object): + + def __init__(self, local, domain, resource): + self._jid = (local, domain, resource) + + def __getattr__(self, name): + """ + :param name: one of: user, server, domain, resource, + full, or bare. + """ + if name == 'resource': + return self._jid[2] or '' + elif name in ('user', 'username', 'local', 'node'): + return self._jid[0] or '' + elif name in ('server', 'domain', 'host'): + return self._jid[1] or '' + elif name in ('full', 'jid'): + return _format_jid(*self._jid) + elif name == 'bare': + return _format_jid(self._jid[0], self._jid[1]) + elif name == '_jid': + return getattr(super(JID, self), '_jid') + else: + return None + + def __str__(self): + """Use the full JID as the string value.""" + return _format_jid(*self._jid) + + def __repr__(self): + return self.__str__() class JID(object): @@ -157,21 +288,37 @@ class JID(object): :param string jid: A string of the form ``'[user@]domain[/resource]'``. """ - def __init__(self, jid=None, local=None, domain=None, resource=None): + def __init__(self, jid=None, **kwargs): """Initialize a new JID""" self._jid = (None, None, None) if jid is None or jid == '': jid = (None, None, None) elif not isinstance(jid, JID): - jid = parse_jid(jid) + jid = _parse_jid(jid) else: jid = jid._jid - orig_local, orig_domain, orig_resource = jid - self._jid = (local or orig_local or None, - domain or orig_domain or None, - resource or orig_resource or None) + local, domain, resource = jid + validated = True + + local = kwargs.get('local', local) + domain = kwargs.get('domain', domain) + resource = kwargs.get('resource', resource) + + if 'local' in kwargs: + local = _escape_node(local) + if 'domain' in kwargs: + _validate_domain(domain) + if 'resource' in kwargs: + _validate_resource(resource) + + self._jid = (local, domain, resource) + + def unescape(self): + return UnescapedJID(_unescape_node(self._jid[0]), + self._jid[1], + self._jid[2]) def regenerate(self): """Deprecated""" @@ -185,8 +332,7 @@ class JID(object): self._jid = JID(data)._jid def __getattr__(self, name): - """handle getting the jid values, using cache if available. - + """ :param name: one of: user, server, domain, resource, full, or bare. """ @@ -197,16 +343,16 @@ class JID(object): elif name in ('server', 'domain', 'host'): return self._jid[1] or '' elif name in ('full', 'jid'): - return str(self) + return _format_jid(*self._jid) elif name == 'bare': - return str(JID(local=self._jid[0], - domain=self._jid[1])) + return _format_jid(self._jid[0], self._jid[1]) + elif name == '_jid': + return getattr(super(JID, self), '_jid') else: - object.__getattr__(self, name) + return None def __setattr__(self, name, value): - """handle getting the jid values, using cache if available. - + """ :param name: one of: ``user``, ``username``, ``local``, ``node``, ``server``, ``domain``, ``host``, ``resource``, ``full``, ``jid``, or ``bare``. @@ -223,21 +369,12 @@ class JID(object): elif name == 'bare': parsed = JID(value)._jid self._jid = (parsed[0], parsed[1], self._jid[2]) - else: - object.__setattr__(self, name, value) + elif name == '_jid': + super(JID, self).__setattr__('_jid', value) def __str__(self): """Use the full JID as the string value.""" - result = [] - if self._jid[0]: - result.append(self._jid[0]) - result.append('@') - if self._jid[1]: - result.append(self._jid[1]) - if self._jid[2]: - result.append('/') - result.append(self._jid[2]) - return ''.join(result) + return _format_jid(*self._jid) def __repr__(self): return self.__str__() @@ -246,6 +383,9 @@ class JID(object): """ Two JIDs are considered equal if they have the same full JID value. """ + if isinstance(other, UnescapedJID): + return False + other = JID(other) return self._jid == other._jid