2009-06-03 22:56:51 +00:00
|
|
|
from . import base
|
|
|
|
from xml.etree import cElementTree
|
|
|
|
from xml.parsers.expat import ExpatError
|
|
|
|
|
2009-07-11 19:31:20 +00:00
|
|
|
ignore_ns = False
|
|
|
|
|
2009-06-03 22:56:51 +00:00
|
|
|
class MatchXMLMask(base.MatcherBase):
|
|
|
|
|
|
|
|
def __init__(self, criteria):
|
|
|
|
base.MatcherBase.__init__(self, criteria)
|
|
|
|
if type(criteria) == type(''):
|
|
|
|
self._criteria = cElementTree.fromstring(self._criteria)
|
|
|
|
self.default_ns = 'jabber:client'
|
|
|
|
|
|
|
|
def setDefaultNS(self, ns):
|
|
|
|
self.default_ns = ns
|
|
|
|
|
|
|
|
def match(self, xml):
|
2010-01-16 05:07:28 +00:00
|
|
|
xml = xml.xml
|
2009-06-03 22:56:51 +00:00
|
|
|
return self.maskcmp(xml, self._criteria, True)
|
|
|
|
|
|
|
|
def maskcmp(self, source, maskobj, use_ns=False, default_ns='__no_ns__'):
|
|
|
|
"""maskcmp(xmlobj, maskobj):
|
|
|
|
Compare etree xml object to etree xml object mask"""
|
2009-07-11 19:31:20 +00:00
|
|
|
use_ns = not ignore_ns
|
2009-06-03 22:56:51 +00:00
|
|
|
#TODO require namespaces
|
|
|
|
if source == None: #if element not found (happens during recursive check below)
|
|
|
|
return False
|
2010-01-08 06:03:02 +00:00
|
|
|
if not hasattr(maskobj, 'attrib'): #if the mask is a string, make it an xml obj
|
2009-06-03 22:56:51 +00:00
|
|
|
try:
|
|
|
|
maskobj = cElementTree.fromstring(maskobj)
|
|
|
|
except ExpatError:
|
|
|
|
logging.log(logging.WARNING, "Expat error: %s\nIn parsing: %s" % ('', maskobj))
|
|
|
|
if not use_ns and source.tag.split('}', 1)[-1] != maskobj.tag.split('}', 1)[-1]: # strip off ns and compare
|
|
|
|
return False
|
|
|
|
if use_ns and (source.tag != maskobj.tag and "{%s}%s" % (self.default_ns, maskobj.tag) != source.tag ):
|
|
|
|
return False
|
|
|
|
if maskobj.text and source.text != maskobj.text:
|
|
|
|
return False
|
|
|
|
for attr_name in maskobj.attrib: #compare attributes
|
|
|
|
if source.attrib.get(attr_name, "__None__") != maskobj.attrib[attr_name]:
|
|
|
|
return False
|
|
|
|
#for subelement in maskobj.getiterator()[1:]: #recursively compare subelements
|
|
|
|
for subelement in maskobj: #recursively compare subelements
|
2009-07-11 20:34:27 +00:00
|
|
|
if use_ns:
|
|
|
|
if not self.maskcmp(source.find(subelement.tag), subelement, use_ns):
|
|
|
|
return False
|
|
|
|
else:
|
|
|
|
if not self.maskcmp(self.getChildIgnoreNS(source, subelement.tag), subelement, use_ns):
|
|
|
|
return False
|
2009-06-03 22:56:51 +00:00
|
|
|
return True
|
2009-07-11 20:34:27 +00:00
|
|
|
|
|
|
|
def getChildIgnoreNS(self, xml, tag):
|
|
|
|
tag = tag.split('}')[-1]
|
|
|
|
try:
|
|
|
|
idx = [c.tag.split('}')[-1] for c in xml.getchildren()].index(tag)
|
|
|
|
except ValueError:
|
|
|
|
return None
|
|
|
|
return xml.getchildren()[idx]
|