예제 #1
0
    def forward(self, data):

        t, pos, batch = data.x, data.pos, data.batch
        pos = pos.cuda()
        batch = batch.cuda()
        t = t.cuda()

        #edge_index = data.edge_index
        # pos = pos.double()
        # batch = batch.long()
        dsize = pos.size()[0]
        bsize = batch[-1].item() + 1
        edge_index = knn_graph(pos, k=30, batch=batch)
        x1 = self.conv1(pos, edge_index)
        edge_index = knn_graph(x1, k=30, batch=batch)
        x2 = self.conv2(x1, edge_index)

        x2max = F.relu(self.lin0(x2))

        x2max = global_max_pool(x2max, batch)
        globalfeats = x2max.repeat(1, int(dsize / bsize)).view(
            dsize,
            x2max.size()[1])

        concat_features = torch.cat((x1, x2, globalfeats), dim=1)

        x = F.relu(self.lin1(concat_features))
        x = F.relu(self.lin2(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin3(x)
        return F.log_softmax(x, dim=-1)
예제 #2
0
    def forward(self, data):
        pos, edge_index, batch = data.pos, data.edge_index, data.batch

        # Build first edges
        edge_index = knn_graph(pos, self.k, batch, loop=False)

        #extract features in 3d
        _, _, features_3d = self.dsc3d(pos, edge_index)
        features_3d = torch.sigmoid(features_3d)
        _, _, features_dd = self.dd(pos, edge_index, features_3d)
        features_dd = torch.sigmoid(features_dd)

        # pooling 80%
        index = fps(pos, batch=batch, ratio=0.2)
        pos = pos[index]
        features = features_dd[index]
        batch = batch[index]
        edge_index = knn_graph(
            pos, self.k, batch,
            loop=False)  #change pos to features for test later!

        # extract features in 3d again
        _, _, features_dd2 = self.dd2(pos, edge_index, features_dd)
        features_dd2 = torch.sigmoid(features_dd2)

        ys = features_dd2.view(self.batch_size, -1, self.out_size_2)
        ys = ys.mean(dim=1).view(-1, self.out_size_2)
        y1 = self.nn1(ys)
        y1 = F.elu(y1)
        y2 = self.nn2(y1)
        y2 = self.sm(y2)

        return y2
예제 #3
0
    def forward(self, pts, batch_ids):
        """
        Input: 
            - data.pos: (B*N, 3) 
            - data.batch: (B*N,)
        Return:
            - out: (B, C), softmax prob
        """

        batch_size = max(batch_ids) + 1
        out = pts
        edge_conv_outs = []
        for edge_conv in self.edge_convs.children():
            # Dynamically update graph
            edge_index = knn_graph(pts, k=self.K, batch=batch_ids)
            out = edge_conv(out, edge_index)
            edge_conv_outs.append(out)
        conv_cats = torch.cat(edge_conv_outs,
                              dim=-1)  # Skip connection to previous features
        out = self.glb_aggr(conv_cats)  # Global aggregation
        glb_feats = scatter_('max', out, index=batch_ids,
                             dim_size=batch_size)  # B, 1024
        glb_feats = glb_feats[batch_ids]  # Expand to B*N, 1024
        out = torch.cat([glb_feats, conv_cats], dim=-1)
        out = self.fc(out)
        return F.log_softmax(out, dim=-1)
예제 #4
0
파일: DirCNN.py 프로젝트: Yannick-S/MTRepo
    def forward(self, data):

        pos, edge_index, batch = data.pos, data.edge_index, data.batch
        real_batch_size = pos.size(0) /self.nr_points
        real_batch_size = int(real_batch_size)

        # Build first edges
        edge_index = knn_graph(pos, self.k, batch, loop=False)

        #extract features in 3d
        _,_,features_dd, _ = self.dd(pos, edge_index, None)

        _,_,features_dd2, _  = self.dd2(pos, edge_index, features_dd)


        y1 = self.nn1(features_dd2)
        y1 = y1.view(real_batch_size, self.nr_points, -1)
        y1 = torch.max(y1, dim=1)[0]
        y1 = torch.nn.functional.relu(y1)
        y1 = self.bn1(y1)

        y2 = self.nn2(y1)
        y2 = torch.nn.functional.relu(y2)
        y2 = self.bn2(y2)

        y3 = self.nn3(y2)
        y3 = torch.nn.functional.relu(y3)
        y3 = self.bn3(y3)

        y4 = self.nn4(y3)
        out = self.sm(y4)
        return out
def get_icosahedron_weights(nodes, depth):
    """Get the icosahedron laplacian list for a certain depth.
    Args:
        nodes (int): initial number of nodes.
        depth (int): the depth of the UNet.
        laplacian_type ["combinatorial", "normalized"]: the type of the laplacian.

    Returns:
        laps (list): increasing list of laplacians.
    """
    edge_list = []
    weight_list = []
    order = icosahedron_order_calculator(nodes)
    for _ in range(depth):
        nodes = icosahedron_nodes_calculator(order)
        order_initial = icosahedron_order_calculator(nodes)
        coords = get_ico_coords(int(order_initial))
        coords = torch.from_numpy(coords)
        edge_index = knn_graph(coords, 6 if order else 5)
        if order:
            dist = torch.norm(coords[edge_index[0]] - coords[edge_index[1]],
                              p=2,
                              dim=1)
            _, extra_idx = torch.topk(dist, 12)
            edge_index[0, extra_idx] = edge_index[1, extra_idx]
            edge_index, _ = remove_self_loops(edge_index)
        edge_list.append(edge_index)
        weight_list.append(None)
        order -= 1
    return edge_list[::-1], weight_list
예제 #6
0
    def forward(self, x):
        x_loc, x_feat = x
        x_new_feat = torch.cat([x_loc, x_feat], dim=1)
        x_new_feat = x_new_feat.transpose(-2, -1)
        x_loc = x_loc.transpose(-2, -1)

        batch_size = x_new_feat.size(0)
        x_batch = torch.ones(x_new_feat.size()[:2],
                             dtype=torch.int64,
                             device=self.device)
        x_batch *= torch.arange(start=0,
                                end=batch_size,
                                dtype=torch.int64,
                                device=self.device).view(-1, 1)
        x_batch = x_batch.flatten()

        x_new_feat = x_new_feat.contiguous().view(-1, x_new_feat.size(-1))
        x_loc = x_loc.contiguous().view(-1, 3)
        edge_index = gnn.knn_graph(x=x_loc, k=self.k, batch=x_batch)
        x_new_feat = self.graph_conv(x_new_feat, edge_index)
        x_new_feat = self.relu(x_new_feat)
        x_new_feat = x_new_feat.view(batch_size, -1,
                                     x_new_feat.size(1)).transpose(-2, -1)
        x_loc = x_loc.view(batch_size, -1, 3).transpose(-2, -1)
        return (x_loc, x_new_feat)
예제 #7
0
	def create_test_knn_graph(self, embeds, batch, args, gin_preds):
		super_class_segregation = {}
		actual_ranking = {}
		count = 0
		super_class_preds = torch.argmax(gin_preds, dim=1).cpu().numpy()
		
		for i, graph in enumerate(batch):
			if super_class_preds[i] not in super_class_segregation.keys():
				super_class_segregation[super_class_preds[i]] = []
			if super_class_preds[i] not in actual_ranking.keys():
				actual_ranking[super_class_preds[i]] = {}
			super_class_segregation[super_class_preds[i]].append(embeds[i].unsqueeze(0))
			actual_ranking[super_class_preds[i]][len(actual_ranking[super_class_preds[i]])] = i

		all_edges = []

		for key, value in super_class_segregation.items():
			knn_value = args.knn_value

			super_class_embeds = torch.cat(value, dim=0)
			super_class_knn = knn_graph(super_class_embeds, knn_value, loop=True)

			actual_super_class_knn = np.zeros((super_class_knn.shape[0], super_class_knn.shape[1])).astype(np.int32)
			for i in range(super_class_knn.shape[0]):
				for j in range(super_class_knn.shape[1]):
					actual_super_class_knn[i, j] = actual_ranking[key][int(super_class_knn[i, j].cpu().numpy())]

			all_edges.append(torch.LongTensor(actual_super_class_knn).cuda())

		return torch.cat(all_edges, dim=1)
예제 #8
0
    def forward(self, data):
        r"""
        data ~ Data(x, y, [z, ][batch, ])
        """
        # print(data)
        target, batch, x = data.y, data.batch, data.x
        # if "noise" in data.keys:
        #     x = torch.cat([x, data.noise], dim=-1)

        for i, (filter, act) in enumerate(zip(self.filters, self.activation)):
            # dynamic graph? yes!
            edge_index = knn_graph(x, k=self.k, batch=batch, loop=False)
            # print(edge_index.shape)
            # NOTE: denselinks added
            y = filter(x, edge_index)
            x = torch.cat((x, y), dim=-1) if i != self.nfilters - 1 else y
            if self.has_activation:
                x = act(x)
        if self.loss_type == "mse":
            loss = mse(x, target)
        elif self.loss_type == "chamfer":
            loss = chamfer_measure(x, target, batch)
        else:
            raise NotImplementedError
        # loss = self.loss(x, target)
        mse_loss = mse(x, target)
        if self.reg is not None:
            reg_loss = self.reg(x, k=self.k, batch=batch) * self.reg_coeff
        else:
            reg_loss = torch.tensor([0.0]).to(loss)
        return x, reg_loss + loss, mse_loss
예제 #9
0
    def forward(self, x, batch=None):
        # apply first dense NN to derive spatial and learned features
        spatial, learned = self.first_dense(x)
        
        # use spatial to generate edge index
        edge_index = knn_graph(spatial, self.n_neighbors, batch, loop=False)
        
        # make the vector of distance weights using kernel
        neighbors = index_select(spatial,0,edge_index[1])
        distances = cdist(spatial,neighbors) # metric='euclidean'
#         distances = torch.from_numpy(cdist2(spatial.detach().numpy(),neighbors.detach().numpy(), metric='euclidean'))
        weights = self.kernel(distances)
#         print("spatial shape: ", spatial.size())
#         print("learned shape: ", learned.size())
#         print("neighbors shape: ", neighbors.shape)
#         print("distances shape: ", distances.shape)
#         print("weights shape: ", weights.shape)
        # use learned for message passing
        messages = [x]
        for messenger in self.messengers:
#             messages.append(torch.from_numpy(messenger.message(learned,weights)))  
            # not actually going to use weights because there is broadcasting error.. need to fix the shapes but how?
            messages.append(messenger.message(learned,weights))

            
        # concatenate features, keep input
        all_features = torch.cat(messages, dim=1)
        
        # apply second dense to get final set of features
        final = self.second_dense(all_features)
        
        return final
예제 #10
0
    def forward(self, data):
        print("counter:" , self.counter)
        pos, edge_index, batch = data.pos, data.edge_index, data.batch

        edge_index = knn_graph(pos, self.k, batch, loop=False)

        y = self.dsc(pos, edge_index) 

        y = torch.sigmoid(y)

        #y = gavgp(y , batch)
        if (self.counter+1) % 600 == 0 or self.counter > 600:
            y3 = y.view(-1, nr_points,self.filter_nr)
            print(y3.std(dim=1).view(-1, self.filter_nr))
            print(y3.std(dim=1).mean())
            print(y3.max(dim=1)[0].view(-1, self.filter_nr))
            color = y[:nr_points,:3].detach().numpy()
            #color = color - color.min()
            #color = color / color.max()
            plot_point_cloud(pos[:nr_points,:].detach().numpy(),color=color)
        y = y.view(-1, nr_points,self.filter_nr)
        y = y.mean(dim=1).view(-1, self.filter_nr)
        y1 = self.nn1(y)
        y1 = F.elu(y1)
        y2 = self.nn2(y1)
        y2 = self.sm(y2) 
            
        self.counter += 1
        return y2
예제 #11
0
    def forward(self, data):
        x, batch = data.x, data.batch
        edge_index = knn_graph(x, 100, batch)
        edge_index, _ = dropout_adj(edge_index, p=0.3)
        batch = data.batch

        x = F.leaky_relu(self.conv1(x, edge_index))
        x1 = torch.cat([gap(x, batch), gmp(x, batch)], dim=1)

        x = F.leaky_relu(self.conv2(x, edge_index))
        x2 = torch.cat([gap(x, batch), gmp(x, batch)], dim=1)

        x = F.leaky_relu(self.conv3(x, edge_index))
        x3 = torch.cat([gap(x, batch), gmp(x, batch)], dim=1)

        x = torch.cat([x1, x2, x3], dim=1)

        x = self.batchnorm1(x)

        x = F.leaky_relu(self.linear1(x))

        x = self.drop(x)
        x = F.leaky_relu(self.linear2(x))
        x = F.leaky_relu(self.linear3(x))
        x = F.leaky_relu(self.linear4(x))
        x = F.leaky_relu(self.linear5(x))

        x = self.out(x)
        if self.classification:
            x = torch.sigmoid(x)
        x = x.view(-1)

        return x
예제 #12
0
    def forward(self, pos, batch):
        edge_index = knn_graph(pos, k=20, batch=batch)
        x = self.conv1(pos, edge_index)

        edge_index = knn_graph(x, k=20, batch=batch)
        x = self.conv2(x, edge_index)

        x = F.relu(self.lin0(x))

        x = global_max_pool(x, batch)

        x = F.relu(self.lin1(x))
        x = F.relu(self.lin2(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin3(x)
        return F.log_softmax(x, dim=-1)
예제 #13
0
    def forward(self, data):
        x, batch = data.x, data.batch
        edge_index = knn_graph(x, 100, batch)                               #?
        edge_index, _ = dropout_adj(edge_index, p=0.3)                      #?
        batch = data.batch

        
        y=data.x
        y=self.point1(y, edge_index)  #dim=n_intermediate
        pointlist=[y]
        for f in range(self.point_depth-1):
            y=self.pointfkt[f](y, edge_index)
            
            pointlist.append(y)
        
        y=torch.cat(pointlist, dim=1) #dim=n_intermediate*point_depth
        y = torch.cat([gap(y, batch), gmp(y, batch)], dim=1)

        
        x = self.batchnorm1(y)
        for g in range(self.lin_depth):
            x=F.leaky_relu(self.linearfkt[g](x))
            if (g-1)%3==0 and self.lin_depth-1>g:  #g=1,4,7,... u. noch mind. zwei weitere Layers
                x = self.drop[g](x)


        x = self.out(x)
        if self.classification:
            x = torch.sigmoid(x)
        x = x.view(-1)

        return x
예제 #14
0
def get_normal(inputs, batch_size, num_points, k=10):
    x = inputs.reshape(
        -1,
        3)  # Mark: PyTorch_Geometric uses a large disconnected sparse graph
    batch = torch.arange(batch_size).repeat_interleave(num_points).cuda()
    edge_index = knn_graph(x, k, batch=batch, loop=True)
    row, col = edge_index
    x = x.unsqueeze(-1) if x.dim() == 1 else x

    # compute centroids
    knn_row = x.index_select(0, row)  # nearest neighbor coordinates
    knn_col = x.index_select(0, col)  # reference coordinates
    mean_v = scatter_('mean', knn_row, col,
                      dim_size=x.size(0))  # geometric mean
    out = knn_row - mean_v.index_select(0, col)

    # reshape to B X N X k X 3
    out = out.reshape(batch_size, num_points, k, 3)

    # Covariance computation
    Cmat = torch.sum(torch.matmul(out.unsqueeze(-1), out[:, :, :, None, :]),
                     2) / k

    # get SVD (size of f must be less than 32)
    Cmat = Cmat.reshape(batch_size * num_points, 3, 3)
    [U, _, _] = batch_svd(Cmat)
    nor = U[:, :, 2]  # normal
    nor = nor.reshape(batch_size, num_points, 3)
    return nor
예제 #15
0
    def forward(self, data):
        x, batch = data.x, data.batch
        edge_index = knn_graph(x, 100, batch)  #?
        edge_index, _ = dropout_adj(edge_index, p=0.3)  #?
        batch = data.batch

        x = F.leaky_relu(self.conv1(x, edge_index))
        x1 = torch.cat([gap(x, batch), gmp(x, batch)], dim=1)
        convlist = [x1]

        for f in range(self.conv_depth - 1):
            x = F.leaky_relu(self.convfkt[f](x, edge_index))
            xi = torch.cat([gap(x, batch), gmp(x, batch)], dim=1)
            convlist.append(xi)

        x = torch.cat(convlist, dim=1)

        x = self.batchnorm1(x)
        for g in range(self.lin_depth):
            x = F.leaky_relu(self.linearfkt[g](x))
            if (
                    g - 1
            ) % 3 == 0 and self.lin_depth - 1 > g:  #g=1,4,7,... u. noch mind. zwei weitere Layers
                x = self.drop[g](x)

        x = self.out(x)
        if self.classification:
            x = torch.sigmoid(x)
        x = x.view(-1)

        return x
예제 #16
0
    def __call__(self, sample):
        if self._visualization:
            keys = sorted([x for x in dir(sample) if 'edge_index' == x])
        else:
            keys = sorted([x for x in dir(sample) if 'edge_index' in x and 'dilated' not in x])

        pos_keys = sorted([x for x in dir(sample) if 'pos' in x])

        if len(self._k) == 1:
            # assume the same 'k' for all hierarchy levels
            self._k = [self._k[0] for _ in range(len(pos_keys))]

        for level, key in enumerate(keys):
            knn_edges = knn_graph(sample[pos_keys[level]], k=self._k[level] * self._d)
            dilated_idx = [index for index in range(knn_edges.shape[1])[0::self._d]]

            if not self._override:
                sample[key.replace('edge_index', 'euclidean_edge_index')] = knn_edges[:, dilated_idx]
            else:
                sample[key] = knn_edges[:, dilated_idx]

        if self._no_pos:
            keys = sorted([x for x in dir(sample) if 'pos_' in x])
            for key in keys:
                delattr(sample, key)

        return sample
예제 #17
0
    def forward(self, data):

        pos, edge_index, batch = data.pos, data.edge_index, data.batch
        real_batch_size = pos.size(0) / self.nr_points
        real_batch_size = int(real_batch_size)

        # Build first edges
        edge_index = knn_graph(pos, self.k, batch, loop=False)

        #extract features in 3d
        _, _, features_dd, _ = self.ds1(pos, edge_index, None)

        #graclus
        cluster = graclus(edge_index)

        pos_gra, batch_gra = avg_pool_x(cluster, pos, batch)
        features_gra, _ = max_pool_x(cluster, features_dd, batch)

        #knn(f)
        with torch.no_grad():
            edge_index_gra = knn_graph(features_gra.norm(dim=2),
                                       self.k,
                                       batch_gra,
                                       loop=False)

        # DD2
        _, _, features_dd2, _ = self.dd2(pos_gra, edge_index_gra, features_gra)

        y1 = self.nn1(features_dd2)

        y1_pool, _ = max_pool_x(batch_gra, y1, batch_gra)

        y1_pool = torch.nn.functional.relu(y1_pool)
        y1_pool = self.bn1(y1_pool)

        y2 = self.nn2(y1_pool)
        y2 = torch.nn.functional.relu(y2)
        y2 = self.bn2(y2)

        y3 = self.nn3(y2)
        y3 = torch.nn.functional.relu(y3)
        y3 = self.bn3(y3)

        y4 = self.nn4(y3)
        out = self.sm(y4)

        return out
예제 #18
0
def gen_knn_index(node_num=100, dim=1, k=8, device=None, **kwargs):
    device = device or getDevice()
    pos = torch.rand([node_num, dim], dtype=torch.float, device=device)
    edge_index = knn_graph(
        pos,
        k,
    )
    return edge_index, node_num
예제 #19
0
 def compute_edges(self, graph):
     # for details on flow argument see:
     # 1. https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/transforms/knn_graph.html?highlight=knn_graph
     # 2. https://github.com/rusty1s/pytorch_geometric/issues/126
     # PS (in our experiments we assumed that visual descriptors were l2 normalized)
     return knn_graph(graph.x_v,
                      self.pre_compute_edges,
                      loop=True,
                      flow="target_to_source")
예제 #20
0
 def forward(self, pts):
     batch_size, num_pts, _ = pts.shape
     out = []
     flag = pts.is_cuda
     for batch in range(batch_size):
         edge_index = knn_graph(pts[batch], self.K)
         if flag:
             edge_index = edge_index.cuda()
         out.append(edge_index)
     return out
    def __call__(self, data):
        data.edge_attr = None
        batch = data.batch if 'batch' in data else None
        edge_index = knn_graph(data.pos, self.k, batch, loop=self.loop, flow=self.flow)

        if self.force_undirected:
            edge_index = to_undirected(edge_index, num_nodes=data.num_nodes)

        data.edge_index = edge_index

        return data
예제 #22
0
def gen_knn_graph(node_data: Union[int, torch.Tensor],
                  k=10,
                  pos_feature: bool = True) -> Data:
    edge_index, pos, node_feat = gen_graph_data(node_data)

    edge_index = knn_graph(pos, k=k, loop=True)

    graph = Data(x=node_feat, edge_index=edge_index, pos=pos)
    graph = Distance(norm=False, cat=False)(graph)

    return graph
예제 #23
0
    def forward(self, data):
        # print(data)
        target, batch, x = data.y, data.batch, data.x
        for i, (layer, activation) in enumerate(zip(self.gats, self.activation)):
            # use dynamic graph
            edge_index = knn_graph(x, k=32, batch=batch, loop=False)
            x = layer(x, edge_index=edge_index)
            x = activation(x)

        x = self.cls(x)  # assume we have normalized/softmaxed prob here.
        loss = self.criterion((x + 1e-8).log(), target)
        return loss, x
예제 #24
0
    def forward(self, data):
        # print(data)
        target, batch, x = data.y, data.batch, data.x
        for i, (filter,
                activation) in enumerate(zip(self.filters, self.activation)):
            edge_index = knn_graph(x, k=32, batch=batch, loop=False)
            x = filter(x, edge_index=edge_index)
            if self.has_activation:
                x = activation(x)

        loss = mse(x, target)
        return x, loss
예제 #25
0
    def forward(self, batch):
        x, u, batch = batch.x, batch.u, batch.batch

        # x = x.mean(1)
        x = self.project(x)

        # batch = batch.view(batch.size(0), 1).repeat(1, 2).view(batch.size(0) * 2)
        # x = x.view(x.size(0) * x.size(1), x.size(2))

        k = 10

        edge_index = gnn.knn_graph(x,
                                   k,
                                   batch,
                                   loop=False,
                                   flow='source_to_target')
        edge_attr = torch.zeros(edge_index.size(1), 0, device=x.device)
        u = torch.zeros(u.size(0), 0, device=x.device)
        x, _, _ = self.layer_1(x, edge_index, edge_attr, u, batch)
        edge_index = gnn.knn_graph(x,
                                   k,
                                   batch,
                                   loop=False,
                                   flow='source_to_target')
        edge_attr = torch.zeros(edge_index.size(1), 0, device=x.device)
        u = torch.zeros(u.size(0), 0, device=x.device)
        x, _, _ = self.layer_2(x, edge_index, edge_attr, u, batch)
        edge_index = gnn.knn_graph(x,
                                   k,
                                   batch,
                                   loop=False,
                                   flow='source_to_target')
        edge_attr = torch.zeros(edge_index.size(1), 0, device=x.device)
        u = torch.zeros(u.size(0), 0, device=x.device)
        x, _, _ = self.layer_3(x, edge_index, edge_attr, u, batch)

        logits = self.output(x)

        return logits
예제 #26
0
    def forward(self, data):
        target, batch, x = data.y, data.batch, data.x
        for i, (filter,
                activation) in enumerate(zip(self.filters, self.activation)):
            edge_index = knn_graph(x, k=32, batch=batch, loop=False)
            row, col = edge_index
            edge_attr = x[row] - x[col]
            # print(edge_attr.shape, edge_index.shape)
            # e_ij = x_i - x_j
            x = filter(x, edge_index=edge_index, edge_attr=edge_attr)
            if self.has_activation:
                x = activation(x)

        loss = mse(x, target)
        return x, loss
예제 #27
0
    def forward(self, data):
        pos, edge_index, batch = data.pos, data.edge_index, data.batch

        edge_index = knn_graph(pos, self.k, batch, loop=False)

        y = self.dsc(pos, edge_index) 

        y = torch.sigmoid(y)
        ys = y.view(-1, self.nr_points , self.filter_nr)
        ys = ys.mean(dim=1).view(-1, self.filter_nr)
        y1 = self.nn1(ys)
        y1 = F.elu(y1)
        y2 = self.nn2(y1)
        y2 = self.sm(y2) 
            
        return y2
예제 #28
0
 def forward(self, x, k, edge_index=None, batch=None):
     r"""
     Calculate graph regularization term
     $R=||X^T L X||_F$
     """
     num_nodes = x.shape[-2]
     xdim = x.shape[-1]
     if edge_index is None:
         edge_index = knn_graph(x, k=k, batch=batch, loop=False)
     lap_index, lap_val = get_laplacian(edge_index,
                                        normalization="rw",
                                        num_nodes=num_nodes)
     res = self.propagate(edge_index=lap_index, x=x, edge_weight=lap_val)
     # print(res.shape)  # [B, F * F]
     # Frobenius Norm (intrinstically same)
     return (torch.norm(res, dim=-1, p="fro")**2).mean()
예제 #29
0
    def __init__(self, input_dim: int, output_dim: int,
                 adjacency_matrix: torch.Tensor, position: torch.Tensor,
                 neighbors: int):
        super(GCNLayer_PyG, self).__init__()
        self.BN = nn.BatchNorm1d(input_dim)
        self.Activition1 = nn.LeakyReLU(inplace=True)

        self.GCN_liner_out_1 = nn.Sequential(nn.Linear(input_dim, output_dim))
        self.GCN_liner_theta_1 = nn.Sequential(nn.Linear(input_dim, 128))

        self.position = position
        self.neighbors = neighbors

        self.a = nn.Parameter(
            torch.ones(size=(1, 1), requires_grad=True, device=device))
        self.b = nn.Parameter(
            torch.ones(size=(1, 1), requires_grad=True, device=device))
        self.lambda_ = nn.Parameter(torch.zeros(1))

        self.theta1 = nn.Sequential(nn.Linear(input_dim, 1))  #,nn.Sigmoid()
        if self.neighbors > 0:
            self.neighbors = self.neighbors + 1  # self-loop
            self.col, self.row = self.edge_index = knn_graph(self.position,
                                                             self.neighbors,
                                                             batch=None,
                                                             loop=True)  #
        else:
            # unfixed neighbors
            self.I = torch.eye(adjacency_matrix.shape[0],
                               adjacency_matrix.shape[0],
                               requires_grad=False,
                               device=device,
                               dtype=torch.float32)
            self.mask = torch.ceil(adjacency_matrix * 0.00001)
            self.index, _ = dense_to_sparse(adjacency_matrix.contiguous() +
                                            self.I)
            self.row, self.col = self.index

        ########################spatial distance##########################
        if self.neighbors > 0:
            self.Spatial_Distance = torch.square(
                torch.norm(self.position[self.col] - self.position[self.row],
                           dim=-1))
예제 #30
0
    def forward(self, data):
        pos, edge_index, batch = data.pos, data.edge_index, data.batch

        # Build first edges
        edge_index = knn_graph(pos, self.k, batch, loop=False)

        #extract features in 3d
        _, _, features_3d = self.dsc3d(pos, edge_index)
        features_3d = torch.sigmoid(features_3d)
        _, _, features_dd = self.dd(pos, edge_index, features_3d)
        features_dd = torch.sigmoid(features_dd)

        ys = features_dd.view(self.batch_size, -1, self.out_size)
        ys = ys.mean(dim=1).view(-1, self.out_size)
        y1 = self.nn1(ys)
        y1 = F.elu(y1)
        y2 = self.nn2(y1)
        y2 = self.sm(y2)

        return y2