XEP-0047: use coroutines for send(), sendall() and the new sendfile().

This commit is contained in:
Emmanuel Gil Peyrot 2015-04-14 19:19:46 +02:00
parent 058c530787
commit 4415d3be1a
2 changed files with 20 additions and 11 deletions

View file

@ -31,7 +31,8 @@ class IBBytestream(object):
self.recv_queue = asyncio.Queue() self.recv_queue = asyncio.Queue()
def send(self, data): @asyncio.coroutine
def send(self, data, timeout=None):
if not self.stream_started or self.stream_out_closed: if not self.stream_started or self.stream_out_closed:
raise socket.error raise socket.error
if len(data) > self.block_size: if len(data) > self.block_size:
@ -55,17 +56,22 @@ class IBBytestream(object):
iq['ibb_data']['sid'] = self.sid iq['ibb_data']['sid'] = self.sid
iq['ibb_data']['seq'] = seq iq['ibb_data']['seq'] = seq
iq['ibb_data']['data'] = data iq['ibb_data']['data'] = data
iq.send(callback=self._recv_ack) yield from iq.send(timeout=timeout)
return len(data) return len(data)
def sendall(self, data): @asyncio.coroutine
def sendall(self, data, timeout=None):
sent_len = 0 sent_len = 0
while sent_len < len(data): while sent_len < len(data):
sent_len += self.send(data[sent_len:self.block_size]) sent_len += yield from self.send(data[sent_len:self.block_size], timeout=timeout)
def _recv_ack(self, iq): @asyncio.coroutine
if iq['type'] == 'error': def sendfile(self, file, timeout=None):
self.close() while True:
data = file.read(self.block_size)
if not data:
break
yield from self.send(data, timeout=timeout)
def _recv_data(self, stanza): def _recv_data(self, stanza):
new_seq = stanza['ibb_data']['seq'] new_seq = stanza['ibb_data']['seq']
@ -80,7 +86,7 @@ class IBBytestream(object):
raise XMPPError('not-acceptable') raise XMPPError('not-acceptable')
self.recv_queue.put_nowait(data) self.recv_queue.put_nowait(data)
self.xmpp.event('ibb_stream_data', {'stream': self, 'data': data}) self.xmpp.event('ibb_stream_data', self)
if isinstance(stanza, Iq): if isinstance(stanza, Iq):
stanza.reply().send() stanza.reply().send()
@ -93,7 +99,7 @@ class IBBytestream(object):
raise socket.error raise socket.error
return self.recv_queue.get_nowait() return self.recv_queue.get_nowait()
def close(self): def close(self, timeout=None):
iq = self.xmpp.Iq() iq = self.xmpp.Iq()
iq['type'] = 'set' iq['type'] = 'set'
iq['to'] = self.peer_jid iq['to'] = self.peer_jid
@ -102,8 +108,9 @@ class IBBytestream(object):
self.stream_out_closed = True self.stream_out_closed = True
def _close_stream(_): def _close_stream(_):
self.stream_in_closed = True self.stream_in_closed = True
iq.send(callback=_close_stream) future = iq.send(timeout=timeout, callback=_close_stream)
self.xmpp.event('ibb_stream_end', self) self.xmpp.event('ibb_stream_end', self)
return future
def _closed(self, iq): def _closed(self, iq):
self.stream_in_closed = True self.stream_in_closed = True

View file

@ -1,3 +1,4 @@
import asyncio
import threading import threading
import time import time
@ -78,6 +79,7 @@ class TestInBandByteStreams(SlixTest):
self.assertEqual(events, set(['ibb_stream_start', 'callback'])) self.assertEqual(events, set(['ibb_stream_start', 'callback']))
@asyncio.coroutine
def testSendData(self): def testSendData(self):
"""Test sending data over an in-band bytestream.""" """Test sending data over an in-band bytestream."""
@ -115,7 +117,7 @@ class TestInBandByteStreams(SlixTest):
# Test sending data out # Test sending data out
stream.send("Testing") yield from stream.send("Testing")
self.send(""" self.send("""
<iq type="set" id="2" <iq type="set" id="2"