diff --git a/sleekxmpp/xmlstream/stanzabase.py b/sleekxmpp/xmlstream/stanzabase.py
index 75b9b921..b95d837f 100644
--- a/sleekxmpp/xmlstream/stanzabase.py
+++ b/sleekxmpp/xmlstream/stanzabase.py
@@ -374,6 +374,49 @@ class ElementBase(object):
else:
return stanza.text
+ def _setSubText(self, name, text=None, keep=False):
+ """
+ Set the text contents of a sub element.
+
+ In case the element does not exist, a element will be created,
+ and its text contents will be set.
+
+ If the text is set to an empty string, or None, then the
+ element will be removed, unless keep is set to True.
+
+ Arguments:
+ name -- The name or XPath expression of the element.
+ text -- The new textual content of the element. If the text
+ is an empty string or None, the element will be removed
+ unless the parameter keep is True.
+ keep -- Indicates if the element should be kept if its text is
+ removed. Defaults to False.
+ """
+ name = self._fix_ns(name)
+ element = self.xml.find(name)
+
+ if not text and not keep:
+ return self.__delitem__(name)
+
+ if element is None:
+ # We need to add the element. If the provided name was
+ # an XPath expression, some of the intermediate elements
+ # may already exist. If so, we want to use those instead
+ # of generating new elements.
+ last_xml = self.xml
+ walked = []
+ for ename in name.split('/'):
+ walked.append(ename)
+ element = self.xml.find("/".join(walked))
+ if element is None:
+ element = ET.Element(ename)
+ last_xml.append(element)
+ last_xml = element
+ element = last_xml
+
+ element.text = text
+ return element
+
@property
def attrib(self): #backwards compatibility
return self
@@ -469,18 +512,6 @@ class ElementBase(object):
return False
return True
- def _setSubText(self, name, attrib={}, text=None):
- if '}' not in name:
- name = "{%s}%s" % (self.namespace, name)
- if text is None or text == '':
- return self.__delitem__(name)
- stanza = self.xml.find(name)
- if stanza is None:
- stanza = ET.Element(name)
- self.xml.append(stanza)
- stanza.text = text
- return stanza
-
def _delSub(self, name):
if '}' not in name:
name = "{%s}%s" % (self.namespace, name)
diff --git a/tests/test_elementbase.py b/tests/test_elementbase.py
index 78cf47d6..2b61489a 100644
--- a/tests/test_elementbase.py
+++ b/tests/test_elementbase.py
@@ -301,5 +301,56 @@ class TestElementBase(SleekTest):
self.failUnless(stanza['bar'] == 'found',
"_getSubText value incorrect: %s." % stanza['bar'])
+ def testSubElement(self):
+ """Test setting the contents of a sub element."""
+
+ class TestStanza(ElementBase):
+ name = "foo"
+ namespace = "foo"
+ interfaces = set(('bar', 'baz'))
+
+ def setBaz(self, value):
+ self._setSubText("wrapper/baz", text=value)
+
+ def getBaz(self):
+ return self._getSubText("wrapper/baz")
+
+ def setBar(self, value):
+ self._setSubText("wrapper/bar", text=value)
+
+ def getBar(self):
+ return self._getSubText("wrapper/bar")
+
+ stanza = TestStanza()
+ stanza['bar'] = 'a'
+ stanza['baz'] = 'b'
+ self.checkStanza(TestStanza, stanza, """
+
+
+ a
+ b
+
+
+ """)
+ stanza._setSubText('bar', text='', keep=True)
+ self.checkStanza(TestStanza, stanza, """
+
+
+
+ b
+
+
+ """, use_values=False)
+
+ stanza['bar'] = 'a'
+ stanza._setSubText('bar', text='')
+ self.checkStanza(TestStanza, stanza, """
+
+
+ b
+
+
+ """)
+
suite = unittest.TestLoader().loadTestsFromTestCase(TestElementBase)