Example #1
0
    def forward(self, inputs):
        """

        Args:
            inputs: graph object

        Returns:

        """

        features = np.ones([inputs.n, 1], dtype=np.float32)  # initialize the feature matrix
        features = torch.FloatTensor(features)
        adj_M = torch.FloatTensor(inputs.M)  # adj matrix of input graph
        x = torch.nonzero(adj_M)
        adj_M = utils.to_sparse(adj_M) # convert to coo sparse tensor

        if self.use_cuda:
            adj_M = adj_M.cuda()
            features = features.cuda()

        probs = self.actor(features, adj_M)  # call actor to get a selection distribution
        probs = probs.view(-1)

        m = Categorical(probs)
        node_selected = m.sample()
        # node_selected = probs.multinomial()  # choose the node given by GCN (add .squeeze(1) for batch training)
        log_prob = m.log_prob(node_selected)

        if self.use_critic:  # call critic to compute the value for current state
            critic_current = self.critic(features, adj_M).sum()

        r = inputs.eliminate_node(node_selected, reduce=True)  # reduce the graph and return the nb of edges added

        features = np.ones([inputs.n, 1], dtype=np.float32)  # initialize the feature matrix
        features = torch.FloatTensor(features)
        adj_M = torch.FloatTensor(inputs.M)  # adj matrix of reduced graph
        adj_M = utils.to_sparse(adj_M)  # convert to coo sparse tensor

        if self.use_cuda:
            adj_M = adj_M.cuda()
            features = features.cuda()

        if self.use_critic:  # call critic to compute the value for current state
            critic_next = self.critic(features, adj_M).sum()


        return node_selected, log_prob, r, critic_current, critic_next, inputs
Example #2
0
def spiral_tramsform(transform_fp, template_fp, ds_factors, seq_length,
                     dilation):
    if not osp.exists(transform_fp):
        print('Generating transform matrices...')
        mesh = Mesh(filename=template_fp)
        # ds_factors = [3.5, 3.5, 3.5, 3.5]
        _, A, D, U, F, V = mesh_sampling.generate_transform_matrices(
            mesh, ds_factors)
        tmp = {
            'vertices': V,
            'face': F,
            'adj': A,
            'down_transform': D,
            'up_transform': U
        }

        with open(transform_fp, 'wb') as fp:
            pickle.dump(tmp, fp)
        print('Done!')
        print('Transform matrices are saved in \'{}\''.format(transform_fp))
    else:
        with open(transform_fp, 'rb') as f:
            tmp = pickle.load(f, encoding='latin1')

    spiral_indices_list = [
        utils.preprocess_spiral(tmp['face'][idx], seq_length[idx],
                                tmp['vertices'][idx],
                                dilation[idx])  #.to(device)
        for idx in range(len(tmp['face']) - 1)
    ]

    down_transform_list = [
        utils.to_sparse(down_transform)  #.to(device)
        for down_transform in tmp['down_transform']
    ]
    up_transform_list = [
        utils.to_sparse(up_transform)  #.to(device)
        for up_transform in tmp['up_transform']
    ]

    return spiral_indices_list, down_transform_list, up_transform_list, tmp
Example #3
0
    mesh = Mesh(filename=template_fp)
    ds_factors = [4, 4, 4, 4]
    _, A, D, U, F = mesh_sampling.generate_transform_matrices(mesh, ds_factors)
    tmp = {'face': F, 'adj': A, 'down_transform': D, 'up_transform': U}

    with open(transform_fp, 'wb') as fp:
        pickle.dump(tmp, fp)
    print('Done!')
    print('Transform matrices are saved in \'{}\''.format(transform_fp))
else:
    with open(transform_fp, 'rb') as f:
        tmp = pickle.load(f, encoding='latin1')

edge_index_list = [utils.to_edge_index(adj).to(device) for adj in tmp['adj']]
down_transform_list = [
    utils.to_sparse(down_transform).to(device)
    for down_transform in tmp['down_transform']
]
up_transform_list = [
    utils.to_sparse(up_transform).to(device)
    for up_transform in tmp['up_transform']
]

model = AE(args.in_channels,
           args.out_channels,
           args.latent_channels,
           edge_index_list,
           down_transform_list,
           up_transform_list,
           K=args.K).to(device)
print(model)