add a separate place for slow ass filters
This commit is contained in:
parent
672f1b28f6
commit
d97efa0bd8
1 changed files with 35 additions and 6 deletions
|
@ -21,7 +21,7 @@ import ssl
|
|||
import weakref
|
||||
import uuid
|
||||
|
||||
from asyncio import iscoroutinefunction
|
||||
from asyncio import iscoroutinefunction, wait
|
||||
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
|
@ -896,6 +896,31 @@ class XMLStream(asyncio.BaseProtocol):
|
|||
"""
|
||||
return xml
|
||||
|
||||
async def continue_slow_send(self, task, already_used):
|
||||
log.debug('rescheduled task: %s', task)
|
||||
data = await task
|
||||
log.debug('data for rescheduled task %s : %s', task, data)
|
||||
for filter in self.__filters['out']:
|
||||
if filter in already_used:
|
||||
continue
|
||||
if iscoroutinefunction(filter):
|
||||
data = await task
|
||||
else:
|
||||
data = filter(data)
|
||||
if data is None:
|
||||
return
|
||||
|
||||
if isinstance(data, ElementBase):
|
||||
for filter in self.__filters['out_sync']:
|
||||
data = filter(data)
|
||||
if data is None:
|
||||
return
|
||||
str_data = tostring(data.xml, xmlns=self.default_ns,
|
||||
stream=self, top_level=True)
|
||||
self.send_raw(str_data)
|
||||
else:
|
||||
self.send_raw(data)
|
||||
|
||||
async def run_filters(self):
|
||||
"""
|
||||
Background loop that processes stanzas to send.
|
||||
|
@ -905,9 +930,16 @@ class XMLStream(asyncio.BaseProtocol):
|
|||
try:
|
||||
if isinstance(data, ElementBase):
|
||||
if use_filters:
|
||||
already_run_filters = set()
|
||||
for filter in self.__filters['out']:
|
||||
already_run_filters.add(filter)
|
||||
if iscoroutinefunction(filter):
|
||||
data = await filter(data)
|
||||
task = asyncio.create_task(filter(data))
|
||||
completed, pending = await wait({task}, timeout=1)
|
||||
if pending:
|
||||
asyncio.ensure_future(self.continue_slow_send(task, already_run_filters))
|
||||
raise Exception("Slow coro, rescheduling")
|
||||
data = task.result()
|
||||
else:
|
||||
data = filter(data)
|
||||
if data is None:
|
||||
|
@ -916,10 +948,7 @@ class XMLStream(asyncio.BaseProtocol):
|
|||
if isinstance(data, ElementBase):
|
||||
if use_filters:
|
||||
for filter in self.__filters['out_sync']:
|
||||
if iscoroutinefunction(filter):
|
||||
data = await filter(data)
|
||||
else:
|
||||
data = filter(data)
|
||||
data = filter(data)
|
||||
if data is None:
|
||||
raise Exception('Empty stanza')
|
||||
str_data = tostring(data.xml, xmlns=self.default_ns,
|
||||
|
|
Loading…
Reference in a new issue