def test_binary_torch_func(tensor_list, other_tensor_list, func_data):
    """ Test torch function wich use two tensors """
    func, _, same_base_name = func_data
    masked_tensor = maskedtensor.from_list(tensor_list, dims=(0, 1))
    other_base_name = 'N' if same_base_name else 'M'
    other_masked_tensor = maskedtensor.from_list(other_tensor_list,
                                                 dims=(0, 1),
                                                 base_name=other_base_name)
    res_mt = list(func(masked_tensor, other_masked_tensor))
    binary_list = zip(tensor_list, other_tensor_list)
    res_lst = apply_binary_list_tensors(binary_list, func)
    for t_mt, t_lst in zip(res_mt, res_lst):
        assert t_mt.size() == t_lst.size()
        assert torch.allclose(t_mt, t_lst,
                              atol=ATOL), torch.norm(t_mt - t_lst,
                                                     p=float('inf'))
Exemplo n.º 2
0
def masked_batch():
    lst = [
        torch.empty((N_VERTICES, N_VERTICES)).normal_()
        for _ in range(BATCH_SIZE)
    ]
    mtensor = maskedtensor.from_list(lst, dims=(0, 1))
    return mtensor
def test_accuracy_func(tensor_list, func_data):
    func, _ = func_data
    masked_tensor = maskedtensor.from_list(tensor_list, dims=(0, 1))
    res_mt = func(masked_tensor)
    res_lst = sum(apply_list_tensors(tensor_list, func))
    assert torch.allclose(res_mt, res_lst,
                          atol=ATOL), torch.norm(res_mt - res_lst,
                                                 p=float('inf'))
def test_score_func(score_list, func_data):
    """ Test score function """
    func, _ = func_data
    masked_tensor = maskedtensor.from_list(score_list, dims=(0, 1))
    res_mt = func(masked_tensor)
    res_lst = func(torch.stack(score_list))
    assert torch.allclose(res_mt, res_lst,
                          atol=ATOL), torch.norm(res_mt - res_lst,
                                                 p=float('inf'))
def test_torch_func(tensor_list, func_data):
    """ Test torch function """
    func, _ = func_data
    masked_tensor = maskedtensor.from_list(tensor_list, dims=(0, 1))
    res_mt = list(func(masked_tensor))
    res_lst = apply_list_tensors(tensor_list, func)
    for t_mt, t_lst in zip(res_mt, res_lst):
        assert t_mt.size() == t_lst.size()
        assert torch.allclose(t_mt, t_lst,
                              atol=ATOL), torch.norm(t_mt - t_lst,
                                                     p=float('inf'))
Exemplo n.º 6
0
def batch(request):
    transpose = request.param
    if transpose:
        tensor_lst = [
            torch.t(perturb(torch.eye(n_vertices, n_vertices)))
            for n_vertices in N_VERTICES_RANGE
        ]
    else:
        tensor_lst = [
            perturb(torch.eye(n_vertices, n_vertices))
            for n_vertices in N_VERTICES_RANGE
        ]
    return maskedtensor.from_list(tensor_lst, dims=(0, 1))
Exemplo n.º 7
0
def correct_batch():
    tensor_lst = [torch.eye(n_vertices) for n_vertices in N_VERTICES_RANGE]
    return from_list(tensor_lst, dims=(0, 1))
Exemplo n.º 8
0
def batch():
    tensor_lst = [
        torch.empty(n_vertices, n_vertices).normal_()
        for n_vertices in N_VERTICES_RANGE
    ]
    return from_list(tensor_lst, dims=(0, 1))