stanzabase: make _get_plugin part of the public API

it is the only way I know of checking if an element is present in a
stanza without creating it or checking the XML manually.
This commit is contained in:
mathieui 2021-02-04 18:56:18 +01:00
parent 917cb555d5
commit 69b265b975
2 changed files with 23 additions and 15 deletions

View file

@ -106,7 +106,7 @@ class XEP_0405(BasePlugin):
contacts = [] contacts = []
mix = [] mix = []
for item in result['roster']: for item in result['roster']:
channel = item._get_plugin('channel', check=True) channel = item.get_plugin('channel', check=True)
if channel: if channel:
mix.append(item) mix.append(item)
else: else:

View file

@ -12,11 +12,12 @@
:license: MIT, see LICENSE for more details :license: MIT, see LICENSE for more details
""" """
from __future__ import with_statement, unicode_literals from __future__ import annotations
import copy import copy
import logging import logging
import weakref import weakref
from typing import Optional
from xml.etree import ElementTree as ET from xml.etree import ElementTree as ET
from slixmpp.xmlstream import JID from slixmpp.xmlstream import JID
@ -466,7 +467,13 @@ class ElementBase(object):
""" """
return self.init_plugin(attrib, lang) return self.init_plugin(attrib, lang)
def _get_plugin(self, name, lang=None, check=False): def get_plugin(self, name: str, lang: Optional[str] = None, check: bool = False) -> Optional[ElementBase]:
"""Retrieve a stanza plugin.
:param check: Return None instead of creating the object if True.
:param name: Stanza plugin attribute name.
:param lang: xml:lang of the element to retrieve.
"""
if lang is None: if lang is None:
lang = self.get_lang() lang = self.get_lang()
@ -614,7 +621,7 @@ class ElementBase(object):
self[full_interface] = value self[full_interface] = value
elif interface in self.plugin_attrib_map: elif interface in self.plugin_attrib_map:
if interface not in iterable_interfaces: if interface not in iterable_interfaces:
plugin = self._get_plugin(interface, lang) plugin = self.get_plugin(interface, lang)
if plugin: if plugin:
plugin.values = value plugin.values = value
return self return self
@ -660,7 +667,7 @@ class ElementBase(object):
if self.plugin_overrides: if self.plugin_overrides:
name = self.plugin_overrides.get(get_method, None) name = self.plugin_overrides.get(get_method, None)
if name: if name:
plugin = self._get_plugin(name, lang) plugin = self.get_plugin(name, lang)
if plugin: if plugin:
handler = getattr(plugin, get_method, None) handler = getattr(plugin, get_method, None)
if handler: if handler:
@ -677,7 +684,7 @@ class ElementBase(object):
else: else:
return self._get_attr(attrib) return self._get_attr(attrib)
elif attrib in self.plugin_attrib_map: elif attrib in self.plugin_attrib_map:
plugin = self._get_plugin(attrib, lang) plugin = self.get_plugin(attrib, lang)
if plugin and plugin.is_extension: if plugin and plugin.is_extension:
return plugin[full_attrib] return plugin[full_attrib]
return plugin return plugin
@ -732,7 +739,7 @@ class ElementBase(object):
if self.plugin_overrides: if self.plugin_overrides:
name = self.plugin_overrides.get(set_method, None) name = self.plugin_overrides.get(set_method, None)
if name: if name:
plugin = self._get_plugin(name, lang) plugin = self.get_plugin(name, lang)
if plugin: if plugin:
handler = getattr(plugin, set_method, None) handler = getattr(plugin, set_method, None)
if handler: if handler:
@ -764,7 +771,7 @@ class ElementBase(object):
else: else:
self.__delitem__(attrib) self.__delitem__(attrib)
elif attrib in self.plugin_attrib_map: elif attrib in self.plugin_attrib_map:
plugin = self._get_plugin(attrib, lang) plugin = self.get_plugin(attrib, lang)
if plugin: if plugin:
plugin[full_attrib] = value plugin[full_attrib] = value
return self return self
@ -816,7 +823,7 @@ class ElementBase(object):
if self.plugin_overrides: if self.plugin_overrides:
name = self.plugin_overrides.get(del_method, None) name = self.plugin_overrides.get(del_method, None)
if name: if name:
plugin = self._get_plugin(attrib, lang) plugin = self.get_plugin(attrib, lang)
if plugin: if plugin:
handler = getattr(plugin, del_method, None) handler = getattr(plugin, del_method, None)
if handler: if handler:
@ -832,7 +839,7 @@ class ElementBase(object):
else: else:
self._del_attr(attrib) self._del_attr(attrib)
elif attrib in self.plugin_attrib_map: elif attrib in self.plugin_attrib_map:
plugin = self._get_plugin(attrib, lang, check=True) plugin = self.get_plugin(attrib, lang, check=True)
if not plugin: if not plugin:
return self return self
if plugin.is_extension: if plugin.is_extension:
@ -1037,12 +1044,10 @@ class ElementBase(object):
parent_path = "/".join(path[:len(path) - level - 1]) parent_path = "/".join(path[:len(path) - level - 1])
elements = self.xml.findall(element_path) elements = self.xml.findall(element_path)
if parent_path == '': if parent_path == '':
parent_path = None parent_path = None
if parent_path is not None: if parent_path is not None:
parent = self.xml.find(parent_path) parent = self.xml.find(parent_path)
if elements: if elements:
if parent is None: if parent is None:
parent = self.xml parent = self.xml
@ -1117,7 +1122,7 @@ class ElementBase(object):
next_tag = xpath[1].split('@')[0].split('}')[-1] next_tag = xpath[1].split('@')[0].split('}')[-1]
langs = [name[1] for name in self.plugins if name[0] == next_tag] langs = [name[1] for name in self.plugins if name[0] == next_tag]
for lang in langs: for lang in langs:
plugin = self._get_plugin(next_tag, lang) plugin = self.get_plugin(next_tag, lang)
if plugin and plugin.match(xpath[1:]): if plugin and plugin.match(xpath[1:]):
return True return True
return False return False
@ -1341,6 +1346,9 @@ class ElementBase(object):
"""Use the stanza's serialized XML as its representation.""" """Use the stanza's serialized XML as its representation."""
return self.__str__() return self.__str__()
# Compatibility.
_get_plugin = get_plugin
class StanzaBase(ElementBase): class StanzaBase(ElementBase):