def test_cache_and_get_tensor(nparray, tensor_key): """Test that cash and get work correctly.""" db = TensorDB() db.cache_tensor({tensor_key: nparray}) cached_nparray = db.get_tensor_from_cache(tensor_key) assert np.array_equal(nparray, cached_nparray)
def tensor_db(): """Prepare tensor db.""" db = TensorDB() array_1 = np.array([0, 1, 2, 3, 4]) tensor_key_1 = TensorKey('tensor_name', 'agg', 0, False, ('col1', )) array_2 = np.array([2, 3, 4, 5, 6]) tensor_key_2 = TensorKey('tensor_name', 'agg', 0, False, ('col2', )) db.cache_tensor({tensor_key_1: array_1, tensor_key_2: array_2}) return db
def test_get_aggregated_tensor_only_col(nparray, tensor_key): """Test that get_aggregated_tensor returns None if data presents for only collaborator.""" db = TensorDB() db.cache_tensor({tensor_key: nparray}) tensor_name, origin, round_number, report, tags = tensor_key tensor_key = TensorKey(tensor_name, 'col2', round_number, report, ('model', )) collaborator_weight_dict = {'col1': 0.5, 'col2': 0.5} agg_nparray, agg_metadata_dict = db.get_aggregated_tensor( tensor_key, collaborator_weight_dict, None) assert agg_nparray is None
def test_clean_up_not_clean_up_with_negative_argument(nparray, tensor_key): """Test that clean_up don't remove if records remove_older_than is negative.""" db = TensorDB() db.cache_tensor({tensor_key: nparray}) db.tensor_db['round'] = 2 db.clean_up(remove_older_than=-1) db.tensor_db['round'] = 0 cached_nparray = db.get_tensor_from_cache(tensor_key) assert np.array_equal(nparray, cached_nparray)
def test_clean_up(nparray, tensor_key): """Test that clean_up remove old records.""" db = TensorDB() db.cache_tensor({tensor_key: nparray}) db.tensor_db['round'] = 2 db.clean_up() cached_nparray = db.get_tensor_from_cache(tensor_key) assert cached_nparray is None
def test_clean_up_not_old(nparray, tensor_key): """Test that clean_up don't remove not old records.""" db = TensorDB() db.cache_tensor({tensor_key: nparray}) db.clean_up() cached_nparray = db.get_tensor_from_cache(tensor_key) assert np.array_equal(nparray, cached_nparray)
def test_get_aggregated_tensor_raise_wrong_weights(nparray, tensor_key): """Test that get_aggregated_tensor raises if collaborator weights do not sum to 1.0.""" db = TensorDB() db.cache_tensor({tensor_key: nparray}) collaborator_weight_dict = {'col1': 0.5, 'col2': 0.8} with pytest.raises(AssertionError): db.get_aggregated_tensor(tensor_key, collaborator_weight_dict, None)
def test_get_aggregated_tensor_directly(nparray, tensor_key): """Test that get_aggregated_tensor returns tensors directly.""" db = TensorDB() db.cache_tensor({tensor_key: nparray}) tensor_name, origin, round_number, report, tags = tensor_key tensor_key = TensorKey(tensor_name, 'col2', round_number, report, ('model', )) db.cache_tensor({tensor_key: nparray}) agg_nparray, agg_metadata_dict = db.get_aggregated_tensor( tensor_key, {}, None) assert np.array_equal(nparray, agg_nparray)
def test_get_aggregated_tensor(nparray, tensor_key): """Test that get_aggregated_tensor returns tensors directly.""" db = TensorDB() db.cache_tensor({tensor_key: nparray}) tensor_name, origin, round_number, report, tags = tensor_key tensor_key = TensorKey(tensor_name, 'col2', round_number, report, ('model', )) db.cache_tensor({tensor_key: nparray}) collaborator_weight_dict = {'col1': 0.5, 'col2': 0.5} agg_nparray, agg_metadata_dict = db.get_aggregated_tensor( tensor_key, collaborator_weight_dict, WeightedAverage()) assert np.array_equal(nparray, agg_nparray)
def test_tensor_from_cache_empty(tensor_key): """Test get works returns None if tensor key is not in the db.""" db = TensorDB() cached_nparray = db.get_tensor_from_cache(tensor_key) assert cached_nparray is None