def forward(self, *input):
        """Returns the log-probability of the marginal probability of each node in the graph"""
        edge_logits = self.model(*input)
        if len(edge_logits.shape) == 4:  # Flatten 2D data
            two_dim_data = True
            b, c, h, w = edge_logits.shape
            edge_logits = edge_logits.view(b, c, h * w).transpose(
                1, 2).contiguous().view(b * h * w, c)
        else:
            two_dim_data = False

        lse = scatter_logsumexp(edge_logits, self.sibling_mask, dim=1)
        scaled_logits = edge_logits - lse[:, self.sibling_mask]
        marginal_logits = scaled_logits.clone()
        depth = self.path_matrix.shape[0]
        for d in range(depth):
            parent_logits = scaled_logits[:, self.path_matrix[d]]
            parent_logits[:, self.path_matrix[d] == 0] = 0
            marginal_logits = marginal_logits + parent_logits

        if two_dim_data:  # Un-flatten 2D data
            _, n_out = marginal_logits.shape
            marginal_logits = marginal_logits.view(b, h * w, n_out).transpose(
                1, 2).contiguous().view(b, n_out, h, w)

        return marginal_logits
Exemplo n.º 2
0
    def copynet_logits(self, logits, a_g, graph_ids):
        gen_probs = torch.empty(
            (logits.size(0), logits.size(1), logits.size(2), 2),
            device=logits.device)
        gen_probs[:, :, :, 0] = logits

        # (batch_size, trg_len, vocab_size) with -inf
        gen_probs[:, :, :,
                  1] = scatter_logsumexp(a_g,
                                         graph_ids.unsqueeze(1).expand_as(a_g),
                                         dim=2,
                                         dim_size=gen_probs.size(2))

        return torch.logsumexp(gen_probs, 3)
Exemplo n.º 3
0
def test_logsumexp(dtype, device):
    src = tensor([0.5, 0, 0.5, -2.1, 3.2, 7, -1, float('-inf')], dtype, device)
    index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)

    out = scatter_logsumexp(src, index)

    out0 = torch.logsumexp(torch.tensor([0.5, 0.5], dtype=dtype), dim=-1)
    out1 = torch.logsumexp(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1)
    out2 = torch.logsumexp(torch.tensor(7, dtype=dtype), dim=-1)
    out3 = torch.tensor(torch.finfo(dtype).min, dtype=dtype)
    out4 = torch.tensor(-1, dtype=dtype)

    expected = torch.stack([out0, out1, out2, out3, out4], dim=0)
    assert torch.allclose(out, expected)
Exemplo n.º 4
0
def test_logsumexp():
    src = torch.tensor([0.5, 0, 0.5, -2.1, 3.2, 7, -1, -100])
    index = torch.tensor([0, 1, 0, 1, 1, 2, 4, 4])

    out = scatter_logsumexp(src, index)

    out0 = torch.logsumexp(torch.tensor([0.5, 0.5]), dim=-1)
    out1 = torch.logsumexp(torch.tensor([0, -2.1, 3.2]), dim=-1)
    out2 = torch.logsumexp(torch.tensor(7, dtype=torch.float), dim=-1)
    out3 = torch.logsumexp(torch.tensor([], dtype=torch.float), dim=-1)
    out4 = torch.tensor(-1, dtype=torch.float)

    expected = torch.stack([out0, out1, out2, out3, out4], dim=0)
    assert torch.allclose(out, expected)
Exemplo n.º 5
0
def test_logsumexp():
    inputs = torch.tensor([
        0.5, 0.5, 0.0, -2.1, 3.2, 7.0, -1.0, -100.0,
        float('-inf'),
        float('-inf'), 0.0
    ])
    inputs.requires_grad_()
    index = torch.tensor([0, 0, 1, 1, 1, 2, 4, 4, 5, 6, 6])
    splits = [2, 3, 1, 0, 2, 1, 2]

    outputs = scatter_logsumexp(inputs, index)

    for src, out in zip(inputs.split(splits), outputs.unbind()):
        assert out.tolist() == torch.logsumexp(src, dim=0).tolist()

    outputs.backward(torch.randn_like(outputs))
Exemplo n.º 6
0
def scatter_logsumexp(device: Type[draw_devices],
                      token_sizes: Type[draw_token_sizes],
                      dim: Type[draw_embedding_dims], *, timer: TimerSuit):
    device = device()
    token_size, num = token_sizes(), token_sizes()
    if num > token_size:
        token_size, num = num, token_size
    in_dim = dim()

    inputs = torch.randn((token_size, in_dim),
                         requires_grad=True,
                         device=device)
    index1 = torch.randint(0, num, (token_size, ), device=device)
    index2 = index1[:, None].expand_as(inputs)

    with timer.rua_forward:
        actual = rua.scatter_logsumexp(tensor=inputs, index=index1)

    with timer.naive_forward:
        excepted = torch_scatter.scatter_logsumexp(src=inputs,
                                                   index=index2,
                                                   dim=0)

    with timer.rua_backward:
        torch.autograd.grad(
            actual,
            inputs,
            torch.ones_like(actual),
            create_graph=False,
        )

    with timer.naive_backward:
        torch.autograd.grad(
            excepted,
            inputs,
            torch.ones_like(excepted),
            create_graph=False,
        )
Exemplo n.º 7
0
#     factorStates_to_varIndices, factorToVar_edge_index = test_create_factorStates_to_varIndices(factorToVar_double_list=[[0,1], [0,1]], numVars=3)
#     factor_beliefs = torch.tensor(range(8)).float()
#     factor_beliefs = factor_beliefs.view(2,2,2)

    print()

    print("factor_beliefs:", factor_beliefs)
    mapped_factor_beliefs = test_map_beliefs(beliefs=factor_beliefs, map_type='factor', facToVar_edge_idx=factorToVar_edge_index)
    print("mapped_factor_beliefs:", mapped_factor_beliefs)
    
    num_edges = factorToVar_edge_index.shape[1]
    print("mapped_factor_beliefs.view(mapped_factor_beliefs.numel())).shape:", mapped_factor_beliefs.view(mapped_factor_beliefs.numel()).shape)
    print("factorStates_to_varIndices.shape:", factorStates_to_varIndices.shape)
    factorStates_to_varIndices[torch.where(factorStates_to_varIndices == -1)] = num_edges*2
    marginalized_states_fast = torch.exp(scatter_logsumexp(src=torch.log(mapped_factor_beliefs.view(mapped_factor_beliefs.numel())), index=factorStates_to_varIndices, dim_size=num_edges*2 + 1))    
    
    print("marginalized_states_fast:", marginalized_states_fast)
    
    old_marginalized_states = marginalized_states_fast[:-1].view((2,num_edges)).permute(1,0)
    new_marginalized_states = marginalized_states_fast[:-1].view(num_edges,belief_repeats,var_cardinality)    
    print("old_marginalized_states:", old_marginalized_states)
    print("new_marginalized_states:", new_marginalized_states)
    
    old_var_beliefs = scatter_('add', old_marginalized_states, factorToVar_edge_index[1])
    new_var_beliefs = scatter_('add', new_marginalized_states, factorToVar_edge_index[1])
    
    print("old_var_beliefs:", old_var_beliefs)
    print("new_var_beliefs:", new_var_beliefs)
    
    old_mapped_var_beliefs = test_map_beliefs(beliefs=old_var_beliefs, map_type='var', facToVar_edge_idx=factorToVar_edge_index)