diff options
Diffstat (limited to 'poky/bitbake/lib/hashserv')
-rw-r--r-- | poky/bitbake/lib/hashserv/__init__.py | 4 | ||||
-rw-r--r-- | poky/bitbake/lib/hashserv/client.py | 8 | ||||
-rw-r--r-- | poky/bitbake/lib/hashserv/server.py | 93 | ||||
-rw-r--r-- | poky/bitbake/lib/hashserv/tests.py | 45 |
4 files changed, 123 insertions, 27 deletions
diff --git a/poky/bitbake/lib/hashserv/__init__.py b/poky/bitbake/lib/hashserv/__init__.py index 55f48410d..5f2e101e5 100644 --- a/poky/bitbake/lib/hashserv/__init__.py +++ b/poky/bitbake/lib/hashserv/__init__.py @@ -94,10 +94,10 @@ def chunkify(msg, max_chunk): yield "\n" -def create_server(addr, dbname, *, sync=True, upstream=None): +def create_server(addr, dbname, *, sync=True, upstream=None, read_only=False): from . import server db = setup_database(dbname, sync=sync) - s = server.Server(db, upstream=upstream) + s = server.Server(db, upstream=upstream, read_only=read_only) (typ, a) = parse_address(addr) if typ == ADDR_TYPE_UNIX: diff --git a/poky/bitbake/lib/hashserv/client.py b/poky/bitbake/lib/hashserv/client.py index 0ffd0c2ae..e05c1eb56 100644 --- a/poky/bitbake/lib/hashserv/client.py +++ b/poky/bitbake/lib/hashserv/client.py @@ -99,7 +99,7 @@ class AsyncClient(object): l = await get_line() m = json.loads(l) - if "chunk-stream" in m: + if m and "chunk-stream" in m: lines = [] while True: l = (await get_line()).rstrip("\n") @@ -170,6 +170,12 @@ class AsyncClient(object): {"get": {"taskhash": taskhash, "method": method, "all": all_properties}} ) + async def get_outhash(self, method, outhash, taskhash): + await self._set_mode(self.MODE_NORMAL) + return await self.send_message( + {"get-outhash": {"outhash": outhash, "taskhash": taskhash, "method": method}} + ) + async def get_stats(self): await self._set_mode(self.MODE_NORMAL) return await self.send_message({"get-stats": None}) diff --git a/poky/bitbake/lib/hashserv/server.py b/poky/bitbake/lib/hashserv/server.py index 3ff4c51cc..a0dc0c170 100644 --- a/poky/bitbake/lib/hashserv/server.py +++ b/poky/bitbake/lib/hashserv/server.py @@ -112,6 +112,9 @@ class Stats(object): class ClientError(Exception): pass +class ServerError(Exception): + pass + def insert_task(cursor, data, ignore=False): keys = sorted(data.keys()) query = '''INSERT%s INTO tasks_v2 (%s) VALUES (%s)''' % ( @@ -127,6 +130,18 @@ async def copy_from_upstream(client, db, method, taskhash): d = {k: v for k, v in d.items() if k in TABLE_COLUMNS} keys = sorted(d.keys()) + with closing(db.cursor()) as cursor: + insert_task(cursor, d) + db.commit() + + return d + +async def copy_outhash_from_upstream(client, db, method, outhash, taskhash): + d = await client.get_outhash(method, outhash, taskhash) + if d is not None: + # Filter out unknown columns + d = {k: v for k, v in d.items() if k in TABLE_COLUMNS} + keys = sorted(d.keys()) with closing(db.cursor()) as cursor: insert_task(cursor, d) @@ -137,8 +152,22 @@ async def copy_from_upstream(client, db, method, taskhash): class ServerClient(object): FAST_QUERY = 'SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1' ALL_QUERY = 'SELECT * FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1' - - def __init__(self, reader, writer, db, request_stats, backfill_queue, upstream): + OUTHASH_QUERY = ''' + -- Find tasks with a matching outhash (that is, tasks that + -- are equivalent) + SELECT * FROM tasks_v2 WHERE method=:method AND outhash=:outhash + + -- If there is an exact match on the taskhash, return it. + -- Otherwise return the oldest matching outhash of any + -- taskhash + ORDER BY CASE WHEN taskhash=:taskhash THEN 1 ELSE 2 END, + created ASC + + -- Only return one row + LIMIT 1 + ''' + + def __init__(self, reader, writer, db, request_stats, backfill_queue, upstream, read_only): self.reader = reader self.writer = writer self.db = db @@ -149,15 +178,20 @@ class ServerClient(object): self.handlers = { 'get': self.handle_get, - 'report': self.handle_report, - 'report-equiv': self.handle_equivreport, + 'get-outhash': self.handle_get_outhash, 'get-stream': self.handle_get_stream, 'get-stats': self.handle_get_stats, - 'reset-stats': self.handle_reset_stats, 'chunk-stream': self.handle_chunk, - 'backfill-wait': self.handle_backfill_wait, } + if not read_only: + self.handlers.update({ + 'report': self.handle_report, + 'report-equiv': self.handle_equivreport, + 'reset-stats': self.handle_reset_stats, + 'backfill-wait': self.handle_backfill_wait, + }) + async def process_requests(self): if self.upstream is not None: self.upstream_client = await create_async_client(self.upstream) @@ -282,6 +316,21 @@ class ServerClient(object): self.write_message(d) + async def handle_get_outhash(self, request): + with closing(self.db.cursor()) as cursor: + cursor.execute(self.OUTHASH_QUERY, + {k: request[k] for k in ('method', 'outhash', 'taskhash')}) + + row = cursor.fetchone() + + if row is not None: + logger.debug('Found equivalent outhash %s -> %s', (row['outhash'], row['unihash'])) + d = {k: row[k] for k in row.keys()} + else: + d = None + + self.write_message(d) + async def handle_get_stream(self, request): self.write_message('ok') @@ -335,23 +384,19 @@ class ServerClient(object): async def handle_report(self, data): with closing(self.db.cursor()) as cursor: - cursor.execute(''' - -- Find tasks with a matching outhash (that is, tasks that - -- are equivalent) - SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND outhash=:outhash - - -- If there is an exact match on the taskhash, return it. - -- Otherwise return the oldest matching outhash of any - -- taskhash - ORDER BY CASE WHEN taskhash=:taskhash THEN 1 ELSE 2 END, - created ASC - - -- Only return one row - LIMIT 1 - ''', {k: data[k] for k in ('method', 'outhash', 'taskhash')}) + cursor.execute(self.OUTHASH_QUERY, + {k: data[k] for k in ('method', 'outhash', 'taskhash')}) row = cursor.fetchone() + if row is None and self.upstream_client: + # Try upstream + row = await copy_outhash_from_upstream(self.upstream_client, + self.db, + data['method'], + data['outhash'], + data['taskhash']) + # If no matching outhash was found, or one *was* found but it # wasn't an exact match on the taskhash, a new entry for this # taskhash should be added @@ -455,7 +500,10 @@ class ServerClient(object): class Server(object): - def __init__(self, db, loop=None, upstream=None): + def __init__(self, db, loop=None, upstream=None, read_only=False): + if upstream and read_only: + raise ServerError("Read-only hashserv cannot pull from an upstream server") + self.request_stats = Stats() self.db = db @@ -467,6 +515,7 @@ class Server(object): self.close_loop = False self.upstream = upstream + self.read_only = read_only self._cleanup_socket = None @@ -510,7 +559,7 @@ class Server(object): async def handle_client(self, reader, writer): # writer.transport.set_write_buffer_limits(0) try: - client = ServerClient(reader, writer, self.db, self.request_stats, self.backfill_queue, self.upstream) + client = ServerClient(reader, writer, self.db, self.request_stats, self.backfill_queue, self.upstream, self.read_only) await client.process_requests() except Exception as e: import traceback diff --git a/poky/bitbake/lib/hashserv/tests.py b/poky/bitbake/lib/hashserv/tests.py index 77a19b807..1a696481e 100644 --- a/poky/bitbake/lib/hashserv/tests.py +++ b/poky/bitbake/lib/hashserv/tests.py @@ -6,6 +6,7 @@ # from . import create_server, create_client +from .client import HashConnectionError import hashlib import logging import multiprocessing @@ -29,7 +30,7 @@ class HashEquivalenceTestSetup(object): server_index = 0 - def start_server(self, dbpath=None, upstream=None): + def start_server(self, dbpath=None, upstream=None, read_only=False): self.server_index += 1 if dbpath is None: dbpath = os.path.join(self.temp_dir.name, "db%d.sqlite" % self.server_index) @@ -38,7 +39,10 @@ class HashEquivalenceTestSetup(object): thread.terminate() thread.join() - server = create_server(self.get_server_addr(self.server_index), dbpath, upstream=upstream) + server = create_server(self.get_server_addr(self.server_index), + dbpath, + upstream=upstream, + read_only=read_only) server.dbpath = dbpath server.thread = multiprocessing.Process(target=_run_server, args=(server, self.server_index)) @@ -242,6 +246,43 @@ class HashEquivalenceCommonTests(object): self.assertClientGetHash(side_client, taskhash4, unihash4) self.assertClientGetHash(self.client, taskhash4, None) + # Test that reporting a unihash in the downstream is able to find a + # match which was previously reported to the upstream server + taskhash5 = '35788efcb8dfb0a02659d81cf2bfd695fb30faf9' + outhash5 = '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f' + unihash5 = 'f46d3fbb439bd9b921095da657a4de906510d2cd' + result = self.client.report_unihash(taskhash5, self.METHOD, outhash5, unihash5) + + taskhash6 = '35788efcb8dfb0a02659d81cf2bfd695fb30fafa' + unihash6 = 'f46d3fbb439bd9b921095da657a4de906510d2ce' + result = down_client.report_unihash(taskhash6, self.METHOD, outhash5, unihash6) + self.assertEqual(result['unihash'], unihash5, 'Server failed to copy unihash from upstream') + + def test_ro_server(self): + (ro_client, ro_server) = self.start_server(dbpath=self.server.dbpath, read_only=True) + + # Report a hash via the read-write server + taskhash = '35788efcb8dfb0a02659d81cf2bfd695fb30faf9' + outhash = '2765d4a5884be49b28601445c2760c5f21e7e5c0ee2b7e3fce98fd7e5970796f' + unihash = 'f46d3fbb439bd9b921095da657a4de906510d2cd' + + result = self.client.report_unihash(taskhash, self.METHOD, outhash, unihash) + self.assertEqual(result['unihash'], unihash, 'Server returned bad unihash') + + # Check the hash via the read-only server + self.assertClientGetHash(ro_client, taskhash, unihash) + + # Ensure that reporting via the read-only server fails + taskhash2 = 'c665584ee6817aa99edfc77a44dd853828279370' + outhash2 = '3c979c3db45c569f51ab7626a4651074be3a9d11a84b1db076f5b14f7d39db44' + unihash2 = '90e9bc1d1f094c51824adca7f8ea79a048d68824' + + with self.assertRaises(HashConnectionError): + ro_client.report_unihash(taskhash2, self.METHOD, outhash2, unihash2) + + # Ensure that the database was not modified + self.assertClientGetHash(self.client, taskhash2, None) + class TestHashEquivalenceUnixServer(HashEquivalenceTestSetup, HashEquivalenceCommonTests, unittest.TestCase): def get_server_addr(self, server_idx): |