def test_cache_and_get_tensor(nparray, tensor_key):
    """Test that cash and get work correctly."""
    db = TensorDB()
    db.cache_tensor({tensor_key: nparray})
    cached_nparray = db.get_tensor_from_cache(tensor_key)

    assert np.array_equal(nparray, cached_nparray)
def tensor_db():
    """Prepare tensor db."""
    db = TensorDB()
    array_1 = np.array([0, 1, 2, 3, 4])
    tensor_key_1 = TensorKey('tensor_name', 'agg', 0, False, ('col1', ))
    array_2 = np.array([2, 3, 4, 5, 6])
    tensor_key_2 = TensorKey('tensor_name', 'agg', 0, False, ('col2', ))
    db.cache_tensor({tensor_key_1: array_1, tensor_key_2: array_2})
    return db
def test_get_aggregated_tensor_only_col(nparray, tensor_key):
    """Test that get_aggregated_tensor returns None if data presents for only collaborator."""
    db = TensorDB()
    db.cache_tensor({tensor_key: nparray})
    tensor_name, origin, round_number, report, tags = tensor_key
    tensor_key = TensorKey(tensor_name, 'col2', round_number, report,
                           ('model', ))

    collaborator_weight_dict = {'col1': 0.5, 'col2': 0.5}
    agg_nparray, agg_metadata_dict = db.get_aggregated_tensor(
        tensor_key, collaborator_weight_dict, None)

    assert agg_nparray is None
def test_clean_up_not_clean_up_with_negative_argument(nparray, tensor_key):
    """Test that clean_up don't remove if records remove_older_than is negative."""
    db = TensorDB()

    db.cache_tensor({tensor_key: nparray})
    db.tensor_db['round'] = 2
    db.clean_up(remove_older_than=-1)
    db.tensor_db['round'] = 0
    cached_nparray = db.get_tensor_from_cache(tensor_key)

    assert np.array_equal(nparray, cached_nparray)
def test_clean_up(nparray, tensor_key):
    """Test that clean_up remove old records."""
    db = TensorDB()

    db.cache_tensor({tensor_key: nparray})
    db.tensor_db['round'] = 2
    db.clean_up()
    cached_nparray = db.get_tensor_from_cache(tensor_key)

    assert cached_nparray is None
def test_clean_up_not_old(nparray, tensor_key):
    """Test that clean_up don't remove not old records."""
    db = TensorDB()

    db.cache_tensor({tensor_key: nparray})
    db.clean_up()
    cached_nparray = db.get_tensor_from_cache(tensor_key)

    assert np.array_equal(nparray, cached_nparray)
def test_get_aggregated_tensor_raise_wrong_weights(nparray, tensor_key):
    """Test that get_aggregated_tensor raises if collaborator weights do not sum to 1.0."""
    db = TensorDB()
    db.cache_tensor({tensor_key: nparray})

    collaborator_weight_dict = {'col1': 0.5, 'col2': 0.8}
    with pytest.raises(AssertionError):
        db.get_aggregated_tensor(tensor_key, collaborator_weight_dict, None)
def test_get_aggregated_tensor_directly(nparray, tensor_key):
    """Test that get_aggregated_tensor returns tensors directly."""
    db = TensorDB()
    db.cache_tensor({tensor_key: nparray})
    tensor_name, origin, round_number, report, tags = tensor_key
    tensor_key = TensorKey(tensor_name, 'col2', round_number, report,
                           ('model', ))

    db.cache_tensor({tensor_key: nparray})
    agg_nparray, agg_metadata_dict = db.get_aggregated_tensor(
        tensor_key, {}, None)

    assert np.array_equal(nparray, agg_nparray)
Example #9
0
def test_get_aggregated_tensor(nparray, tensor_key):
    """Test that get_aggregated_tensor returns tensors directly."""
    db = TensorDB()
    db.cache_tensor({tensor_key: nparray})
    tensor_name, origin, round_number, report, tags = tensor_key
    tensor_key = TensorKey(tensor_name, 'col2', round_number, report,
                           ('model', ))
    db.cache_tensor({tensor_key: nparray})

    collaborator_weight_dict = {'col1': 0.5, 'col2': 0.5}
    agg_nparray, agg_metadata_dict = db.get_aggregated_tensor(
        tensor_key, collaborator_weight_dict, WeightedAverage())

    assert np.array_equal(nparray, agg_nparray)
def test_tensor_from_cache_empty(tensor_key):
    """Test get works returns None if tensor key is not in the db."""
    db = TensorDB()
    cached_nparray = db.get_tensor_from_cache(tensor_key)
    assert cached_nparray is None