asyncio: avoid StreamReader.readuntil

This commit is contained in:
Janus 2017-12-12 19:07:54 +01:00
parent ef9236dd0c
commit 88f906bc2a
2 changed files with 29 additions and 19 deletions

View File

@ -34,6 +34,7 @@ import asyncio
import json
import asyncio.streams
from asyncio.sslproto import SSLProtocol
import io
import requests
@ -93,6 +94,7 @@ class Interface(util.PrintError):
self.unanswered_requests = {}
self.last_ping = 0
self.closed_remotely = False
self.buf = bytes()
def conn_coro(self, context):
return asyncio.open_connection(self.host, self.port, ssl=context)
@ -163,26 +165,35 @@ class Interface(util.PrintError):
if self.writer:
self.writer.close()
def _try_extract(self):
try:
pos = self.buf.index(b"\n")
except ValueError:
return
obj = self.buf[:pos]
try:
obj = json.loads(obj.decode("ascii"))
except ValueError:
return
else:
self.buf = self.buf[pos+1:]
self.last_action = time.time()
return obj
async def get(self):
reader, _ = await self._get_read_write()
obj = b""
while True:
if len(obj) > 3000000:
raise BaseException("too much data: " + str(len(obj)))
try:
obj += await reader.readuntil(b"\n")
except asyncio.LimitOverrunError as e:
obj += await reader.read(e.consumed)
except asyncio.streams.IncompleteReadError as e:
return None
try:
obj = json.loads(obj.decode("ascii"))
except ValueError:
continue
else:
self.last_action = time.time()
return obj
tried = self._try_extract()
if tried: return tried
temp = io.BytesIO()
starttime = time.time()
while time.time() - starttime < 1:
try:
data = await asyncio.wait_for(reader.read(2**8), 1)
temp.write(data)
except asyncio.TimeoutError:
break
self.buf += temp.getvalue()
def idle_time(self):
return time.time() - self.last_action

View File

@ -38,8 +38,7 @@ def makeProtocolFactory(receivedQueue, connUpLock, ca_certs):
class ReaderEmulator:
def __init__(self, receivedQueue):
self.receivedQueue = receivedQueue
async def readuntil(self, splitter):
assert splitter == b"\n"
async def read(self, _bufferSize):
return await self.receivedQueue.get()
class WriterEmulator:
@ -66,7 +65,7 @@ if __name__ == "__main__":
reader, writer = await sslInSocksReaderWriter(aiosocks.Socks4Addr("127.0.0.1", 9050), None, "songbird.bauerj.eu", 50002, None)
writer.write(b'{"id":0,"method":"server.version","args":["3.0.2", "1.1"]}\n')
await writer.drain()
print(await reader.readuntil(b"\n"))
print(await reader.read(4096))
writer.close()
fut.set_result("finished")
except BaseException as e: