Update tostring to inject xmlns definitions when needed.

This commit is contained in:
Lance Stout 2013-01-24 02:43:46 -08:00
parent 4f9a95b011
commit 403b1802ec

View file

@ -24,8 +24,8 @@ if sys.version_info < (3, 0):
XML_NS = 'http://www.w3.org/XML/1998/namespace'
def tostring(xml=None, xmlns='', stream=None,
outbuffer='', top_level=False, open_only=False):
def tostring(xml=None, xmlns='', stream=None, outbuffer='',
top_level=False, open_only=False, namespaces=None):
"""Serialize an XML object to a Unicode string.
If an outer xmlns is provided using ``xmlns``, then the current element's
@ -41,7 +41,8 @@ def tostring(xml=None, xmlns='', stream=None,
during recursive calls.
:param bool top_level: Indicates that the element is the outermost
element.
:param set namespaces: Track which namespaces are in active use so
that new ones can be declared when needed.
:type xml: :py:class:`~xml.etree.ElementTree.Element`
:type stream: :class:`~sleekxmpp.xmlstream.xmlstream.XMLStream`
@ -63,6 +64,7 @@ def tostring(xml=None, xmlns='', stream=None,
default_ns = ''
stream_ns = ''
use_cdata = False
if stream:
default_ns = stream.default_ns
stream_ns = stream.stream_ns
@ -82,6 +84,7 @@ def tostring(xml=None, xmlns='', stream=None,
output.append(namespace)
# Output escaped attribute values.
new_namespaces = set()
for attrib, value in xml.attrib.items():
value = escape(value, use_cdata)
if '}' not in attrib:
@ -92,9 +95,15 @@ def tostring(xml=None, xmlns='', stream=None,
if stream and attrib_ns in stream.namespace_map:
mapped_ns = stream.namespace_map[attrib_ns]
if mapped_ns:
output.append(' %s:%s="%s"' % (mapped_ns,
attrib,
value))
if namespaces is None:
namespaces = set()
if attrib_ns not in namespaces:
namespaces.add(attrib_ns)
new_namespaces.add(attrib_ns)
output.append(' xmlns:%s="%s"' % (
mapped_ns, attrib_ns))
output.append(' %s:%s="%s"' % (
mapped_ns, attrib, value))
elif attrib_ns == XML_NS:
output.append(' xml:%s="%s"' % (attrib, value))
@ -110,7 +119,8 @@ def tostring(xml=None, xmlns='', stream=None,
output.append(escape(xml.text, use_cdata))
if len(xml):
for child in xml:
output.append(tostring(child, tag_xmlns, stream))
output.append(tostring(child, tag_xmlns, stream,
namespaces=namespaces))
output.append("</%s>" % tag_name)
elif xml.text:
# If we only have text content.
@ -121,6 +131,11 @@ def tostring(xml=None, xmlns='', stream=None,
if xml.tail:
# If there is additional text after the element.
output.append(escape(xml.tail, use_cdata))
for ns in new_namespaces:
# Remove namespaces introduced in this context. This is necessary
# because the namespaces object continues to be shared with other
# contexts.
namespaces.remove(ns)
return ''.join(output)