def forward(self, *input): """Returns the log-probability of the marginal probability of each node in the graph""" edge_logits = self.model(*input) if len(edge_logits.shape) == 4: # Flatten 2D data two_dim_data = True b, c, h, w = edge_logits.shape edge_logits = edge_logits.view(b, c, h * w).transpose( 1, 2).contiguous().view(b * h * w, c) else: two_dim_data = False lse = scatter_logsumexp(edge_logits, self.sibling_mask, dim=1) scaled_logits = edge_logits - lse[:, self.sibling_mask] marginal_logits = scaled_logits.clone() depth = self.path_matrix.shape[0] for d in range(depth): parent_logits = scaled_logits[:, self.path_matrix[d]] parent_logits[:, self.path_matrix[d] == 0] = 0 marginal_logits = marginal_logits + parent_logits if two_dim_data: # Un-flatten 2D data _, n_out = marginal_logits.shape marginal_logits = marginal_logits.view(b, h * w, n_out).transpose( 1, 2).contiguous().view(b, n_out, h, w) return marginal_logits
def copynet_logits(self, logits, a_g, graph_ids): gen_probs = torch.empty( (logits.size(0), logits.size(1), logits.size(2), 2), device=logits.device) gen_probs[:, :, :, 0] = logits # (batch_size, trg_len, vocab_size) with -inf gen_probs[:, :, :, 1] = scatter_logsumexp(a_g, graph_ids.unsqueeze(1).expand_as(a_g), dim=2, dim_size=gen_probs.size(2)) return torch.logsumexp(gen_probs, 3)
def test_logsumexp(dtype, device): src = tensor([0.5, 0, 0.5, -2.1, 3.2, 7, -1, float('-inf')], dtype, device) index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device) out = scatter_logsumexp(src, index) out0 = torch.logsumexp(torch.tensor([0.5, 0.5], dtype=dtype), dim=-1) out1 = torch.logsumexp(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1) out2 = torch.logsumexp(torch.tensor(7, dtype=dtype), dim=-1) out3 = torch.tensor(torch.finfo(dtype).min, dtype=dtype) out4 = torch.tensor(-1, dtype=dtype) expected = torch.stack([out0, out1, out2, out3, out4], dim=0) assert torch.allclose(out, expected)
def test_logsumexp(): src = torch.tensor([0.5, 0, 0.5, -2.1, 3.2, 7, -1, -100]) index = torch.tensor([0, 1, 0, 1, 1, 2, 4, 4]) out = scatter_logsumexp(src, index) out0 = torch.logsumexp(torch.tensor([0.5, 0.5]), dim=-1) out1 = torch.logsumexp(torch.tensor([0, -2.1, 3.2]), dim=-1) out2 = torch.logsumexp(torch.tensor(7, dtype=torch.float), dim=-1) out3 = torch.logsumexp(torch.tensor([], dtype=torch.float), dim=-1) out4 = torch.tensor(-1, dtype=torch.float) expected = torch.stack([out0, out1, out2, out3, out4], dim=0) assert torch.allclose(out, expected)
def test_logsumexp(): inputs = torch.tensor([ 0.5, 0.5, 0.0, -2.1, 3.2, 7.0, -1.0, -100.0, float('-inf'), float('-inf'), 0.0 ]) inputs.requires_grad_() index = torch.tensor([0, 0, 1, 1, 1, 2, 4, 4, 5, 6, 6]) splits = [2, 3, 1, 0, 2, 1, 2] outputs = scatter_logsumexp(inputs, index) for src, out in zip(inputs.split(splits), outputs.unbind()): assert out.tolist() == torch.logsumexp(src, dim=0).tolist() outputs.backward(torch.randn_like(outputs))
def scatter_logsumexp(device: Type[draw_devices], token_sizes: Type[draw_token_sizes], dim: Type[draw_embedding_dims], *, timer: TimerSuit): device = device() token_size, num = token_sizes(), token_sizes() if num > token_size: token_size, num = num, token_size in_dim = dim() inputs = torch.randn((token_size, in_dim), requires_grad=True, device=device) index1 = torch.randint(0, num, (token_size, ), device=device) index2 = index1[:, None].expand_as(inputs) with timer.rua_forward: actual = rua.scatter_logsumexp(tensor=inputs, index=index1) with timer.naive_forward: excepted = torch_scatter.scatter_logsumexp(src=inputs, index=index2, dim=0) with timer.rua_backward: torch.autograd.grad( actual, inputs, torch.ones_like(actual), create_graph=False, ) with timer.naive_backward: torch.autograd.grad( excepted, inputs, torch.ones_like(excepted), create_graph=False, )
# factorStates_to_varIndices, factorToVar_edge_index = test_create_factorStates_to_varIndices(factorToVar_double_list=[[0,1], [0,1]], numVars=3) # factor_beliefs = torch.tensor(range(8)).float() # factor_beliefs = factor_beliefs.view(2,2,2) print() print("factor_beliefs:", factor_beliefs) mapped_factor_beliefs = test_map_beliefs(beliefs=factor_beliefs, map_type='factor', facToVar_edge_idx=factorToVar_edge_index) print("mapped_factor_beliefs:", mapped_factor_beliefs) num_edges = factorToVar_edge_index.shape[1] print("mapped_factor_beliefs.view(mapped_factor_beliefs.numel())).shape:", mapped_factor_beliefs.view(mapped_factor_beliefs.numel()).shape) print("factorStates_to_varIndices.shape:", factorStates_to_varIndices.shape) factorStates_to_varIndices[torch.where(factorStates_to_varIndices == -1)] = num_edges*2 marginalized_states_fast = torch.exp(scatter_logsumexp(src=torch.log(mapped_factor_beliefs.view(mapped_factor_beliefs.numel())), index=factorStates_to_varIndices, dim_size=num_edges*2 + 1)) print("marginalized_states_fast:", marginalized_states_fast) old_marginalized_states = marginalized_states_fast[:-1].view((2,num_edges)).permute(1,0) new_marginalized_states = marginalized_states_fast[:-1].view(num_edges,belief_repeats,var_cardinality) print("old_marginalized_states:", old_marginalized_states) print("new_marginalized_states:", new_marginalized_states) old_var_beliefs = scatter_('add', old_marginalized_states, factorToVar_edge_index[1]) new_var_beliefs = scatter_('add', new_marginalized_states, factorToVar_edge_index[1]) print("old_var_beliefs:", old_var_beliefs) print("new_var_beliefs:", new_var_beliefs) old_mapped_var_beliefs = test_map_beliefs(beliefs=old_var_beliefs, map_type='var', facToVar_edge_idx=factorToVar_edge_index)