diff --git a/sleekxmpp/__init__.py b/sleekxmpp/__init__.py index 84b1114f..f0dc2ce2 100644 --- a/sleekxmpp/__init__.py +++ b/sleekxmpp/__init__.py @@ -10,7 +10,7 @@ from sleekxmpp.basexmpp import BaseXMPP from sleekxmpp.clientxmpp import ClientXMPP from sleekxmpp.componentxmpp import ComponentXMPP from sleekxmpp.stanza import Message, Presence, Iq -from sleekxmpp.jid import JID +from sleekxmpp.jid import JID, InvalidJID from sleekxmpp.xmlstream.handler import * from sleekxmpp.xmlstream import XMLStream, RestartStream from sleekxmpp.xmlstream.matcher import * diff --git a/sleekxmpp/jid.py b/sleekxmpp/jid.py index f0b7423b..9e9c0d0b 100644 --- a/sleekxmpp/jid.py +++ b/sleekxmpp/jid.py @@ -140,13 +140,12 @@ def _validate_node(node): """ try: if node is not None: - if not node: - raise InvalidJID('Localpart must not be 0 bytes') - node = nodeprep(node) if not node: raise InvalidJID('Localpart must not be 0 bytes') + if len(node) > 1023: + raise InvalidJID('Localpart must be less than 1024 bytes') return node except stringprep_profiles.StringPrepError: raise InvalidJID('Invalid local part') @@ -179,6 +178,7 @@ def _validate_domain(domain): if not ip_addr and hasattr(socket, 'inet_pton'): try: socket.inet_pton(socket.AF_INET6, domain.strip('[]')) + domain = '[%s]' % domain.strip('[]') ip_addr = True except socket.error: pass @@ -186,12 +186,19 @@ def _validate_domain(domain): if not ip_addr: # This is a domain name, which must be checked further + if domain and domain[-1] == '.': + domain = domain[:-1] + domain_parts = [] for label in domain.split('.'): try: label = encodings.idna.nameprep(label) encodings.idna.ToASCII(label) + pass_nameprep = True except UnicodeError: + pass_nameprep = False + + if not pass_nameprep: raise InvalidJID('Could not encode domain as ASCII') if label.startswith('xn--'): @@ -209,6 +216,8 @@ def _validate_domain(domain): if not domain: raise InvalidJID('Domain must not be 0 bytes') + if len(domain) > 1023: + raise InvalidJID('Domain must be less than 1024 bytes') return domain @@ -222,13 +231,12 @@ def _validate_resource(resource): """ try: if resource is not None: - if not resource: - raise InvalidJID('Resource must not be 0 bytes') - resource = resourceprep(resource) if not resource: raise InvalidJID('Resource must not be 0 bytes') + if len(resource) > 1023: + raise InvalidJID('Resource must be less than 1024 bytes') return resource except stringprep_profiles.StringPrepError: raise InvalidJID('Invalid resource') diff --git a/sleekxmpp/util/stringprep_profiles.py b/sleekxmpp/util/stringprep_profiles.py index a75bb9dd..6844c9ac 100644 --- a/sleekxmpp/util/stringprep_profiles.py +++ b/sleekxmpp/util/stringprep_profiles.py @@ -77,6 +77,9 @@ def check_bidi(data): character MUST be the first character of the string, and a RandALCat character MUST be the last character of the string. """ + if not data: + return data + has_lcat = False has_randal = False diff --git a/tests/test_jid.py b/tests/test_jid.py index 7b800520..aeb635a1 100644 --- a/tests/test_jid.py +++ b/tests/test_jid.py @@ -1,5 +1,5 @@ from sleekxmpp.test import * -from sleekxmpp import JID +from sleekxmpp import JID, InvalidJID class TestJIDClass(SleekTest): @@ -137,5 +137,146 @@ class TestJIDClass(SleekTest): self.assertFalse(jid1 == jid2, "Same JIDs are not considered equal") self.assertTrue(jid1 != jid2, "Same JIDs are considered not equal") + def testZeroLengthDomain(self): + self.assertRaises(InvalidJID, JID, domain='') + self.assertRaises(InvalidJID, JID, 'user@/resource') + + def testZeroLengthLocalPart(self): + self.assertRaises(InvalidJID, JID, local='', domain='test.com') + self.assertRaises(InvalidJID, JID, '@/test.com') + + def testZeroLengthResource(self): + self.assertRaises(InvalidJID, JID, domain='test.com', resource='') + self.assertRaises(InvalidJID, JID, 'test.com/') + + def test1023LengthDomain(self): + domain = ('a.' * 509) + 'a.com' + jid1 = JID(domain=domain) + jid2 = JID('user@%s/resource' % domain) + + def test1023LengthLocalPart(self): + local = 'a' * 1023 + jid1 = JID(local=local, domain='test.com') + jid2 = JID('%s@test.com' % local) + + def test1023LengthResource(self): + resource = 'r' * 1023 + jid1 = JID(domain='test.com', resource=resource) + jid2 = JID('test.com/%s' % resource) + + def test1024LengthDomain(self): + domain = ('a.' * 509) + 'aa.com' + self.assertRaises(InvalidJID, JID, domain=domain) + self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain) + + def test1024LengthLocalPart(self): + local = 'a' * 1024 + self.assertRaises(InvalidJID, JID, local=local, domain='test.com') + self.assertRaises(InvalidJID, JID, '%s@/test.com' % local) + + def test1024LengthResource(self): + resource = 'r' * 1024 + self.assertRaises(InvalidJID, JID, domain='test.com', resource=resource) + self.assertRaises(InvalidJID, JID, 'test.com/%s' % resource) + + def testTooLongDomainLabel(self): + domain = ('a' * 64) + '.com' + self.assertRaises(InvalidJID, JID, domain=domain) + self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain) + + def testDomainEmptyLabel(self): + domain = 'aaa..bbb.com' + self.assertRaises(InvalidJID, JID, domain=domain) + self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain) + + def testDomainIPv4(self): + domain = '127.0.0.1' + jid1 = JID(domain=domain) + jid2 = JID('user@%s/resource' % domain) + + def testDomainIPv6(self): + domain = '[::1]' + jid1 = JID(domain=domain) + jid2 = JID('user@%s/resource' % domain) + + def testDomainInvalidIPv6NoBrackets(self): + domain = '::1' + jid1 = JID(domain=domain) + jid2 = JID('user@%s/resource' % domain) + + self.assertEqual(jid1.domain, '[::1]') + self.assertEqual(jid2.domain, '[::1]') + + def testDomainInvalidIPv6MissingBracket(self): + domain = '[::1' + jid1 = JID(domain=domain) + jid2 = JID('user@%s/resource' % domain) + + self.assertEqual(jid1.domain, '[::1]') + self.assertEqual(jid2.domain, '[::1]') + + def testDomainWithPort(self): + domain = 'example.com:5555' + self.assertRaises(InvalidJID, JID, domain=domain) + self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain) + + def testDomainWithTrailingDot(self): + domain = 'example.com.' + jid1 = JID(domain=domain) + jid2 = JID('user@%s/resource' % domain) + + self.assertEqual(jid1.domain, 'example.com') + self.assertEqual(jid2.domain, 'example.com') + + def testDomainWithDashes(self): + domain = 'example.com-' + self.assertRaises(InvalidJID, JID, domain=domain) + self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain) + + domain = '-example.com' + self.assertRaises(InvalidJID, JID, domain=domain) + self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain) + + def testACEDomain(self): + domain = 'xn--bcher-kva.ch' + jid1 = JID(domain=domain) + jid2 = JID('user@%s/resource' % domain) + + self.assertEqual(jid1.domain.encode('utf-8'), b'b\xc3\xbccher.ch') + self.assertEqual(jid2.domain.encode('utf-8'), b'b\xc3\xbccher.ch') + + def testJIDEscapeExistingSequences(self): + jid = JID(local='blah\\foo\\20bar', domain='example.com') + self.assertEqual(jid.local, 'blah\\foo\\5c20bar') + + def testJIDEscape(self): + jid = JID(local='here\'s_a_wild_&_/cr%zy/_address_for:("IMPS")', + domain='example.com') + self.assertEqual(jid.local, r'here\27s_a_wild_\26_\2fcr%zy\2f_address_for\3a\3cwv\3e(\22IMPS\22)') + + def testJIDUnescape(self): + jid = JID(local='here\'s_a_wild_&_/cr%zy/_address_for:("IMPS")', + domain='example.com') + ujid = jid.unescape() + self.assertEqual(ujid.local, 'here\'s_a_wild_&_/cr%zy/_address_for:("IMPS")') + + jid = JID(local='blah\\foo\\20bar', domain='example.com') + ujid = jid.unescape() + self.assertEqual(ujid.local, 'blah\\foo\\20bar') + + def testStartOrEndWithEscapedSpaces(self): + local = ' foo' + self.assertRaises(InvalidJID, JID, local=local, domain='example.com') + self.assertRaises(InvalidJID, JID, '%s@example.com' % local) + + local = 'bar ' + self.assertRaises(InvalidJID, JID, local=local, domain='example.com') + self.assertRaises(InvalidJID, JID, '%s@example.com' % local) + + # Need more input for these cases. A JID starting with \20 *is* valid + # according to RFC 6122, but is not according to XEP-0106. + #self.assertRaises(InvalidJID, JID, '%s@example.com' % '\\20foo2') + #self.assertRaises(InvalidJID, JID, '%s@example.com' % 'bar2\\20') + suite = unittest.TestLoader().loadTestsFromTestCase(TestJIDClass)