diff options
Diffstat (limited to 'poky/bitbake/lib/hashserv/server.py')
-rw-r--r-- | poky/bitbake/lib/hashserv/server.py | 149 |
1 files changed, 119 insertions, 30 deletions
diff --git a/poky/bitbake/lib/hashserv/server.py b/poky/bitbake/lib/hashserv/server.py index 81050715e..3ff4c51cc 100644 --- a/poky/bitbake/lib/hashserv/server.py +++ b/poky/bitbake/lib/hashserv/server.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: GPL-2.0-only # -from contextlib import closing +from contextlib import closing, contextmanager from datetime import datetime import asyncio import json @@ -12,8 +12,9 @@ import math import os import signal import socket +import sys import time -from . import chunkify, DEFAULT_MAX_CHUNK +from . import chunkify, DEFAULT_MAX_CHUNK, create_async_client, TABLE_COLUMNS logger = logging.getLogger('hashserv.server') @@ -111,16 +112,40 @@ class Stats(object): class ClientError(Exception): pass +def insert_task(cursor, data, ignore=False): + keys = sorted(data.keys()) + query = '''INSERT%s INTO tasks_v2 (%s) VALUES (%s)''' % ( + " OR IGNORE" if ignore else "", + ', '.join(keys), + ', '.join(':' + k for k in keys)) + cursor.execute(query, data) + +async def copy_from_upstream(client, db, method, taskhash): + d = await client.get_taskhash(method, taskhash, True) + if d is not None: + # Filter out unknown columns + d = {k: v for k, v in d.items() if k in TABLE_COLUMNS} + keys = sorted(d.keys()) + + + with closing(db.cursor()) as cursor: + insert_task(cursor, d) + db.commit() + + return d + class ServerClient(object): FAST_QUERY = 'SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1' ALL_QUERY = 'SELECT * FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1' - def __init__(self, reader, writer, db, request_stats): + def __init__(self, reader, writer, db, request_stats, backfill_queue, upstream): self.reader = reader self.writer = writer self.db = db self.request_stats = request_stats self.max_chunk = DEFAULT_MAX_CHUNK + self.backfill_queue = backfill_queue + self.upstream = upstream self.handlers = { 'get': self.handle_get, @@ -130,10 +155,18 @@ class ServerClient(object): 'get-stats': self.handle_get_stats, 'reset-stats': self.handle_reset_stats, 'chunk-stream': self.handle_chunk, + 'backfill-wait': self.handle_backfill_wait, } async def process_requests(self): + if self.upstream is not None: + self.upstream_client = await create_async_client(self.upstream) + else: + self.upstream_client = None + try: + + self.addr = self.writer.get_extra_info('peername') logger.debug('Client %r connected' % (self.addr,)) @@ -171,6 +204,9 @@ class ServerClient(object): except ClientError as e: logger.error(str(e)) finally: + if self.upstream_client is not None: + await self.upstream_client.close() + self.writer.close() async def dispatch_message(self, msg): @@ -239,15 +275,19 @@ class ServerClient(object): if row is not None: logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash'])) d = {k: row[k] for k in row.keys()} - - self.write_message(d) + elif self.upstream_client is not None: + d = await copy_from_upstream(self.upstream_client, self.db, method, taskhash) else: - self.write_message(None) + d = None + + self.write_message(d) async def handle_get_stream(self, request): self.write_message('ok') while True: + upstream = None + l = await self.reader.readline() if not l: return @@ -272,6 +312,12 @@ class ServerClient(object): if row is not None: msg = ('%s\n' % row['unihash']).encode('utf-8') #logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash'])) + elif self.upstream_client is not None: + upstream = await self.upstream_client.get_unihash(method, taskhash) + if upstream: + msg = ("%s\n" % upstream).encode("utf-8") + else: + msg = "\n".encode("utf-8") else: msg = '\n'.encode('utf-8') @@ -282,6 +328,11 @@ class ServerClient(object): await self.writer.drain() + # Post to the backfill queue after writing the result to minimize + # the turn around time on a request + if upstream is not None: + await self.backfill_queue.put((method, taskhash)) + async def handle_report(self, data): with closing(self.db.cursor()) as cursor: cursor.execute(''' @@ -324,11 +375,7 @@ class ServerClient(object): if k in data: insert_data[k] = data[k] - cursor.execute('''INSERT INTO tasks_v2 (%s) VALUES (%s)''' % ( - ', '.join(sorted(insert_data.keys())), - ', '.join(':' + k for k in sorted(insert_data.keys()))), - insert_data) - + insert_task(cursor, insert_data) self.db.commit() logger.info('Adding taskhash %s with unihash %s', @@ -358,11 +405,7 @@ class ServerClient(object): if k in data: insert_data[k] = data[k] - cursor.execute('''INSERT OR IGNORE INTO tasks_v2 (%s) VALUES (%s)''' % ( - ', '.join(sorted(insert_data.keys())), - ', '.join(':' + k for k in sorted(insert_data.keys()))), - insert_data) - + insert_task(cursor, insert_data, ignore=True) self.db.commit() # Fetch the unihash that will be reported for the taskhash. If the @@ -394,6 +437,13 @@ class ServerClient(object): self.request_stats.reset() self.write_message(d) + async def handle_backfill_wait(self, request): + d = { + 'tasks': self.backfill_queue.qsize(), + } + await self.backfill_queue.join() + self.write_message(d) + def query_equivalent(self, method, taskhash, query): # This is part of the inner loop and must be as fast as possible try: @@ -405,7 +455,7 @@ class ServerClient(object): class Server(object): - def __init__(self, db, loop=None): + def __init__(self, db, loop=None, upstream=None): self.request_stats = Stats() self.db = db @@ -416,6 +466,8 @@ class Server(object): self.loop = loop self.close_loop = False + self.upstream = upstream + self._cleanup_socket = None def start_tcp_server(self, host, port): @@ -458,7 +510,7 @@ class Server(object): async def handle_client(self, reader, writer): # writer.transport.set_write_buffer_limits(0) try: - client = ServerClient(reader, writer, self.db, self.request_stats) + client = ServerClient(reader, writer, self.db, self.request_stats, self.backfill_queue, self.upstream) await client.process_requests() except Exception as e: import traceback @@ -467,23 +519,60 @@ class Server(object): writer.close() logger.info('Client disconnected') + @contextmanager + def _backfill_worker(self): + async def backfill_worker_task(): + client = await create_async_client(self.upstream) + try: + while True: + item = await self.backfill_queue.get() + if item is None: + self.backfill_queue.task_done() + break + method, taskhash = item + await copy_from_upstream(client, self.db, method, taskhash) + self.backfill_queue.task_done() + finally: + await client.close() + + async def join_worker(worker): + await self.backfill_queue.put(None) + await worker + + if self.upstream is not None: + worker = asyncio.ensure_future(backfill_worker_task()) + try: + yield + finally: + self.loop.run_until_complete(join_worker(worker)) + else: + yield + def serve_forever(self): def signal_handler(): self.loop.stop() - self.loop.add_signal_handler(signal.SIGTERM, signal_handler) - + asyncio.set_event_loop(self.loop) try: - self.loop.run_forever() - except KeyboardInterrupt: - pass + self.backfill_queue = asyncio.Queue() + + self.loop.add_signal_handler(signal.SIGTERM, signal_handler) - self.server.close() - self.loop.run_until_complete(self.server.wait_closed()) - logger.info('Server shutting down') + with self._backfill_worker(): + try: + self.loop.run_forever() + except KeyboardInterrupt: + pass - if self.close_loop: - self.loop.close() + self.server.close() + + self.loop.run_until_complete(self.server.wait_closed()) + logger.info('Server shutting down') + finally: + if self.close_loop: + if sys.version_info >= (3, 6): + self.loop.run_until_complete(self.loop.shutdown_asyncgens()) + self.loop.close() - if self._cleanup_socket is not None: - self._cleanup_socket() + if self._cleanup_socket is not None: + self._cleanup_socket() |