summaryrefslogtreecommitdiff
path: root/poky/bitbake/lib/hashserv/server.py
diff options
context:
space:
mode:
Diffstat (limited to 'poky/bitbake/lib/hashserv/server.py')
-rw-r--r--poky/bitbake/lib/hashserv/server.py217
1 files changed, 25 insertions, 192 deletions
diff --git a/poky/bitbake/lib/hashserv/server.py b/poky/bitbake/lib/hashserv/server.py
index a0dc0c170..8e8498973 100644
--- a/poky/bitbake/lib/hashserv/server.py
+++ b/poky/bitbake/lib/hashserv/server.py
@@ -6,15 +6,12 @@
from contextlib import closing, contextmanager
from datetime import datetime
import asyncio
-import json
import logging
import math
-import os
-import signal
-import socket
-import sys
import time
-from . import chunkify, DEFAULT_MAX_CHUNK, create_async_client, TABLE_COLUMNS
+from . import create_async_client, TABLE_COLUMNS
+import bb.asyncrpc
+
logger = logging.getLogger('hashserv.server')
@@ -109,12 +106,6 @@ class Stats(object):
return {k: getattr(self, k) for k in ('num', 'total_time', 'max_time', 'average', 'stdev')}
-class ClientError(Exception):
- pass
-
-class ServerError(Exception):
- pass
-
def insert_task(cursor, data, ignore=False):
keys = sorted(data.keys())
query = '''INSERT%s INTO tasks_v2 (%s) VALUES (%s)''' % (
@@ -128,7 +119,6 @@ async def copy_from_upstream(client, db, method, taskhash):
if d is not None:
# Filter out unknown columns
d = {k: v for k, v in d.items() if k in TABLE_COLUMNS}
- keys = sorted(d.keys())
with closing(db.cursor()) as cursor:
insert_task(cursor, d)
@@ -141,7 +131,6 @@ async def copy_outhash_from_upstream(client, db, method, outhash, taskhash):
if d is not None:
# Filter out unknown columns
d = {k: v for k, v in d.items() if k in TABLE_COLUMNS}
- keys = sorted(d.keys())
with closing(db.cursor()) as cursor:
insert_task(cursor, d)
@@ -149,7 +138,7 @@ async def copy_outhash_from_upstream(client, db, method, outhash, taskhash):
return d
-class ServerClient(object):
+class ServerClient(bb.asyncrpc.AsyncServerConnection):
FAST_QUERY = 'SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1'
ALL_QUERY = 'SELECT * FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1'
OUTHASH_QUERY = '''
@@ -168,21 +157,19 @@ class ServerClient(object):
'''
def __init__(self, reader, writer, db, request_stats, backfill_queue, upstream, read_only):
- self.reader = reader
- self.writer = writer
+ super().__init__(reader, writer, 'OEHASHEQUIV', logger)
self.db = db
self.request_stats = request_stats
- self.max_chunk = DEFAULT_MAX_CHUNK
+ self.max_chunk = bb.asyncrpc.DEFAULT_MAX_CHUNK
self.backfill_queue = backfill_queue
self.upstream = upstream
- self.handlers = {
+ self.handlers.update({
'get': self.handle_get,
'get-outhash': self.handle_get_outhash,
'get-stream': self.handle_get_stream,
'get-stats': self.handle_get_stats,
- 'chunk-stream': self.handle_chunk,
- }
+ })
if not read_only:
self.handlers.update({
@@ -192,56 +179,19 @@ class ServerClient(object):
'backfill-wait': self.handle_backfill_wait,
})
+ def validate_proto_version(self):
+ return (self.proto_version > (1, 0) and self.proto_version <= (1, 1))
+
async def process_requests(self):
if self.upstream is not None:
self.upstream_client = await create_async_client(self.upstream)
else:
self.upstream_client = None
- try:
-
-
- self.addr = self.writer.get_extra_info('peername')
- logger.debug('Client %r connected' % (self.addr,))
-
- # Read protocol and version
- protocol = await self.reader.readline()
- if protocol is None:
- return
-
- (proto_name, proto_version) = protocol.decode('utf-8').rstrip().split()
- if proto_name != 'OEHASHEQUIV':
- return
-
- proto_version = tuple(int(v) for v in proto_version.split('.'))
- if proto_version < (1, 0) or proto_version > (1, 1):
- return
-
- # Read headers. Currently, no headers are implemented, so look for
- # an empty line to signal the end of the headers
- while True:
- line = await self.reader.readline()
- if line is None:
- return
+ await super().process_requests()
- line = line.decode('utf-8').rstrip()
- if not line:
- break
-
- # Handle messages
- while True:
- d = await self.read_message()
- if d is None:
- break
- await self.dispatch_message(d)
- await self.writer.drain()
- except ClientError as e:
- logger.error(str(e))
- finally:
- if self.upstream_client is not None:
- await self.upstream_client.close()
-
- self.writer.close()
+ if self.upstream_client is not None:
+ await self.upstream_client.close()
async def dispatch_message(self, msg):
for k in self.handlers.keys():
@@ -255,47 +205,7 @@ class ServerClient(object):
await self.handlers[k](msg[k])
return
- raise ClientError("Unrecognized command %r" % msg)
-
- def write_message(self, msg):
- for c in chunkify(json.dumps(msg), self.max_chunk):
- self.writer.write(c.encode('utf-8'))
-
- async def read_message(self):
- l = await self.reader.readline()
- if not l:
- return None
-
- try:
- message = l.decode('utf-8')
-
- if not message.endswith('\n'):
- return None
-
- return json.loads(message)
- except (json.JSONDecodeError, UnicodeDecodeError) as e:
- logger.error('Bad message from client: %r' % message)
- raise e
-
- async def handle_chunk(self, request):
- lines = []
- try:
- while True:
- l = await self.reader.readline()
- l = l.rstrip(b"\n").decode("utf-8")
- if not l:
- break
- lines.append(l)
-
- msg = json.loads(''.join(lines))
- except (json.JSONDecodeError, UnicodeDecodeError) as e:
- logger.error('Bad message from client: %r' % message)
- raise e
-
- if 'chunk-stream' in msg:
- raise ClientError("Nested chunks are not allowed")
-
- await self.dispatch_message(msg)
+ raise bb.asyncrpc.ClientError("Unrecognized command %r" % msg)
async def handle_get(self, request):
method = request['method']
@@ -499,74 +409,20 @@ class ServerClient(object):
cursor.close()
-class Server(object):
+class Server(bb.asyncrpc.AsyncServer):
def __init__(self, db, loop=None, upstream=None, read_only=False):
if upstream and read_only:
- raise ServerError("Read-only hashserv cannot pull from an upstream server")
+ raise bb.asyncrpc.ServerError("Read-only hashserv cannot pull from an upstream server")
+
+ super().__init__(logger, loop)
self.request_stats = Stats()
self.db = db
-
- if loop is None:
- self.loop = asyncio.new_event_loop()
- self.close_loop = True
- else:
- self.loop = loop
- self.close_loop = False
-
self.upstream = upstream
self.read_only = read_only
- self._cleanup_socket = None
-
- def start_tcp_server(self, host, port):
- self.server = self.loop.run_until_complete(
- asyncio.start_server(self.handle_client, host, port, loop=self.loop)
- )
-
- for s in self.server.sockets:
- logger.info('Listening on %r' % (s.getsockname(),))
- # Newer python does this automatically. Do it manually here for
- # maximum compatibility
- s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
- s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
-
- name = self.server.sockets[0].getsockname()
- if self.server.sockets[0].family == socket.AF_INET6:
- self.address = "[%s]:%d" % (name[0], name[1])
- else:
- self.address = "%s:%d" % (name[0], name[1])
-
- def start_unix_server(self, path):
- def cleanup():
- os.unlink(path)
-
- cwd = os.getcwd()
- try:
- # Work around path length limits in AF_UNIX
- os.chdir(os.path.dirname(path))
- self.server = self.loop.run_until_complete(
- asyncio.start_unix_server(self.handle_client, os.path.basename(path), loop=self.loop)
- )
- finally:
- os.chdir(cwd)
-
- logger.info('Listening on %r' % path)
-
- self._cleanup_socket = cleanup
- self.address = "unix://%s" % os.path.abspath(path)
-
- async def handle_client(self, reader, writer):
- # writer.transport.set_write_buffer_limits(0)
- try:
- client = ServerClient(reader, writer, self.db, self.request_stats, self.backfill_queue, self.upstream, self.read_only)
- await client.process_requests()
- except Exception as e:
- import traceback
- logger.error('Error from client: %s' % str(e), exc_info=True)
- traceback.print_exc()
- writer.close()
- logger.info('Client disconnected')
+ def accept_client(self, reader, writer):
+ return ServerClient(reader, writer, self.db, self.request_stats, self.backfill_queue, self.upstream, self.read_only)
@contextmanager
def _backfill_worker(self):
@@ -597,31 +453,8 @@ class Server(object):
else:
yield
- def serve_forever(self):
- def signal_handler():
- self.loop.stop()
-
- asyncio.set_event_loop(self.loop)
- try:
- self.backfill_queue = asyncio.Queue()
-
- self.loop.add_signal_handler(signal.SIGTERM, signal_handler)
-
- with self._backfill_worker():
- try:
- self.loop.run_forever()
- except KeyboardInterrupt:
- pass
-
- self.server.close()
-
- self.loop.run_until_complete(self.server.wait_closed())
- logger.info('Server shutting down')
- finally:
- if self.close_loop:
- if sys.version_info >= (3, 6):
- self.loop.run_until_complete(self.loop.shutdown_asyncgens())
- self.loop.close()
+ def run_loop_forever(self):
+ self.backfill_queue = asyncio.Queue()
- if self._cleanup_socket is not None:
- self._cleanup_socket()
+ with self._backfill_worker():
+ super().run_loop_forever()