slixmpp/sleekxmpp/xmlstream/tostring.py
Lance Stout 9a08dfc7d4 Add support for using CDATA for escaping.
CDATA escaping is disabled by default, but may be enabled by setting:

    self.use_cdata = True

Closes issue #114
2012-07-24 03:25:55 -07:00

157 lines
5.1 KiB
Python

# -*- coding: utf-8 -*-
"""
sleekxmpp.xmlstream.tostring
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
This module converts XML objects into Unicode strings and
intelligently includes namespaces only when necessary to
keep the output readable.
Part of SleekXMPP: The Sleek XMPP Library
:copyright: (c) 2011 Nathanael C. Fritz
:license: MIT, see LICENSE for more details
"""
from __future__ import unicode_literals
import sys
if sys.version_info < (3, 0):
import types
XML_NS = 'http://www.w3.org/XML/1998/namespace'
def tostring(xml=None, xmlns='', stanza_ns='', stream=None,
outbuffer='', top_level=False, open_only=False):
"""Serialize an XML object to a Unicode string.
If namespaces are provided using ``xmlns`` or ``stanza_ns``, then
elements that use those namespaces will not include the xmlns attribute
in the output.
:param XML xml: The XML object to serialize.
:param string xmlns: Optional namespace of an element wrapping the XML
object.
:param string stanza_ns: The namespace of the stanza object that contains
the XML object.
:param stream: The XML stream that generated the XML object.
:param string outbuffer: Optional buffer for storing serializations
during recursive calls.
:param bool top_level: Indicates that the element is the outermost
element.
:type xml: :py:class:`~xml.etree.ElementTree.Element`
:type stream: :class:`~sleekxmpp.xmlstream.xmlstream.XMLStream`
:rtype: Unicode string
"""
# Add previous results to the start of the output.
output = [outbuffer]
# Extract the element's tag name.
tag_name = xml.tag.split('}', 1)[-1]
# Extract the element's namespace if it is defined.
if '}' in xml.tag:
tag_xmlns = xml.tag.split('}', 1)[0][1:]
else:
tag_xmlns = ''
default_ns = ''
stream_ns = ''
use_cdata = False
if stream:
default_ns = stream.default_ns
stream_ns = stream.stream_ns
use_cdata = stream.use_cdata
# Output the tag name and derived namespace of the element.
namespace = ''
if top_level and tag_xmlns not in ['', default_ns, stream_ns] or \
tag_xmlns not in ['', xmlns, stanza_ns, stream_ns]:
namespace = ' xmlns="%s"' % tag_xmlns
if stream and tag_xmlns in stream.namespace_map:
mapped_namespace = stream.namespace_map[tag_xmlns]
if mapped_namespace:
tag_name = "%s:%s" % (mapped_namespace, tag_name)
output.append("<%s" % tag_name)
output.append(namespace)
# Output escaped attribute values.
for attrib, value in xml.attrib.items():
value = escape(value, use_cdata)
if '}' not in attrib:
output.append(' %s="%s"' % (attrib, value))
else:
attrib_ns = attrib.split('}')[0][1:]
attrib = attrib.split('}')[1]
if stream and attrib_ns in stream.namespace_map:
mapped_ns = stream.namespace_map[attrib_ns]
if mapped_ns:
output.append(' %s:%s="%s"' % (mapped_ns,
attrib,
value))
elif attrib_ns == XML_NS:
output.append(' xml:%s="%s"' % (attrib, value))
if open_only:
# Only output the opening tag, regardless of content.
output.append(">")
return ''.join(output)
if len(xml) or xml.text:
# If there are additional child elements to serialize.
output.append(">")
if xml.text:
output.append(escape(xml.text, use_cdata))
if len(xml):
for child in xml:
output.append(tostring(child, tag_xmlns, stanza_ns, stream))
output.append("</%s>" % tag_name)
elif xml.text:
# If we only have text content.
output.append(">%s</%s>" % (escape(xml.text, use_cdata), tag_name))
else:
# Empty element.
output.append(" />")
if xml.tail:
# If there is additional text after the element.
output.append(escape(xml.tail, use_cdata))
return ''.join(output)
def escape(text, use_cdata=False):
"""Convert special characters in XML to escape sequences.
:param string text: The XML text to convert.
:rtype: Unicode string
"""
if sys.version_info < (3, 0):
if type(text) != types.UnicodeType:
text = unicode(text, 'utf-8', 'ignore')
escapes = {'&': '&amp;',
'<': '&lt;',
'>': '&gt;',
"'": '&apos;',
'"': '&quot;'}
if not use_cdata:
text = list(text)
for i, c in enumerate(text):
text[i] = escapes.get(c, c)
return ''.join(text)
else:
escape_needed = False
for c in text:
if c in escapes:
escape_needed = True
break
if escape_needed:
escaped = map(lambda x : "<![CDATA[%s]]>" % x, text.split("]]>"))
return "<![CDATA[]]]><![CDATA[]>]]>".join(escaped)
return text