def _step_sample(self, step, score_func, adjs, node_flags, log=True): adjs_c = adjs adjs_c = self._add_sym_normal_noise(adjs_c) adjs_c = mask_adjs(adjs_c, node_flags) check_adjs_symmetry(adjs_c) score = score_func(adjs_c, node_flags) # print((score - score.mean(0)).abs().sum().detach().cpu().item()) # print((score.std(0).mean()).detach().cpu().item()) # print(np.array2string(score.std().detach().cpu().numpy(), precision=2, # separator='\t', prefix='\t')) check_adjs_symmetry(score) delta = self.grad_step_size * score new_adjs_c = adjs_c + delta adjs_c = new_adjs_c if log: logging.debug(f"LG MC: step {step:5d}\t|" + "score: {:+.2e}\t|new_score_d: {:+.2e}\t" "|adj_mean: {:.2e}\t|adj_std: {:.2e}\t|delta_abs_mean: {:+.2e}\t|" .format(score.norm(dim=[-1, -2]).mean().item(), score_func(self.adj_to_int(adjs_c, to_int=False)[0], node_flags).norm(dim=[-1, -2]).mean().item(), adjs_c.mean([0, 1, 2]).item(), adjs_c.std([0, 1, 2]).item(), (delta.abs().sum([1, 2]) / node_flags.sum(-1)**2).mean().item()) ) return adjs_c, node_flags
def forward(self, x, adjs, node_flags): x = adjs mask = self.mask.clone() mask = mask_adjs(mask, node_flags) x = x * mask node = x.size(2) x = x.view(x.size(0), 1, node, node) score = self.conv_net(x).view(x.size(0), node, node) score_s = score + score.transpose(-1, -2) return score_s * mask
def gen_init_sample(self, batch_size, max_node_num, node_flags=None): adjs_c = torch.randn((batch_size, max_node_num, max_node_num), dtype=torch.float32).triu(diagonal=1).abs().to(self.dev) adjs_c = (adjs_c + adjs_c.transpose(-1, -2)) # adjs_c = torch.zeros([batch_size, max_node_num, max_node_num], dtype=torch.float32).to(self.dev) # flag_b = torch.ones([batch_size, node_number], dtype=torch.float32).to(dev) if node_flags is None: _, node_flags = self.adj_to_int(adjs_c) else: adjs_c = mask_adjs(adjs_c, node_flags) return adjs_c, node_flags
def forward(self, x, adjs, node_flags, viz=False, save_dir=None, title=''): """ :param x: [num_classes * batch_size, N, F_i], batch of node features :param adjs: [num_classes * batch_size, C_i, N, N], batch of adjacency matrices :param node_flags: [num_classes * batch_size, N, F_i], batch of node_flags, denoting whether a node is effective in each graph :param viz: whether to visualize the intermediate channels :param title: the filename of the output figure (if viz==True) :param save_dir: the directory of the output figure (if viz==True) :return: score: [num_classes * batch_size, N, N], the estimated score """ ori_adjs = adjs.unsqueeze(1) adjs = torch.cat([ori_adjs, 1. - ori_adjs], dim=1) # B x 2 x N x N adjs = mask_adjs(adjs, node_flags) temp_adjs = [adjs] # temp_x = [x] if x is not None else [] for layer in self.gnn_list: x, adjs = layer(x, adjs, node_flags) temp_adjs.append(adjs) # temp_x.append(x) if viz: batch_size = adjs.size(0) // self.num_classes for i in range(self.num_classes): plot_multi_channel_numpy_adjs(adjs=[ adjs[i * batch_size + 0].detach().cpu().numpy() for adjs in temp_adjs ], save_dir=save_dir, title=f's_{i}_' + title) plot_multi_channel_numpy_adjs_1b1(adjs=[ adjs[i * batch_size + 0].detach().cpu().numpy() for adjs in temp_adjs ], save_dir=save_dir, title=f's_{i}_' + title, fig_dir=f'figs_{i}') # break stacked_adjs = torch.cat(temp_adjs, dim=1) # stacked_x = torch.cat(temp_x, dim=-1) # B x N x sum_F_o # stacked_x_pair = node_feature_to_matrix(stacked_x) # B x N x N x (2sum_F_o) mlp_in = stacked_adjs.permute(0, 2, 3, 1) # mlp_in = torch.cat([mlp_in, stacked_x_pair], dim=-1) # B x N x N x (2sum_F_o + sum_C) out_shape = mlp_in.shape[:-1] mlp_out = self.final_read_score(mlp_in) score = mlp_out.view(*out_shape) return score * self.mask
def forward(self, x, adjs, node_flags): """ :param x: B x N x F_i :param adjs: B x C_i x N x N :param node_flags: B x N :return: x_o: B x N x F_o, new_adjs: B x C_o x N x N """ x_o = self.multi_channel_gnn_module(x, adjs, node_flags) # B x N x F_o x_o_pair = node_feature_to_matrix(x_o) # B x N x N x 2F_o last_c_adjs = adjs.permute(0, 2, 3, 1) # B x N x N x C_i mlp_in = torch.cat([last_c_adjs, x_o_pair], dim=-1) # B x N x N x (2F_o+C_i) mlp_in_shape = mlp_in.shape mlp_out = self.translate_mlp(mlp_in.view(-1, mlp_in_shape[-1])) new_adjs = mlp_out.view(mlp_in_shape[0], mlp_in_shape[1], mlp_in_shape[2], -1).permute(0, 3, 1, 2) new_adjs = new_adjs + new_adjs.transpose(-1, -2) # new_adjs = torch.sigmoid(new_adjs) new_adjs = mask_adjs(new_adjs, node_flags) return x_o, new_adjs
def forward(self, x, adjs, node_flags): for i, score_net in enumerate(self.score_net_list): score = score_net(x, adjs, node_flags) score = mask_adjs(score, node_flags) adjs = adjs + score * 0.1 return score