Beispiel #1
0
def test_document_sparse_embedding(
    scipy_sparse_matrix,
    return_sparse_ndarray_cls_type,
    return_scipy_class_type,
    return_expected_type,
):
    d = Document()
    d.embedding = scipy_sparse_matrix
    cls_type = None
    sparse_kwargs = {}
    if return_sparse_ndarray_cls_type == 'scipy':
        from jina.types.ndarray.sparse.scipy import SparseNdArray

        cls_type = SparseNdArray
        sparse_kwargs['sp_format'] = return_scipy_class_type
    elif return_sparse_ndarray_cls_type == 'torch':
        from jina.types.ndarray.sparse.pytorch import SparseNdArray

        cls_type = SparseNdArray
    elif return_sparse_ndarray_cls_type == 'tf':
        from jina.types.ndarray.sparse.tensorflow import SparseNdArray

        cls_type = SparseNdArray

    embedding = d.get_sparse_embedding(sparse_ndarray_cls_type=cls_type,
                                       **sparse_kwargs)
    assert embedding is not None
    assert isinstance(embedding, return_expected_type)
    if return_sparse_ndarray_cls_type == 'torch':
        assert embedding.is_sparse
Beispiel #2
0
def test_document_sparse_embedding(
    scipy_sparse_matrix,
    return_sparse_ndarray_cls_type,
    return_scipy_class_type,
    return_expected_type,
    field,
):
    d = Document()
    setattr(d, field, scipy_sparse_matrix)
    cls_type = None
    sparse_kwargs = {}
    if return_sparse_ndarray_cls_type == 'scipy':
        from jina.types.ndarray.sparse.scipy import SparseNdArray

        cls_type = SparseNdArray
        sparse_kwargs['sp_format'] = return_scipy_class_type
    elif return_sparse_ndarray_cls_type == 'torch':
        from jina.types.ndarray.sparse.pytorch import SparseNdArray

        cls_type = SparseNdArray
    elif return_sparse_ndarray_cls_type == 'tf':
        from jina.types.ndarray.sparse.tensorflow import SparseNdArray

        cls_type = SparseNdArray

    if field == 'blob':
        field_sparse = d.get_sparse_blob(sparse_ndarray_cls_type=cls_type,
                                         **sparse_kwargs)
    elif field == 'embedding':
        field_sparse = d.get_sparse_embedding(sparse_ndarray_cls_type=cls_type,
                                              **sparse_kwargs)

    assert field_sparse is not None
    assert isinstance(field_sparse, return_expected_type)
    if return_sparse_ndarray_cls_type == 'torch':
        assert field_sparse.is_sparse

    if return_sparse_ndarray_cls_type == 'scipy':
        np.testing.assert_equal(field_sparse.todense(),
                                scipy_sparse_matrix.todense())
    elif return_sparse_ndarray_cls_type == 'torch':
        np.testing.assert_equal(field_sparse.to_dense().numpy(),
                                scipy_sparse_matrix.todense())
    elif return_scipy_class_type == 'tf':
        np.testing.assert_equal(
            tf.sparse.to_dense(field_sparse).numpy(),
            scipy_sparse_matrix.todense())