asyncio: avoid StreamReader.readuntil

This commit is contained in:
Janus 2017-12-12 19:07:54 +01:00
parent ef9236dd0c
commit 88f906bc2a
2 changed files with 29 additions and 19 deletions

View File

@ -34,6 +34,7 @@ import asyncio
import json import json
import asyncio.streams import asyncio.streams
from asyncio.sslproto import SSLProtocol from asyncio.sslproto import SSLProtocol
import io
import requests import requests
@ -93,6 +94,7 @@ class Interface(util.PrintError):
self.unanswered_requests = {} self.unanswered_requests = {}
self.last_ping = 0 self.last_ping = 0
self.closed_remotely = False self.closed_remotely = False
self.buf = bytes()
def conn_coro(self, context): def conn_coro(self, context):
return asyncio.open_connection(self.host, self.port, ssl=context) return asyncio.open_connection(self.host, self.port, ssl=context)
@ -163,26 +165,35 @@ class Interface(util.PrintError):
if self.writer: if self.writer:
self.writer.close() self.writer.close()
def _try_extract(self):
try:
pos = self.buf.index(b"\n")
except ValueError:
return
obj = self.buf[:pos]
try:
obj = json.loads(obj.decode("ascii"))
except ValueError:
return
else:
self.buf = self.buf[pos+1:]
self.last_action = time.time()
return obj
async def get(self): async def get(self):
reader, _ = await self._get_read_write() reader, _ = await self._get_read_write()
obj = b""
while True: while True:
if len(obj) > 3000000: tried = self._try_extract()
raise BaseException("too much data: " + str(len(obj))) if tried: return tried
try: temp = io.BytesIO()
obj += await reader.readuntil(b"\n") starttime = time.time()
except asyncio.LimitOverrunError as e: while time.time() - starttime < 1:
obj += await reader.read(e.consumed) try:
except asyncio.streams.IncompleteReadError as e: data = await asyncio.wait_for(reader.read(2**8), 1)
return None temp.write(data)
try: except asyncio.TimeoutError:
obj = json.loads(obj.decode("ascii")) break
except ValueError: self.buf += temp.getvalue()
continue
else:
self.last_action = time.time()
return obj
def idle_time(self): def idle_time(self):
return time.time() - self.last_action return time.time() - self.last_action

View File

@ -38,8 +38,7 @@ def makeProtocolFactory(receivedQueue, connUpLock, ca_certs):
class ReaderEmulator: class ReaderEmulator:
def __init__(self, receivedQueue): def __init__(self, receivedQueue):
self.receivedQueue = receivedQueue self.receivedQueue = receivedQueue
async def readuntil(self, splitter): async def read(self, _bufferSize):
assert splitter == b"\n"
return await self.receivedQueue.get() return await self.receivedQueue.get()
class WriterEmulator: class WriterEmulator:
@ -66,7 +65,7 @@ if __name__ == "__main__":
reader, writer = await sslInSocksReaderWriter(aiosocks.Socks4Addr("127.0.0.1", 9050), None, "songbird.bauerj.eu", 50002, None) reader, writer = await sslInSocksReaderWriter(aiosocks.Socks4Addr("127.0.0.1", 9050), None, "songbird.bauerj.eu", 50002, None)
writer.write(b'{"id":0,"method":"server.version","args":["3.0.2", "1.1"]}\n') writer.write(b'{"id":0,"method":"server.version","args":["3.0.2", "1.1"]}\n')
await writer.drain() await writer.drain()
print(await reader.readuntil(b"\n")) print(await reader.read(4096))
writer.close() writer.close()
fut.set_result("finished") fut.set_result("finished")
except BaseException as e: except BaseException as e: