Ejemplo n.º 1
0
    def forward(self, data):
        """
        Input:
            data[0]: (Nx5) Cluster tensor with row (x, y, z, batch_id, cluster_id)
        Output:
        dictionary, with
            'node_pred': torch.tensor with node prediction weights
        """
        # Get device
        cluster_label = data[0]
        device = cluster_label.device

        # Find index of points that belong to the same EM clusters
        clusts = form_clusters_new(cluster_label)

        # If requested, remove clusters below a certain size threshold
        if self.remove_compton:
            selection = np.where(filter_compton(clusts,
                                                self.compton_thresh))[0]
            if not len(selection):
                return self.default_return(device)
            clusts = clusts[selection]

        # Get the cluster ids of each processed cluster
        clust_ids = get_cluster_label(cluster_label, clusts)

        # Get the batch ids of each cluster
        batch_ids = get_cluster_batch(cluster_label, clusts)

        # Form a complete graph (should add options for other structures, TODO)
        edge_index = complete_graph(batch_ids, device=device)
        if not edge_index.shape[0]:
            return self.default_return(device)

        # Obtain vertex features
        x = cluster_vtx_features(cluster_label, clusts, device=device)

        # Obtain edge features
        e = cluster_edge_features(cluster_label,
                                  clusts,
                                  edge_index,
                                  device=device)

        # Convert the the batch IDs to a torch tensor to pass to Torch
        xbatch = torch.tensor(batch_ids).to(device)

        # Pass through the model, get output
        out = self.node_predictor(x, edge_index, e, xbatch)

        return {
            **out, 'clust_ids': [torch.tensor(clust_ids)],
            'batch_ids': [torch.tensor(batch_ids)],
            'edge_index': [edge_index]
        }
Ejemplo n.º 2
0
    def forward(self, data):
        """
        inputs data:
            data[0] - dbscan data
        output:
        dictionary, with
            'edge_pred': torch.tensor with edge prediction weights
        """
        # get device
        device = data[0].device

        # need to form graph, then pass through GNN
        clusts = form_clusters_new(data[0])

        # remove compton clusters
        # if no cluster fits this condition, return
        if self.remove_compton:
            selection = filter_compton(
                clusts, self.compton_thresh)  # non-compton looking clusters
            if not len(selection):
                e = torch.tensor([], requires_grad=True)
                e.to(device)
                return {'edge_pred': [e]}

            clusts = clusts[selection]

        # form graph
        batch = get_cluster_batch(data[0], clusts)
        edge_index = complete_graph(batch, device=device)

        if not edge_index.shape[0]:
            e = torch.tensor([], requires_grad=True)
            e.to(device)
            return {'edge_pred': [e]}

        # obtain vertex directions
        x = cluster_vtx_dirs(data[0], clusts, device=device)

        # obtain edge directions
        e = cluster_edge_dirs(data[0], clusts, edge_index, device=device)

        # get x batch
        xbatch = torch.tensor(batch).to(device)

        # get output
        outdict = self.edge_predictor(x, edge_index, e, xbatch)

        return outdict
Ejemplo n.º 3
0
    def forward(self, data):
        """
        inputs data:
            data[0] - dbscan data
        """

        # need to form graph, then pass through GNN
        clusts = form_clusters_new(data[0])

        # remove compton clusters (should we?)
        # if no cluster fits this condition, return
        selection = filter_compton(clusts)  # non-compton looking clusters
        if not len(selection):
            x = torch.tensor([], requires_grad=True)
            if data[0].is_cuda:
                x.cuda()
            return x

        clusts = clusts[selection]

        # form complete graph
        batch = get_cluster_batch(data[0], clusts)
        edge_index = complete_graph(batch, cuda=True)
        if not len(edge_index):
            x = torch.tensor([], requires_grad=True)
            if data[0].is_cuda:
                x.cuda()
            return x

        batch = torch.tensor(batch)
        if data[0].is_cuda:
            batch = batch.cuda()

        # obtain vertex features
        #x = cluster_vtx_features(data[0], clusts, cuda=True)
        x = cluster_vtx_features_old(data[0], clusts, cuda=True)

        # go through layers
        x = self.econv1(x, edge_index)
        x = self.econv2(x, edge_index)
        x = self.econv3(x, edge_index)

        x, e, u = self.predictor(x,
                                 edge_index,
                                 edge_attr=None,
                                 u=None,
                                 batch=batch)
        return F.log_softmax(x, dim=1)
Ejemplo n.º 4
0
    def forward(self, out, data0, data1):
        """
        out:
            dictionary output from GNN Model
            keys:
                'edge_pred': predicted edge weights from model forward
        data:
            data[0] - DBSCAN data
            data[1] - groups data
        """
        edge_pred = out[0][0]
        data0 = data0[0]
        data1 = data1[0]

        device = data0.device

        # first decide what true edges should be
        # need to form graph, then pass through GNN
        # clusts = form_clusters(data0)
        clusts = form_clusters_new(data0)

        # remove compton clusters
        # if no cluster fits this condition, return
        if self.remove_compton:
            selection = filter_compton(
                clusts, self.compton_thresh)  # non-compton looking clusters
            if not len(selection):
                total_loss = self.lossfn(edge_pred, edge_pred)
                return {'accuracy': 1., 'loss': total_loss}

            clusts = clusts[selection]

        # process group data
        # data_grp = process_group_data(data1, data0)
        data_grp = data1

        # form graph
        batch = get_cluster_batch(data0, clusts)
        edge_index = complete_graph(batch, device=device)

        if not edge_index.shape[0]:
            total_loss = self.lossfn(edge_pred, edge_pred)
            return {'accuracy': 0., 'loss': total_loss}
        group = get_cluster_label(data_grp, clusts)

        # determine true assignments
        edge_assn = edge_assignment(edge_index,
                                    batch,
                                    group,
                                    device=device,
                                    dtype=torch.long)

        edge_assn = edge_assn.view(-1)

        # total loss on batch
        total_loss = self.lossfn(edge_pred, edge_assn)

        # compute assigned clusters
        fe = edge_pred[1, :] - edge_pred[0, :]
        cs = assign_clusters_UF(edge_index, fe, len(clusts), thresh=0.0)

        ari, ami, sbd, pur, eff = DBSCAN_cluster_metrics2(cs, clusts, group)

        edge_ct = edge_index.shape[1]

        return {
            'ARI': ari,
            'AMI': ami,
            'SBD': sbd,
            'purity': pur,
            'efficiency': eff,
            'accuracy': ari,
            'loss': total_loss,
            'edge_count': edge_ct
        }