Пример #1
0
    def _step_sample(self, step, score_func, adjs, node_flags, log=True):
        adjs_c = adjs
        adjs_c = self._add_sym_normal_noise(adjs_c)
        adjs_c = mask_adjs(adjs_c, node_flags)
        check_adjs_symmetry(adjs_c)

        score = score_func(adjs_c, node_flags)
        # print((score - score.mean(0)).abs().sum().detach().cpu().item())
        # print((score.std(0).mean()).detach().cpu().item())
        # print(np.array2string(score.std().detach().cpu().numpy(), precision=2,
        #                                          separator='\t', prefix='\t'))
        check_adjs_symmetry(score)

        delta = self.grad_step_size * score
        new_adjs_c = adjs_c + delta

        adjs_c = new_adjs_c
        if log:
            logging.debug(f"LG MC: step {step:5d}\t|" +
                          "score: {:+.2e}\t|new_score_d: {:+.2e}\t"
                          "|adj_mean: {:.2e}\t|adj_std: {:.2e}\t|delta_abs_mean: {:+.2e}\t|"
                          .format(score.norm(dim=[-1, -2]).mean().item(),
                                  score_func(self.adj_to_int(adjs_c, to_int=False)[0],
                                             node_flags).norm(dim=[-1, -2]).mean().item(),
                                  adjs_c.mean([0, 1, 2]).item(),
                                  adjs_c.std([0, 1, 2]).item(),
                                  (delta.abs().sum([1, 2]) / node_flags.sum(-1)**2).mean().item())
                          )
        return adjs_c, node_flags
 def forward(self, x, adjs, node_flags):
     x = adjs
     mask = self.mask.clone()
     mask = mask_adjs(mask, node_flags)
     x = x * mask
     node = x.size(2)
     x = x.view(x.size(0), 1, node, node)
     score = self.conv_net(x).view(x.size(0), node, node)
     score_s = score + score.transpose(-1, -2)
     return score_s * mask
Пример #3
0
    def gen_init_sample(self, batch_size, max_node_num, node_flags=None):
        adjs_c = torch.randn((batch_size, max_node_num, max_node_num),
                             dtype=torch.float32).triu(diagonal=1).abs().to(self.dev)
        adjs_c = (adjs_c + adjs_c.transpose(-1, -2))

        # adjs_c = torch.zeros([batch_size, max_node_num, max_node_num], dtype=torch.float32).to(self.dev)
        # flag_b = torch.ones([batch_size, node_number], dtype=torch.float32).to(dev)
        if node_flags is None:
            _, node_flags = self.adj_to_int(adjs_c)
        else:
            adjs_c = mask_adjs(adjs_c, node_flags)
        return adjs_c, node_flags
Пример #4
0
 def forward(self, x, adjs, node_flags, viz=False, save_dir=None, title=''):
     """
     :param x: [num_classes * batch_size, N, F_i], batch of node features
     :param adjs: [num_classes * batch_size, C_i, N, N], batch of adjacency matrices
     :param node_flags: [num_classes * batch_size, N, F_i], batch of node_flags, denoting whether a node is effective in each graph
     :param viz: whether to visualize the intermediate channels
     :param title: the filename of the output figure (if viz==True)
     :param save_dir: the directory of the output figure (if viz==True)
     :return: score: [num_classes * batch_size, N, N], the estimated score
     """
     ori_adjs = adjs.unsqueeze(1)
     adjs = torch.cat([ori_adjs, 1. - ori_adjs], dim=1)  # B x 2 x N x N
     adjs = mask_adjs(adjs, node_flags)
     temp_adjs = [adjs]
     # temp_x = [x] if x is not None else []
     for layer in self.gnn_list:
         x, adjs = layer(x, adjs, node_flags)
         temp_adjs.append(adjs)
         # temp_x.append(x)
     if viz:
         batch_size = adjs.size(0) // self.num_classes
         for i in range(self.num_classes):
             plot_multi_channel_numpy_adjs(adjs=[
                 adjs[i * batch_size + 0].detach().cpu().numpy()
                 for adjs in temp_adjs
             ],
                                           save_dir=save_dir,
                                           title=f's_{i}_' + title)
             plot_multi_channel_numpy_adjs_1b1(adjs=[
                 adjs[i * batch_size + 0].detach().cpu().numpy()
                 for adjs in temp_adjs
             ],
                                               save_dir=save_dir,
                                               title=f's_{i}_' + title,
                                               fig_dir=f'figs_{i}')
             # break
     stacked_adjs = torch.cat(temp_adjs, dim=1)
     # stacked_x = torch.cat(temp_x, dim=-1)  # B x N x sum_F_o
     # stacked_x_pair = node_feature_to_matrix(stacked_x)  # B x N x N x (2sum_F_o)
     mlp_in = stacked_adjs.permute(0, 2, 3, 1)
     # mlp_in = torch.cat([mlp_in, stacked_x_pair], dim=-1)  # B x N x N x (2sum_F_o + sum_C)
     out_shape = mlp_in.shape[:-1]
     mlp_out = self.final_read_score(mlp_in)
     score = mlp_out.view(*out_shape)
     return score * self.mask
Пример #5
0
    def forward(self, x, adjs, node_flags):
        """

        :param x:  B x N x F_i
        :param adjs: B x C_i x N x N
        :param node_flags:  B x N
        :return: x_o: B x N x F_o, new_adjs: B x C_o x N x N
        """
        x_o = self.multi_channel_gnn_module(x, adjs, node_flags)  # B x N x F_o
        x_o_pair = node_feature_to_matrix(x_o)  # B x N x N x 2F_o
        last_c_adjs = adjs.permute(0, 2, 3, 1)  # B x N x N x C_i
        mlp_in = torch.cat([last_c_adjs, x_o_pair], dim=-1)  # B x N x N x (2F_o+C_i)
        mlp_in_shape = mlp_in.shape
        mlp_out = self.translate_mlp(mlp_in.view(-1, mlp_in_shape[-1]))
        new_adjs = mlp_out.view(mlp_in_shape[0], mlp_in_shape[1], mlp_in_shape[2], -1).permute(0, 3, 1, 2)
        new_adjs = new_adjs + new_adjs.transpose(-1, -2)
        # new_adjs = torch.sigmoid(new_adjs)
        new_adjs = mask_adjs(new_adjs, node_flags)
        return x_o, new_adjs
 def forward(self, x, adjs, node_flags):
     for i, score_net in enumerate(self.score_net_list):
         score = score_net(x, adjs, node_flags)
         score = mask_adjs(score, node_flags)
         adjs = adjs + score * 0.1
     return score