def backward(ctx, logl_grad): # Recovering saved tensors from forward tree = ctx.saved_input beta, t_beta, A, B, Pi, SP = ctx.saved_tensors # Getting model info C, n_gen, device = A.size(0), A.size(-1), A.device # Creating parameter gradient tensors A_grad, B_grad, Pi_grad, SP_grad = torch.zeros_like( A), torch.zeros_like(B), torch.zeros_like(Pi), torch.zeros_like(SP) eps = torch.zeros((tree['dim'], C, n_gen), device=device) roots = tree['levels'][0][0].unique(sorted=False) eps[roots] = beta[roots] for l in tree['levels']: # Computing eps_{u, ch(u)}(i, j) pos_ch = tree['pos'][l[1]] SP_ch = SP[pos_ch] A_ch = A[:, :, pos_ch].permute(2, 0, 1, 3) trans_ch = SP_ch.unsqueeze(1).unsqueeze(1) * A_ch eps_pa = eps[l[0]].unsqueeze(2) beta_ch = beta[l[1]].unsqueeze(1) t_beta_pa = t_beta[l[0]].unsqueeze(2) # Computing eps_{ch(u)} eps_joint = (eps_pa * trans_ch * beta_ch) / t_beta_pa eps_ch = eps_joint.sum(1) eps[l[1]] = eps_ch local_grad = logl_grad[tree['batch'][l[1]]] # Accumulating gradient in grad_A and grad_SP SP_grad = scatter((eps_ch.sum(1) - SP_ch) * local_grad, index=pos_ch, dim=0, out=SP_grad) local_grad = local_grad.unsqueeze(0).unsqueeze(0) A_grad = scatter( (eps_joint - A_ch * eps_ch.unsqueeze(1)).permute(1, 2, 0, 3) * local_grad, index=pos_ch, dim=2, out=A_grad) eps_leaves = eps[tree['leaves']] local_grad = logl_grad[tree['batch'][tree['leaves']]].unsqueeze(1) Pi_grad = ((eps_leaves - Pi.unsqueeze(0)) * local_grad).sum(0) eps_B = eps.permute(1, 0, 2) B_grad = scatter(torch.ones_like(eps_B), index=tree['x'], dim=1, out=B_grad) B_grad -= B * tree['dim'] local_grad = logl_grad[tree['batch']].unsqueeze(0) B_grad *= (eps_B * local_grad).sum(1, keepdim=True) return None, A_grad, B_grad, Pi_grad, SP_grad
def forward(ctx, x, tree, lambda_A, lambda_B, lambda_Pi): # Softmax Reparameterization sm_A, sm_B, sm_Pi = [], [], [] for i in range(lambda_A.size(-1)): sm_A.append(F.softmax(lambda_A[:, :, i], dim=0)) sm_B.append(F.softmax(lambda_B[:, :, i], dim=1)) sm_Pi.append(F.softmax(lambda_Pi[:, i], dim=0)) A, B, Pi = torch.stack(sm_A, dim=-1), torch.stack(sm_B, dim=-1), torch.stack(sm_Pi, dim=-1) # Getting model info C, n_gen, device = A.size(0), A.size(-1), A.device # Preliminary Downward recursion: init prior = torch.zeros((tree['dim'], C, n_gen), device=device) # Preliminary Downward recursion: base case prior[tree['roots']] = Pi # Preliminary Downward recursion for l in tree['levels']: prior_pa = prior[l[0]].unsqueeze(1) prior_l = (A.unsqueeze(0) * prior_pa).sum(2) prior[l[1]] = prior_l # Upward recursion: init beta = prior * B[:, x[tree['inv_map']]].permute(1, 0, 2) t_beta = torch.zeros((tree['dim'], C, n_gen), device=device) log_likelihood = torch.zeros((tree['dim'], n_gen), device=device) # Upward Recursion: base case beta_leaves_unnorm = beta[tree['leaves']] nu = beta_leaves_unnorm.sum(dim=1) beta[tree['leaves']] = beta_leaves_unnorm / nu.unsqueeze(1) log_likelihood[tree['leaves']] = nu.log() # Upward Recursion for l in reversed(tree['levels']): # Computing beta_uv children beta_ch = beta[l[1]].unsqueeze(2) prior_l = prior[l[1]].unsqueeze(2) beta_uv = (A.unsqueeze(0) * beta_ch / prior_l).sum(1) t_beta[l[1]] = beta_uv # Computing beta on level pa_idx = l[0].unique(sorted=False) prev_beta = beta[pa_idx] beta = scatter(src=beta_uv, index=l[0], dim=0, out=beta, reduce='mul') beta_l_unnorm = prev_beta * beta[pa_idx] nu = beta_l_unnorm.sum(1) beta[pa_idx] = beta_l_unnorm / nu.unsqueeze(1) log_likelihood[pa_idx] = nu.log() ctx.saved_input = x, tree ctx.save_for_backward(prior, beta, t_beta, A, B, Pi) return scatter(log_likelihood, tree['trees_ind'], dim=0)
def forward(ctx, tree, lambda_A, lambda_B, lambda_Pi, lambda_SP): # Softmax reparameterization sm_A, sm_B, sm_Pi, sm_SP = [], [], [], [] for i in range(lambda_A.size(-1)): sm_A.append( torch.cat([ F.softmax(lambda_A[:, :, j, i], dim=0).unsqueeze(2) for j in range(lambda_A.size(2)) ], dim=2)) sm_B.append(F.softmax(lambda_B[:, :, i], dim=1)) sm_Pi.append(F.softmax(lambda_Pi[:, i], dim=0)) sm_SP.append(F.softmax(lambda_SP[:, i], dim=0)) A, B, Pi, SP = torch.stack(sm_A, dim=-1), torch.stack( sm_B, dim=-1), torch.stack(sm_Pi, dim=-1), torch.stack(sm_SP, dim=-1) # Getting model info C, n_gen, device = A.size(0), A.size(-1), A.device # Upward recursion: init beta = torch.zeros((tree['dim'], C, n_gen), device=device) t_beta = torch.zeros((tree['dim'], C, n_gen), device=device) log_likelihood = torch.zeros((tree['dim'], n_gen), device=device) # Upward recursion: base case B_leaves = B[:, tree['x'][tree['leaves']]] beta_leaves = (Pi.unsqueeze(1) * B_leaves).permute(1, 0, 2) nu = beta_leaves.sum(dim=1) beta[tree['leaves']] = beta_leaves / nu.unsqueeze(1) log_likelihood[tree['leaves']] = nu.log() # Upward recursion for l in reversed(tree['levels']): # Computing beta_uv children = (A_ch @ beta_ch) / prior_pa pos_ch = tree['pos'][l[1]] SP_ch = SP[pos_ch].unsqueeze(1).unsqueeze(2) A_ch = A[:, :, pos_ch].permute(2, 0, 1, 3) beta_ch = beta[l[1]].unsqueeze(1) t_beta_ch = (SP_ch * A_ch * beta_ch).sum(2) t_beta = scatter(src=t_beta_ch, index=l[0], dim=0, out=t_beta) u_idx = l[0].unique(sorted=False) B_l = B[:, tree['x'][u_idx]].permute(1, 0, 2) beta_l = t_beta[u_idx] * B_l nu = beta_l.sum(dim=1) beta[u_idx] = beta_l / nu.unsqueeze(1) log_likelihood[u_idx] = nu.log() ctx.saved_input = tree ctx.save_for_backward(beta, t_beta, A, B, Pi, SP) return scatter(log_likelihood, tree['batch'], dim=0)
def backward(ctx, logl_grad): # Recovering saved tensors from forward tree = ctx.saved_input prior, beta, t_beta, A, B, Pi = ctx.saved_tensors # Getting model info C, n_gen, device = A.size(0), A.size(-1), A.device # Creating parameter gradient tensors A_grad, B_grad = torch.zeros_like(A), torch.zeros_like(B) eps = torch.zeros((tree['dim'], C, n_gen), device=device) roots = tree['levels'][0][0].unique(sorted=False) eps_roots = beta[roots] eps[roots] = eps_roots for l in tree['levels']: # Computing eps_{u, pa(u)}(i, j) eps_pa = eps[l[0]].unsqueeze(1) pos_ch = tree['pos'][l[1]] A_ch = A[:, :, pos_ch].permute(2, 0, 1, 3) beta_ch = beta[l[1]].unsqueeze(2) eps_trans_pa = A_ch * eps_pa t_beta_ch = t_beta[l[1]].unsqueeze(1) prior_ch = prior[l[1]].unsqueeze(2) eps_joint = (beta_ch * eps_trans_pa) / (prior_ch * t_beta_ch) # Computing eps_u(i) eps_ch = eps_joint.sum(2) eps[l[1]] = eps_ch local_grad = logl_grad[tree['batch'][l[1]]].unsqueeze(0).unsqueeze( 0) # Accumulating gradient in grad_A and grad_SP A_grad = scatter( (eps_joint - eps_trans_pa).permute(1, 2, 0, 3) * local_grad, index=pos_ch, dim=2, out=A_grad) eps_B = eps.permute(1, 0, 2) B_grad = scatter(torch.ones_like(eps_B), index=tree['x'], dim=1, out=B_grad) B_grad -= B * tree['dim'] local_grad = logl_grad[tree['batch']].unsqueeze(0) B_grad *= (eps_B * local_grad).sum(1, keepdim=True) local_grad = logl_grad[tree['batch'][roots]].unsqueeze(1) Pi_grad = ((eps_roots - Pi.unsqueeze(0)) * local_grad).sum(0) return None, A_grad, B_grad, Pi_grad
def forward(ctx, x, tree, lambda_A, lambda_B, lambda_Pi): # Softmax Reparameterization sm_A, sm_B, sm_Pi = [], [], [] for i in range(lambda_A.size(-1)): sm_A.append(F.softmax(lambda_A[:, :, i], dim=0)) sm_B.append(F.softmax(lambda_B[:, :, i], dim=1)) sm_Pi.append(F.softmax(lambda_Pi[:, i], dim=0)) A, B, Pi = torch.stack(sm_A, dim=-1), torch.stack( sm_B, dim=-1), torch.stack(sm_Pi, dim=-1) # Getting model info C, n_gen, device = A.size(0), A.size(-1), A.device # Upward recursion: init beta = torch.zeros((tree['dim'], C, n_gen), device=device) t_beta = torch.zeros((tree['dim'], C, n_gen), device=device) log_likelihood = torch.zeros((tree['dim'], n_gen), device=device) # Upward Recursion: base case Pi_leaves = Pi.unsqueeze(0) leaves_idx = tree['inv_map'][tree['leaves']] B_leaves = B[:, x[leaves_idx]].permute(1, 0, 2) beta_leaves = Pi_leaves * B_leaves nu = beta_leaves.sum(dim=1) beta[tree['leaves']] = beta_leaves / nu.unsqueeze(1) log_likelihood[tree['leaves']] = nu.log() # Upward Recursion for l in reversed(tree['levels']): # Computing unnormalized beta_uv children = A_ch @ beta_ch beta_ch = beta[l[1]] t_beta_ch = (A.unsqueeze(0) * beta_ch.unsqueeze(1)).sum(2) t_beta = scatter(src=t_beta_ch, index=l[0], dim=0, out=t_beta, reduce='mean') u_idx = l[0].unique(sorted=False) B_l = B[:, x[tree['inv_map'][u_idx]]].permute(1, 0, 2) beta_l = B_l * t_beta[u_idx] nu = beta_l.sum(dim=1) beta[u_idx] = beta_l / nu.unsqueeze(1) log_likelihood[u_idx] = nu.log() ctx.saved_input = x, tree ctx.save_for_backward(beta, t_beta, A, B, Pi) return scatter(log_likelihood, tree['trees_ind'], dim=0)
def backward(ctx, likelihood, posterior_i): x, edge_index = ctx.saved_input posterior_il, posterior_i, Q, B = ctx.saved_tensors post_neigh = scatter(posterior_i[edge_index[1]], edge_index[0], dim=0, reduce='mean').unsqueeze(1) Q_grad = (posterior_il - Q * post_neigh).sum(0) post_nodes = posterior_i.permute(1, 0, 2) B_grad = scatter(post_nodes - post_nodes * B[:, x], index=x, dim=1, out=torch.zeros_like(B, device=B.device)) return None, None, None, Q_grad, B_grad
def backward(ctx, log_likelihood): # Recovering saved tensors from forward x, tree = ctx.saved_input beta, t_beta, A, B, Pi = ctx.saved_tensors # Getting model info C, n_gen, device = A.size(0), A.size(-1), A.device # Creating parameter gradient tensors A_grad, B_grad, Pi_grad = torch.zeros_like(A), torch.zeros_like( B), torch.zeros_like(Pi) # Downward recursion: init eps = torch.zeros((tree['dim'], C, n_gen), device=device) out_deg = torch.zeros(tree['dim'], device=device) # Downward recursion: base case eps[tree['roots']] = beta[tree['roots']] # Downward recursion for l in tree['levels']: # Computing eps_{u, ch_i(u)}(i, j) out_deg = scatter(torch.ones_like(l[1], dtype=out_deg.dtype, device=device), dim=0, index=l[0], out=out_deg) t_beta_pa = t_beta[l[0]].unsqueeze(2) eps_pa = eps[l[0]].unsqueeze(2) beta_ch = beta[l[1]].unsqueeze(1) eps_joint = (eps_pa * A.unsqueeze(0) * beta_ch) / ( t_beta_pa * out_deg[l[0]].view(-1, 1, 1, 1)) eps_ch = eps_joint.sum(1) eps[l[1]] = eps_joint.sum(1) A_grad += (eps_joint - A.unsqueeze(0) * eps_ch.unsqueeze(1)).sum(0) eps_leaves = eps[tree['leaves']] Pi_grad = eps_leaves.sum(0) - tree['leaves'].size(0) * Pi eps_nodes = eps.permute(1, 0, 2) x_trees = x[tree['inv_map']] B_grad = scatter(eps_nodes - eps_nodes * B[:, x_trees], index=x_trees, dim=1, out=B_grad) return None, None, A_grad, B_grad, Pi_grad
def forward(self, x, prev_h, edge_index, pos): Q_neigh, B = self._softmax_reparameterization() prev_h_neigh = prev_h[edge_index[1]].unsqueeze(1) trans_neigh = Q_neigh[:, :, pos].permute(2, 0, 1, 3) B_nodes = B[:, x[edge_index[0]]].permute(1, 0, 2).unsqueeze(2) # edges x C x 1 x n_gen unnorm_posterior = B_nodes * trans_neigh * prev_h_neigh # edges x C x C x n_gen likelihood = scatter(unnorm_posterior.sum([1, 2], keepdim=True), edge_index[0], dim=0, reduce='mean') posterior_il = (unnorm_posterior / (likelihood[edge_index[0]] + 1e-16)).detach() # edges x C x C x n_gen posterior_i = scatter(posterior_il.sum(2), index=edge_index[0], dim=0).detach() # nodes x C x n_gen if self.training and not self.frozen: B_nodes = B[:, x].permute(1, 0, 2) # nodes x C x n_gen, necessary for backpropagating in the new, detached graph exp_likelihood = (posterior_il * trans_neigh.log()).sum() + (posterior_i * B_nodes.log()).sum() (-exp_likelihood).backward() likelihood = likelihood.log().squeeze() return likelihood, posterior_i
def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch log_likelihood = self.cgmm(x, edge_index) log_likelihood = scatter(log_likelihood, batch, dim=0) c_neurons = (log_likelihood @ self.contrastive).tanh() to_out, _ = self.gru(c_neurons, self.h0.repeat(1, batch.max() + 1, 1)) out = self.output(to_out.flatten(start_dim=-2)) return out
def backward(ctx, logl_grad, posterior): x = ctx.saved_input posterior, B, Pi = ctx.saved_tensors post_nodes = posterior.permute(1, 0, 2) B_grad = scatter(post_nodes - post_nodes * B[:, x], index=x, dim=1, out=torch.zeros_like(B, device=B.device)) Pi_grad = posterior.sum(0) - posterior.size(0)*Pi return None, B_grad, Pi_grad
def backward(ctx, log_likelihood): # Recovering saved tensors from forward x, tree = ctx.saved_input prior, beta, t_beta, A, B, Pi = ctx.saved_tensors # Getting model info C, n_gen, device = A.size(0), A.size(-1), A.device # Creating parameter gradient tensors A_grad, B_grad = torch.zeros_like(A), torch.zeros_like(B) # Downward recursion: init eps = torch.zeros((tree['dim'], C, n_gen), device=device) # Downward recursion: base case eps[tree['roots']] = beta[tree['roots']] for l in tree['levels']: # Computing eps_{u, pa(u)}(i, j) eps_pa = eps[l[0]].unsqueeze(1) beta_ch = beta[l[1]].unsqueeze(2) eps_trans_pa = A.unsqueeze(0) * eps_pa t_beta_ch = t_beta[l[1]].unsqueeze(1) prior_ch = prior[l[1]].unsqueeze(2) eps_joint = (beta_ch * eps_trans_pa) / (prior_ch * t_beta_ch) # Computing eps_u(i) eps_ch = eps_joint.sum(2) eps[l[1]] = eps_ch A_grad += (eps_joint - eps_trans_pa).sum(0) eps_roots = eps[tree['roots']] Pi_grad = eps_roots.sum(0) - tree['roots'].size(0)*Pi eps_nodes = eps.permute(1, 0, 2) x_trees = x[tree['inv_map']] B_grad = scatter(eps_nodes - eps_nodes * B[:, x_trees], index=x_trees, dim=1, out=B_grad) return None, None, A_grad, B_grad, Pi_grad
def forward(ctx, x, prev_h, edge_index, lambda_Q, lambda_B): Q, B = [], [] for j in range(lambda_Q.size(-1)): Q.append(F.softmax(lambda_Q[:, :, j], dim=0)) B.append(F.softmax(lambda_B[:, :, j], dim=1)) Q, B = torch.stack(Q, dim=-1), torch.stack(B, dim=-1) prev_h_neigh = prev_h[edge_index[1]] prev_h_neigh_aggr = scatter(prev_h_neigh, edge_index[0], dim=0, reduce='mean') B_nodes = B[:, x].permute(1, 0, 2).unsqueeze(2) # nodes x C x 1 x n_gen prev_h_neigh_aggr = prev_h_neigh_aggr.unsqueeze(1) # nodes x 1 x C x n_gen unnorm_posterior = B_nodes * (Q * prev_h_neigh_aggr) + 1e-12 posterior_il = (unnorm_posterior / unnorm_posterior.sum([1, 2], keepdim=True)) # nodes x C x C x n_gen posterior_i = posterior_il.sum(2) ctx.saved_input = x, edge_index ctx.save_for_backward(posterior_il, posterior_i, Q, B) return unnorm_posterior.sum([1, 2]).log(), posterior_i