def test_push_to_remote(self): client = MLStorageClient('http://127.0.0.1:8080') id = ObjectId() return_value = {'id': id, 'flag': 123} now_time = datetime.utcnow().replace(tzinfo=pytz.UTC).isoformat() client.update = Mock(return_value=return_value) # test push_to_remote without heartbeat field push_updates = {'id': id, 'flag': 456} doc = ExperimentRemoteDoc(client, id) doc.now_time_literal = Mock(return_value=now_time) doc.push_to_remote(push_updates) self.assertEqual(doc.last_response, return_value) self.assertEqual(client.update.call_args[0][0], id) updates_arg = client.update.call_args[0][1] self.assertIn('heartbeat', updates_arg) updates_arg.pop('heartbeat') self.assertEqual(updates_arg, push_updates) # test with heartbeat field now_time2 = datetime.utcnow().replace(tzinfo=pytz.UTC).isoformat() push_updates = {'id': id, 'flag': 456, 'heartbeat': now_time2} doc = ExperimentRemoteDoc(client, id) doc.now_time_literal = Mock(return_value=now_time) doc.push_to_remote(push_updates) self.assertEqual(doc.last_response, return_value) self.assertEqual(client.update.call_args[0][0], id) self.assertEqual(client.update.call_args[0][1], push_updates)
def test_set_finished(self): client = MLStorageClient('http://127.0.0.1:8080') id = ObjectId() doc = ExperimentRemoteDoc(client, id) now_time = datetime.utcnow().replace(tzinfo=pytz.UTC).isoformat() doc.now_time_literal = Mock(return_value=now_time) # set_finished should only be called when thread is None client.update = Mock(return_value={}) with doc: with pytest.raises(RuntimeError, match='`set_finished` must only be called ' 'when the background worker is not ' 'running'): _ = doc.set_finished('FAILED', retry_intervals=(0.1, 0.2)) # test error retry start_time = time.time() retry_times = [] expected_updates = { 'status': 'COMPLETED', 'heartbeat': now_time, 'stop_time': now_time, 'abc': 123, } def f(v_id, v_updates): self.assertEqual(v_id, id) self.assertEqual(v_updates, expected_updates) retry_times.append(time.time()) raise RuntimeError(f'retry count: {len(retry_times)}') client.update = Mock(wraps=f) with pytest.raises(RuntimeError, match='retry count: 3'): doc.set_finished('COMPLETED', {'abc': 123}, retry_intervals=(0.1, 0.2)) self.assertEqual(len(retry_times), 3) self.assertLess(abs(retry_times[0] - start_time), 0.01) self.assertLess(abs(retry_times[1] - retry_times[0] - 0.1), 0.01) self.assertLess(abs(retry_times[2] - retry_times[1] - 0.2), 0.01) self.assertEqual(doc.has_set_finished, False) # test success return_value = {'id': id, 'flags': 456} client.update = Mock(return_value=return_value) doc.set_finished('COMPLETED', {'abc': 123}, retry_intervals=(0.1, 0.2)) self.assertEqual(doc.last_response, return_value) self.assertEqual(client.update.call_count, 1) self.assertEqual(client.update.call_args[0][0], id) self.assertEqual(client.update.call_args[0][1], expected_updates) self.assertEqual(doc.has_set_finished, True)
def test_interface(self): c = MLStorageClient('http://127.0.0.1') self.assertEqual(c.uri, 'http://127.0.0.1') c = MLStorageClient('http://127.0.0.1/') self.assertEqual(c.uri, 'http://127.0.0.1') # test invalid response should trigger error httpretty.register_uri(httpretty.POST, 'http://127.0.0.1/v1/_query', body='hello') with pytest.raises(IOError, match=r'The response from http://127.0.0.1/v1/' r'_query\?skip=0 is not JSON: HTTP code is ' r'200'): _ = self.client.query()
def setUp(self): self.client = MLStorageClient('http://127.0.0.1')
class MLStorageClientTestCase(unittest.TestCase): def setUp(self): self.client = MLStorageClient('http://127.0.0.1') @httpretty.activate def test_interface(self): c = MLStorageClient('http://127.0.0.1') self.assertEqual(c.uri, 'http://127.0.0.1') c = MLStorageClient('http://127.0.0.1/') self.assertEqual(c.uri, 'http://127.0.0.1') # test invalid response should trigger error httpretty.register_uri(httpretty.POST, 'http://127.0.0.1/v1/_query', body='hello') with pytest.raises(IOError, match=r'The response from http://127.0.0.1/v1/' r'_query\?skip=0 is not JSON: HTTP code is ' r'200'): _ = self.client.query() @httpretty.activate def test_query(self): def callback(request, uri, response_headers, expected_body, expected_skip, expected_limit, expected_sort, response_body): content_type = request.headers.get('Content-Type').split(';', 1)[0] self.assertEqual(content_type, 'application/json') self.assertEqual(request.querystring['skip'][0], str(expected_skip)) if expected_limit is not None: self.assertEqual(request.querystring['limit'][0], str(expected_limit)) if expected_sort is not None: self.assertEqual(request.querystring['sort'][0], str(expected_sort)) self.assertEqual(request.body, expected_body) response_headers[ 'content-type'] = 'application/json; charset=utf-8' return [200, response_headers, response_body] object_ids = [str(ObjectId()) for _ in range(5)] docs = [{ '_id': object_ids[i], 'storage_dir': f'/{object_ids[i]}', 'uuid': uuid.uuid4() } for i in range(len(object_ids))] # test bare query httpretty.register_uri(httpretty.POST, 'http://127.0.0.1/v1/_query', body=partial(callback, expected_body=b'{}', expected_skip=0, expected_limit=None, expected_sort=None, response_body=json_dumps( docs[:2]))) ret = self.client.query() self.assertListEqual(ret, docs[:2]) for obj_id in object_ids[:2]: # test the storage dir cache self.assertEqual(self.client.get_storage_dir(obj_id), f'/{obj_id}') with pytest.raises(requests.exceptions.ConnectionError): _ = self.client.get_storage_dir(object_ids[2]) # test query httpretty.register_uri(httpretty.POST, 'http://127.0.0.1/v1/_query', body=partial(callback, expected_body=b'{"name":"hint"}', expected_skip=1, expected_limit=99, expected_sort='-start_time', response_body=json_dumps( docs[2:4]))) ret = self.client.query(filter={'name': 'hint'}, sort='-start_time', skip=1, limit=99) self.assertListEqual(ret, docs[2:4]) @httpretty.activate def test_get(self): object_id = str(ObjectId()) doc = { '_id': object_id, 'storage_dir': f'/{object_id}', 'uuid': uuid.uuid4(), } def callback(request, uri, response_headers): response_headers[ 'content-type'] = 'application/json; charset=utf-8' return [200, response_headers, json_dumps(doc)] httpretty.register_uri(httpretty.GET, f'http://127.0.0.1/v1/_get/{object_id}', body=callback) ret = self.client.get(object_id) self.assertDictEqual(ret, doc) self.assertEqual(self.client.get_storage_dir(object_id), doc['storage_dir']) @httpretty.activate def test_heartbeat(self): def callback(request, uri, response_headers): self.assertEqual(request.body, b'') response_headers[ 'content-type'] = 'application/json; charset=utf-8' heartbeat_received[0] = True return [200, response_headers, b'{}'] heartbeat_received = [False] object_id = str(ObjectId()) httpretty.register_uri(httpretty.POST, f'http://127.0.0.1/v1/_heartbeat/{object_id}', body=callback) self.client.heartbeat(object_id) self.assertTrue(heartbeat_received[0]) @httpretty.activate def test_add_tags(self): object_id = str(ObjectId()) doc_fields = { 'uuid': uuid.uuid4(), 'name': 'hello', 'storage_dir': f'/{object_id}', } def callback(request, uri, response_headers): response_headers[ 'content-type'] = 'application/json; charset=utf-8' o = {'_id': object_id} o.update(doc_fields) return [200, response_headers, json_dumps(o)] httpretty.register_uri(httpretty.GET, f'http://127.0.0.1/v1/_get/{object_id}', body=callback) # test add_tags doc_fields['tags'] = ['abc', '123'] def callback(request, uri, response_headers): content_type = request.headers.get('Content-Type').split(';', 1)[0] self.assertEqual(content_type, 'application/json') fields = json_loads(request.body) self.assertDictEqual(fields, {'tags': ['abc', '123', 'hello, world!']}) o = {'_id': object_id} o.update(doc_fields) response_headers[ 'content-type'] = 'application/json; charset=utf-8' return [200, response_headers, json_dumps(o)] httpretty.register_uri(httpretty.POST, f'http://127.0.0.1/v1/_update/{object_id}', body=callback) ret = self.client.add_tags(object_id, ['hello, world!', '123']) expected = {'_id': object_id} expected.update(doc_fields) self.assertDictEqual(ret, expected) @httpretty.activate def test_create_update_delete(self): object_id = str(ObjectId()) doc_fields = { 'uuid': uuid.uuid4(), 'name': 'hello', 'storage_dir': f'/{object_id}', } # test create def callback(request, uri, response_headers): content_type = request.headers.get('Content-Type').split(';', 1)[0] self.assertEqual(content_type, 'application/json') fields = json_loads(request.body) self.assertDictEqual(fields, doc_fields) o = {'_id': object_id} o.update(doc_fields) response_headers[ 'content-type'] = 'application/json; charset=utf-8' return [200, response_headers, json_dumps(o)] httpretty.register_uri(httpretty.POST, f'http://127.0.0.1/v1/_create', body=callback) ret = self.client.create(doc_fields) expected = {'_id': object_id} expected.update(doc_fields) self.assertDictEqual(ret, expected) self.assertEqual(self.client.get_storage_dir(object_id), doc_fields['storage_dir']) # test update doc_fields['storage_dir'] = f'/new/{object_id}' def callback(request, uri, response_headers): content_type = request.headers.get('Content-Type').split(';', 1)[0] self.assertEqual(content_type, 'application/json') fields = json_loads(request.body) self.assertDictEqual(fields, {'storage_dir': doc_fields['storage_dir']}) o = {'_id': object_id} o.update(doc_fields) response_headers[ 'content-type'] = 'application/json; charset=utf-8' return [200, response_headers, json_dumps(o)] httpretty.register_uri(httpretty.POST, f'http://127.0.0.1/v1/_update/{object_id}', body=callback) ret = self.client.update(object_id, {'storage_dir': doc_fields['storage_dir']}) expected = {'_id': object_id} expected.update(doc_fields) self.assertDictEqual(ret, expected) self.assertEqual(self.client.get_storage_dir(object_id), doc_fields['storage_dir']) # test delete def callback(request, uri, response_headers): self.assertEqual(request.body, b'') response_headers[ 'content-type'] = 'application/json; charset=utf-8' return [200, response_headers, json_dumps([object_id])] httpretty.register_uri(httpretty.POST, f'http://127.0.0.1/v1/_delete/{object_id}', body=callback) self.assertListEqual(self.client.delete(object_id), [object_id]) with pytest.raises(requests.exceptions.ConnectionError): _ = self.client.get_storage_dir(object_id) @httpretty.activate def test_set_finished(self): object_id = str(ObjectId()) doc_fields = { '_id': object_id, 'uuid': uuid.uuid4(), 'name': 'hello', 'status': 'COMPLETED', 'storage_dir': f'/{object_id}', } # test set status only def callback(request, uri, response_headers): content_type = request.headers.get('Content-Type').split(';', 1)[0] self.assertEqual(content_type, 'application/json') fields = json_loads(request.body) self.assertDictEqual(fields, {'status': 'COMPLETED'}) response_headers[ 'content-type'] = 'application/json; charset=utf-8' return [200, response_headers, json_dumps(doc_fields)] httpretty.register_uri( httpretty.POST, f'http://127.0.0.1/v1/_set_finished/{object_id}', body=callback) ret = self.client.set_finished(object_id, 'COMPLETED') self.assertDictEqual(ret, doc_fields) # test set status with new fields def callback(request, uri, response_headers): content_type = request.headers.get('Content-Type').split(';', 1)[0] self.assertEqual(content_type, 'application/json') fields = json_loads(request.body) self.assertDictEqual(fields, { 'name': 'hello', 'status': 'COMPLETED' }) response_headers[ 'content-type'] = 'application/json; charset=utf-8' return [200, response_headers, json_dumps(doc_fields)] httpretty.register_uri( httpretty.POST, f'http://127.0.0.1/v1/_set_finished/{object_id}', body=callback) ret = self.client.set_finished(object_id, 'COMPLETED', {'name': 'hello'}) self.assertDictEqual(ret, doc_fields) self.assertEqual(self.client.get_storage_dir(object_id), f'/{object_id}') @httpretty.activate def test_get_storage_dir(self): object_id = str(ObjectId()) doc_fields = {'_id': object_id, 'storage_dir': f'/{object_id}'} def callback(request, uri, response_headers): response_headers[ 'content-type'] = 'application/json; charset=utf-8' counter[0] += 1 return [200, response_headers, json_dumps(doc_fields)] counter = [0] httpretty.register_uri(httpretty.GET, f'http://127.0.0.1/v1/_get/{object_id}', body=callback) self.assertEqual(self.client.get_storage_dir(object_id), f'/{object_id}') self.assertEqual(counter[0], 1) self.assertEqual(self.client.get_storage_dir(object_id), f'/{object_id}') self.assertEqual(counter[0], 1) @httpretty.activate def test_get_file(self): object_id = str(ObjectId()) httpretty.register_uri( httpretty.GET, f'http://127.0.0.1/v1/_getfile/{object_id}/hello.txt', body=b'hello, world') self.assertEqual(self.client.get_file(object_id, '/./hello.txt'), b'hello, world')