Tweak request handling
Pause serving sessions whose socket buffer is full (anti-DoS) Serve requests in batches of 8 Don't store the session in the request RPC has priority 0; every other session at least 1 Periodically consolidate small session groups into 1
This commit is contained in:
parent
b3b3f047c2
commit
263e88ad57
@ -43,7 +43,7 @@ class RPCClient(JSONRPC):
|
||||
future.cancel()
|
||||
print('request timed out after {}s'.format(timeout))
|
||||
else:
|
||||
await request.process(1)
|
||||
await request.process(self)
|
||||
|
||||
async def handle_response(self, result, error, method):
|
||||
if result and method in ('groups', 'sessions'):
|
||||
|
||||
@ -15,63 +15,59 @@ import time
|
||||
from lib.util import LoggedClass
|
||||
|
||||
|
||||
class SingleRequest(object):
|
||||
class RequestBase(object):
|
||||
'''An object that represents a queued request.'''
|
||||
|
||||
def __init__(self, remaining):
|
||||
self.remaining = remaining
|
||||
|
||||
class SingleRequest(RequestBase):
|
||||
'''An object that represents a single request.'''
|
||||
def __init__(self, session, payload):
|
||||
|
||||
def __init__(self, payload):
|
||||
super().__init__(1)
|
||||
self.payload = payload
|
||||
self.session = session
|
||||
self.count = 1
|
||||
|
||||
def remaining(self):
|
||||
return self.count
|
||||
|
||||
async def process(self, limit):
|
||||
async def process(self, session):
|
||||
'''Asynchronously handle the JSON request.'''
|
||||
binary = await self.session.process_single_payload(self.payload)
|
||||
self.remaining = 0
|
||||
binary = await session.process_single_payload(self.payload)
|
||||
if binary:
|
||||
self.session._send_bytes(binary)
|
||||
self.count = 0
|
||||
return 1
|
||||
session._send_bytes(binary)
|
||||
|
||||
def __str__(self):
|
||||
return str(self.payload)
|
||||
|
||||
|
||||
class BatchRequest(object):
|
||||
class BatchRequest(RequestBase):
|
||||
'''An object that represents a batch request and its processing state.
|
||||
|
||||
Batches are processed in chunks.
|
||||
'''
|
||||
|
||||
def __init__(self, session, payload):
|
||||
self.session = session
|
||||
def __init__(self, payload):
|
||||
super().__init__(len(payload))
|
||||
self.payload = payload
|
||||
self.done = 0
|
||||
self.parts = []
|
||||
|
||||
def remaining(self):
|
||||
return len(self.payload) - self.done
|
||||
|
||||
async def process(self, limit):
|
||||
async def process(self, session):
|
||||
'''Asynchronously handle the JSON batch according to the JSON 2.0
|
||||
spec.'''
|
||||
count = min(limit, self.remaining())
|
||||
for n in range(count):
|
||||
item = self.payload[self.done]
|
||||
part = await self.session.process_single_payload(item)
|
||||
target = max(self.remaining - 4, 0)
|
||||
while self.remaining > target:
|
||||
item = self.payload[len(self.payload) - self.remaining]
|
||||
self.remaining -= 1
|
||||
part = await session.process_single_payload(item)
|
||||
if part:
|
||||
self.parts.append(part)
|
||||
self.done += 1
|
||||
|
||||
total_len = sum(len(part) + 2 for part in self.parts)
|
||||
self.session.check_oversized_request(total_len)
|
||||
session.check_oversized_request(total_len)
|
||||
|
||||
if not self.remaining():
|
||||
if not self.remaining:
|
||||
if self.parts:
|
||||
binary = b'[' + b', '.join(self.parts) + b']'
|
||||
self.session._send_bytes(binary)
|
||||
|
||||
return count
|
||||
session._send_bytes(binary)
|
||||
|
||||
def __str__(self):
|
||||
return str(self.payload)
|
||||
@ -151,6 +147,7 @@ class JSONRPC(asyncio.Protocol, LoggedClass):
|
||||
self.bandwidth_used = 0
|
||||
self.bandwidth_limit = 5000000
|
||||
self.transport = None
|
||||
self.pause = False
|
||||
# Parts of an incomplete JSON line. We buffer them until
|
||||
# getting a newline.
|
||||
self.parts = []
|
||||
@ -186,11 +183,22 @@ class JSONRPC(asyncio.Protocol, LoggedClass):
|
||||
'''Handle an incoming client connection.'''
|
||||
self.transport = transport
|
||||
self.peer_info = transport.get_extra_info('peername')
|
||||
transport.set_write_buffer_limits(high=500000)
|
||||
|
||||
def connection_lost(self, exc):
|
||||
'''Handle client disconnection.'''
|
||||
pass
|
||||
|
||||
def pause_writing(self):
|
||||
'''Called by asyncio when the write buffer is full.'''
|
||||
self.log_info('pausing writing')
|
||||
self.pause = True
|
||||
|
||||
def resume_writing(self):
|
||||
'''Called by asyncio when the write buffer has room.'''
|
||||
self.log_info('resuming writing')
|
||||
self.pause = False
|
||||
|
||||
def close_connection(self):
|
||||
self.stop = time.time()
|
||||
if self.transport:
|
||||
@ -263,9 +271,9 @@ class JSONRPC(asyncio.Protocol, LoggedClass):
|
||||
if not message:
|
||||
self.send_json_error('empty batch', self.INVALID_REQUEST)
|
||||
return
|
||||
request = BatchRequest(self, message)
|
||||
request = BatchRequest(message)
|
||||
else:
|
||||
request = SingleRequest(self, message)
|
||||
request = SingleRequest(message)
|
||||
|
||||
'''Queue the request for asynchronous handling.'''
|
||||
self.enqueue_request(request)
|
||||
|
||||
@ -21,7 +21,7 @@ from functools import partial
|
||||
import pylru
|
||||
|
||||
from lib.hash import sha256, double_sha256, hash_to_str, hex_str_to_hash
|
||||
from lib.jsonrpc import JSONRPC
|
||||
from lib.jsonrpc import JSONRPC, RequestBase
|
||||
from lib.tx import Deserializer
|
||||
import lib.util as util
|
||||
from server.block_processor import BlockProcessor
|
||||
@ -217,16 +217,15 @@ class ServerManager(util.LoggedClass):
|
||||
|
||||
BANDS = 5
|
||||
|
||||
class NotificationRequest(object):
|
||||
def __init__(self, fn_call):
|
||||
self.fn_call = fn_call
|
||||
class NotificationRequest(RequestBase):
|
||||
def __init__(self, height, touched):
|
||||
super().__init__(1)
|
||||
self.height = height
|
||||
self.touched = touched
|
||||
|
||||
def remaining(self):
|
||||
return 0
|
||||
|
||||
async def process(self, limit):
|
||||
await self.fn_call()
|
||||
return 0
|
||||
async def process(self, session):
|
||||
self.remaining = 0
|
||||
await session.notify(self.height, self.touched)
|
||||
|
||||
def __init__(self, env):
|
||||
super().__init__()
|
||||
@ -294,8 +293,8 @@ class ServerManager(util.LoggedClass):
|
||||
if isinstance(session, LocalRPC):
|
||||
return 0
|
||||
group_bandwidth = sum(s.bandwidth_used for s in self.sessions[session])
|
||||
return (bisect_left(self.bands, session.bandwidth_used)
|
||||
+ bisect_left(self.bands, group_bandwidth) + 1) // 2
|
||||
return 1 + (bisect_left(self.bands, session.bandwidth_used)
|
||||
+ bisect_left(self.bands, group_bandwidth) + 1) // 2
|
||||
|
||||
async def enqueue_delayed_sessions(self):
|
||||
now = time.time()
|
||||
@ -317,9 +316,14 @@ class ServerManager(util.LoggedClass):
|
||||
item = (priority, self.next_queue_id, session)
|
||||
self.next_queue_id += 1
|
||||
|
||||
secs = priority - self.BANDS
|
||||
if secs >= 0:
|
||||
secs = int(session.pause)
|
||||
if secs:
|
||||
session.log_info('delaying processing whilst paused')
|
||||
excess = priority - self.BANDS
|
||||
if excess > 0:
|
||||
secs = excess
|
||||
session.log_info('delaying response {:d}s'.format(secs))
|
||||
if secs:
|
||||
self.delayed_sessions.append((time.time() + secs, item))
|
||||
else:
|
||||
self.queue.put_nowait(item)
|
||||
@ -403,8 +407,8 @@ class ServerManager(util.LoggedClass):
|
||||
|
||||
for session in self.sessions:
|
||||
if isinstance(session, ElectrumX):
|
||||
fn_call = partial(session.notify, self.bp.db_height, touched)
|
||||
session.enqueue_request(self.NotificationRequest(fn_call))
|
||||
request = self.NotificationRequest(self.bp.db_height, touched)
|
||||
session.enqueue_request(request)
|
||||
# Periodically log sessions
|
||||
if self.env.log_sessions and time.time() > self.next_log_sessions:
|
||||
data = self.session_data(for_log=True)
|
||||
@ -480,7 +484,7 @@ class ServerManager(util.LoggedClass):
|
||||
if now > self.next_stale_check:
|
||||
self.next_stale_check = now + 60
|
||||
self.clear_stale_sessions()
|
||||
group = self.groups[int(session.start - self.start) // 60]
|
||||
group = self.groups[int(session.start - self.start) // 180]
|
||||
group.add(session)
|
||||
self.sessions[session] = group
|
||||
session.log_info('connection from {}, {:,d} total'
|
||||
@ -521,9 +525,14 @@ class ServerManager(util.LoggedClass):
|
||||
if stale:
|
||||
self.logger.info('closing stale connections {}'.format(stale))
|
||||
|
||||
# Clear out empty groups
|
||||
for key in [k for k, v in self.groups.items() if not v]:
|
||||
del self.groups[key]
|
||||
# Consolidate small groups
|
||||
keys = [k for k, v in self.groups.items() if len(v) <= 2
|
||||
and sum(session.bandwidth_used for session in v) < 10000]
|
||||
if len(keys) > 1:
|
||||
group = set.union(*(self.groups[key] for key in keys))
|
||||
for key in keys:
|
||||
del self.groups[key]
|
||||
self.groups[max(keys)] = group
|
||||
|
||||
def new_subscription(self):
|
||||
if self.subscription_count >= self.max_subs:
|
||||
@ -728,7 +737,7 @@ class Session(JSONRPC):
|
||||
return status
|
||||
|
||||
def requests_remaining(self):
|
||||
return sum(request.remaining() for request in self.requests)
|
||||
return sum(request.remaining for request in self.requests)
|
||||
|
||||
def enqueue_request(self, request):
|
||||
'''Add a request to the session's list.'''
|
||||
@ -738,28 +747,28 @@ class Session(JSONRPC):
|
||||
|
||||
async def serve_requests(self):
|
||||
'''Serve requests in batches.'''
|
||||
done_reqs = 0
|
||||
done_jobs = 0
|
||||
limit = 4
|
||||
total = 0
|
||||
errs = []
|
||||
# Process 8 items at a time
|
||||
for request in self.requests:
|
||||
try:
|
||||
done_jobs += await request.process(limit - done_jobs)
|
||||
initial = request.remaining
|
||||
await request.process(self)
|
||||
total += initial - request.remaining
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
# Getting here should probably be considered a bug and fixed
|
||||
# Should probably be considered a bug and fixed
|
||||
self.log_error('error handling request {}'.format(request))
|
||||
traceback.print_exc()
|
||||
done_reqs += 1
|
||||
else:
|
||||
if not request.remaining():
|
||||
done_reqs += 1
|
||||
if done_jobs >= limit:
|
||||
errs.append(request)
|
||||
if total >= 8:
|
||||
break
|
||||
|
||||
self.log_info('done {:,d} items'.format(total))
|
||||
# Remove completed requests and re-enqueue ourself if any remain.
|
||||
if done_reqs:
|
||||
self.requests = self.requests[done_reqs:]
|
||||
self.requests = [req for req in self.requests
|
||||
if req.remaining and not req in errs]
|
||||
if self.requests:
|
||||
self.manager.enqueue_session(self)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user