Fix JID validation bugs, add lots of tests.
This commit is contained in:
parent
78aa5c3dfa
commit
352ee2f2fd
4 changed files with 160 additions and 8 deletions
|
@ -10,7 +10,7 @@ from sleekxmpp.basexmpp import BaseXMPP
|
|||
from sleekxmpp.clientxmpp import ClientXMPP
|
||||
from sleekxmpp.componentxmpp import ComponentXMPP
|
||||
from sleekxmpp.stanza import Message, Presence, Iq
|
||||
from sleekxmpp.jid import JID
|
||||
from sleekxmpp.jid import JID, InvalidJID
|
||||
from sleekxmpp.xmlstream.handler import *
|
||||
from sleekxmpp.xmlstream import XMLStream, RestartStream
|
||||
from sleekxmpp.xmlstream.matcher import *
|
||||
|
|
|
@ -140,13 +140,12 @@ def _validate_node(node):
|
|||
"""
|
||||
try:
|
||||
if node is not None:
|
||||
if not node:
|
||||
raise InvalidJID('Localpart must not be 0 bytes')
|
||||
|
||||
node = nodeprep(node)
|
||||
|
||||
if not node:
|
||||
raise InvalidJID('Localpart must not be 0 bytes')
|
||||
if len(node) > 1023:
|
||||
raise InvalidJID('Localpart must be less than 1024 bytes')
|
||||
return node
|
||||
except stringprep_profiles.StringPrepError:
|
||||
raise InvalidJID('Invalid local part')
|
||||
|
@ -179,6 +178,7 @@ def _validate_domain(domain):
|
|||
if not ip_addr and hasattr(socket, 'inet_pton'):
|
||||
try:
|
||||
socket.inet_pton(socket.AF_INET6, domain.strip('[]'))
|
||||
domain = '[%s]' % domain.strip('[]')
|
||||
ip_addr = True
|
||||
except socket.error:
|
||||
pass
|
||||
|
@ -186,12 +186,19 @@ def _validate_domain(domain):
|
|||
if not ip_addr:
|
||||
# This is a domain name, which must be checked further
|
||||
|
||||
if domain and domain[-1] == '.':
|
||||
domain = domain[:-1]
|
||||
|
||||
domain_parts = []
|
||||
for label in domain.split('.'):
|
||||
try:
|
||||
label = encodings.idna.nameprep(label)
|
||||
encodings.idna.ToASCII(label)
|
||||
pass_nameprep = True
|
||||
except UnicodeError:
|
||||
pass_nameprep = False
|
||||
|
||||
if not pass_nameprep:
|
||||
raise InvalidJID('Could not encode domain as ASCII')
|
||||
|
||||
if label.startswith('xn--'):
|
||||
|
@ -209,6 +216,8 @@ def _validate_domain(domain):
|
|||
|
||||
if not domain:
|
||||
raise InvalidJID('Domain must not be 0 bytes')
|
||||
if len(domain) > 1023:
|
||||
raise InvalidJID('Domain must be less than 1024 bytes')
|
||||
|
||||
return domain
|
||||
|
||||
|
@ -222,13 +231,12 @@ def _validate_resource(resource):
|
|||
"""
|
||||
try:
|
||||
if resource is not None:
|
||||
if not resource:
|
||||
raise InvalidJID('Resource must not be 0 bytes')
|
||||
|
||||
resource = resourceprep(resource)
|
||||
|
||||
if not resource:
|
||||
raise InvalidJID('Resource must not be 0 bytes')
|
||||
if len(resource) > 1023:
|
||||
raise InvalidJID('Resource must be less than 1024 bytes')
|
||||
return resource
|
||||
except stringprep_profiles.StringPrepError:
|
||||
raise InvalidJID('Invalid resource')
|
||||
|
|
|
@ -77,6 +77,9 @@ def check_bidi(data):
|
|||
character MUST be the first character of the string, and a
|
||||
RandALCat character MUST be the last character of the string.
|
||||
"""
|
||||
if not data:
|
||||
return data
|
||||
|
||||
has_lcat = False
|
||||
has_randal = False
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from sleekxmpp.test import *
|
||||
from sleekxmpp import JID
|
||||
from sleekxmpp import JID, InvalidJID
|
||||
|
||||
|
||||
class TestJIDClass(SleekTest):
|
||||
|
@ -137,5 +137,146 @@ class TestJIDClass(SleekTest):
|
|||
self.assertFalse(jid1 == jid2, "Same JIDs are not considered equal")
|
||||
self.assertTrue(jid1 != jid2, "Same JIDs are considered not equal")
|
||||
|
||||
def testZeroLengthDomain(self):
|
||||
self.assertRaises(InvalidJID, JID, domain='')
|
||||
self.assertRaises(InvalidJID, JID, 'user@/resource')
|
||||
|
||||
def testZeroLengthLocalPart(self):
|
||||
self.assertRaises(InvalidJID, JID, local='', domain='test.com')
|
||||
self.assertRaises(InvalidJID, JID, '@/test.com')
|
||||
|
||||
def testZeroLengthResource(self):
|
||||
self.assertRaises(InvalidJID, JID, domain='test.com', resource='')
|
||||
self.assertRaises(InvalidJID, JID, 'test.com/')
|
||||
|
||||
def test1023LengthDomain(self):
|
||||
domain = ('a.' * 509) + 'a.com'
|
||||
jid1 = JID(domain=domain)
|
||||
jid2 = JID('user@%s/resource' % domain)
|
||||
|
||||
def test1023LengthLocalPart(self):
|
||||
local = 'a' * 1023
|
||||
jid1 = JID(local=local, domain='test.com')
|
||||
jid2 = JID('%s@test.com' % local)
|
||||
|
||||
def test1023LengthResource(self):
|
||||
resource = 'r' * 1023
|
||||
jid1 = JID(domain='test.com', resource=resource)
|
||||
jid2 = JID('test.com/%s' % resource)
|
||||
|
||||
def test1024LengthDomain(self):
|
||||
domain = ('a.' * 509) + 'aa.com'
|
||||
self.assertRaises(InvalidJID, JID, domain=domain)
|
||||
self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain)
|
||||
|
||||
def test1024LengthLocalPart(self):
|
||||
local = 'a' * 1024
|
||||
self.assertRaises(InvalidJID, JID, local=local, domain='test.com')
|
||||
self.assertRaises(InvalidJID, JID, '%s@/test.com' % local)
|
||||
|
||||
def test1024LengthResource(self):
|
||||
resource = 'r' * 1024
|
||||
self.assertRaises(InvalidJID, JID, domain='test.com', resource=resource)
|
||||
self.assertRaises(InvalidJID, JID, 'test.com/%s' % resource)
|
||||
|
||||
def testTooLongDomainLabel(self):
|
||||
domain = ('a' * 64) + '.com'
|
||||
self.assertRaises(InvalidJID, JID, domain=domain)
|
||||
self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain)
|
||||
|
||||
def testDomainEmptyLabel(self):
|
||||
domain = 'aaa..bbb.com'
|
||||
self.assertRaises(InvalidJID, JID, domain=domain)
|
||||
self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain)
|
||||
|
||||
def testDomainIPv4(self):
|
||||
domain = '127.0.0.1'
|
||||
jid1 = JID(domain=domain)
|
||||
jid2 = JID('user@%s/resource' % domain)
|
||||
|
||||
def testDomainIPv6(self):
|
||||
domain = '[::1]'
|
||||
jid1 = JID(domain=domain)
|
||||
jid2 = JID('user@%s/resource' % domain)
|
||||
|
||||
def testDomainInvalidIPv6NoBrackets(self):
|
||||
domain = '::1'
|
||||
jid1 = JID(domain=domain)
|
||||
jid2 = JID('user@%s/resource' % domain)
|
||||
|
||||
self.assertEqual(jid1.domain, '[::1]')
|
||||
self.assertEqual(jid2.domain, '[::1]')
|
||||
|
||||
def testDomainInvalidIPv6MissingBracket(self):
|
||||
domain = '[::1'
|
||||
jid1 = JID(domain=domain)
|
||||
jid2 = JID('user@%s/resource' % domain)
|
||||
|
||||
self.assertEqual(jid1.domain, '[::1]')
|
||||
self.assertEqual(jid2.domain, '[::1]')
|
||||
|
||||
def testDomainWithPort(self):
|
||||
domain = 'example.com:5555'
|
||||
self.assertRaises(InvalidJID, JID, domain=domain)
|
||||
self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain)
|
||||
|
||||
def testDomainWithTrailingDot(self):
|
||||
domain = 'example.com.'
|
||||
jid1 = JID(domain=domain)
|
||||
jid2 = JID('user@%s/resource' % domain)
|
||||
|
||||
self.assertEqual(jid1.domain, 'example.com')
|
||||
self.assertEqual(jid2.domain, 'example.com')
|
||||
|
||||
def testDomainWithDashes(self):
|
||||
domain = 'example.com-'
|
||||
self.assertRaises(InvalidJID, JID, domain=domain)
|
||||
self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain)
|
||||
|
||||
domain = '-example.com'
|
||||
self.assertRaises(InvalidJID, JID, domain=domain)
|
||||
self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain)
|
||||
|
||||
def testACEDomain(self):
|
||||
domain = 'xn--bcher-kva.ch'
|
||||
jid1 = JID(domain=domain)
|
||||
jid2 = JID('user@%s/resource' % domain)
|
||||
|
||||
self.assertEqual(jid1.domain.encode('utf-8'), b'b\xc3\xbccher.ch')
|
||||
self.assertEqual(jid2.domain.encode('utf-8'), b'b\xc3\xbccher.ch')
|
||||
|
||||
def testJIDEscapeExistingSequences(self):
|
||||
jid = JID(local='blah\\foo\\20bar', domain='example.com')
|
||||
self.assertEqual(jid.local, 'blah\\foo\\5c20bar')
|
||||
|
||||
def testJIDEscape(self):
|
||||
jid = JID(local='here\'s_a_wild_&_/cr%zy/_address_for:<wv>("IMPS")',
|
||||
domain='example.com')
|
||||
self.assertEqual(jid.local, r'here\27s_a_wild_\26_\2fcr%zy\2f_address_for\3a\3cwv\3e(\22IMPS\22)')
|
||||
|
||||
def testJIDUnescape(self):
|
||||
jid = JID(local='here\'s_a_wild_&_/cr%zy/_address_for:<wv>("IMPS")',
|
||||
domain='example.com')
|
||||
ujid = jid.unescape()
|
||||
self.assertEqual(ujid.local, 'here\'s_a_wild_&_/cr%zy/_address_for:<wv>("IMPS")')
|
||||
|
||||
jid = JID(local='blah\\foo\\20bar', domain='example.com')
|
||||
ujid = jid.unescape()
|
||||
self.assertEqual(ujid.local, 'blah\\foo\\20bar')
|
||||
|
||||
def testStartOrEndWithEscapedSpaces(self):
|
||||
local = ' foo'
|
||||
self.assertRaises(InvalidJID, JID, local=local, domain='example.com')
|
||||
self.assertRaises(InvalidJID, JID, '%s@example.com' % local)
|
||||
|
||||
local = 'bar '
|
||||
self.assertRaises(InvalidJID, JID, local=local, domain='example.com')
|
||||
self.assertRaises(InvalidJID, JID, '%s@example.com' % local)
|
||||
|
||||
# Need more input for these cases. A JID starting with \20 *is* valid
|
||||
# according to RFC 6122, but is not according to XEP-0106.
|
||||
#self.assertRaises(InvalidJID, JID, '%s@example.com' % '\\20foo2')
|
||||
#self.assertRaises(InvalidJID, JID, '%s@example.com' % 'bar2\\20')
|
||||
|
||||
|
||||
suite = unittest.TestLoader().loadTestsFromTestCase(TestJIDClass)
|
||||
|
|
Loading…
Reference in a new issue