Exemplo n.º 1
0
    def forward(self, out, clusters, groups, primary):
        """
        out:
            array output from the DataParallel gather function
            out[0] - n_gpus tensors of edge indexes
            out[1] - n_gpus tensors of predicted edge weights from model forward
            out[2] - n_gpus arrays of group ids for each cluster
            out[3] - n_gpus number of iterations
        data:
            cluster_labels - n_gpus Nx5 tensors of (x, y, z, batch_id, cluster_id)
            group_labels - n_gpus Nx5 tensors of (x, y, z, batch_id, group_id) 
            em_primaries - n_gpus tensor of (x, y, z) coordinates of origins of EM primaries
        """
        total_loss, total_acc, total_primary_fdr, total_primary_acc, total_iter = 0., 0., 0., 0., 0
        ngpus = len(clusters)
        for i in range(ngpus):
            data0 = clusters[i]
            data1 = groups[i]
            data2 = primary[i]

            clusts = form_clusters_new(data0)

            # remove compton clusters
            # if no cluster fits this condition, return
            if self.remove_compton:
                selection = filter_compton(
                    clusts)  # non-compton looking clusters
                if not len(selection):
                    edge_pred = out[1][i]
                    total_loss += self.lossfn(edge_pred, edge_pred)
                    total_acc += 1.
                    continue

            clusts = clusts[selection]

            # process group data
            data_grp = data1

            # form primary/secondary bipartite graph
            primaries = assign_primaries(data2, clusts, data0)
            batch = get_cluster_batch(data0, clusts)
            # edge_index = primary_bipartite_incidence(batch, primaries)
            group = get_cluster_label(data_grp, clusts)

            primaries_true = assign_primaries(data2,
                                              clusts,
                                              data1,
                                              use_labels=True)
            primary_fdr, primary_tdr, primary_acc = analyze_primaries(
                primaries, primaries_true)
            total_primary_fdr += primary_fdr
            total_primary_acc += primary_acc

            niter = out[3][i][0]  # number of iterations
            total_iter += niter
            for j in range(niter):
                # determine true assignments
                edge_index = out[0][i][j]
                edge_assn = edge_assignment(edge_index,
                                            batch,
                                            group,
                                            cuda=True)

                edge_pred = out[1][i][j]
                # print(edge_pred)

                # print(edge_assn.shape)
                # print(edge_pred.shape)
                edge_assn = edge_assn.view(-1)
                edge_pred = edge_pred.view(-1)
                # print(edge_assn.shape)
                # print(edge_pred.shape)

                if self.balance:
                    edge_assn, edge_pred = self.balance_classes(
                        edge_assn, edge_pred)

                total_loss += self.lossfn(edge_pred, edge_assn)

            # compute accuracy of assignment
            # need to multiply by batch size to be accurate
            #total_acc = (np.max(batch) + 1) * torch.tensor(secondary_matching_vox_efficiency(edge_index, edge_assn, edge_pred, primaries, clusts, len(clusts)))
            # use out['matched']
            total_acc += torch.tensor(
                secondary_matching_vox_efficiency2(out[2][i], group, primaries,
                                                   clusts))

        return {
            'primary_fdr': total_primary_fdr / ngpus,
            'primary_acc': total_primary_acc / ngpus,
            'accuracy': total_acc / ngpus,
            'loss': total_loss / ngpus,
            'n_iter': total_iter
        }
Exemplo n.º 2
0
    def forward(self, out, clusters, groups, primary):
        """
        out:
            array output from the DataParallel gather function
            out[0] - n_gpus tensors of edge indexes
            out[1] - n_gpus tensors of predicted edge weights from model forward
            out[2] - n_gpus arrays of group ids for each cluster
            out[3] - n_gpus number of iterations
        data:
            cluster_labels - n_gpus Nx5 tensors of (x, y, z, batch_id, cluster_id)
            group_labels - n_gpus Nx5 tensors of (x, y, z, batch_id, group_id) 
            em_primaries - n_gpus tensor of (x, y, z) coordinates of origins of EM primaries
        """
        total_loss, total_acc, total_primary_fdr, total_primary_acc, total_iter = 0., 0., 0., 0., 0
        total_ari, total_ami, total_sbd, total_pur, total_eff = 0., 0., 0., 0., 0.
        ngpus = len(clusters)
        for i in range(ngpus):
            data0 = clusters[i]
            data1 = groups[i]
            data2 = primary[i]

            clusts = form_clusters_new(data0)

            # remove compton clusters
            # if no cluster fits this condition, return
            if self.remove_compton:
                selection = filter_compton(
                    clusts)  # non-compton looking clusters
                if not len(selection):
                    edge_pred = out[1][i][0]
                    total_loss += self.lossfn(edge_pred, edge_pred)
                    total_acc += 1.

            clusts = clusts[selection]

            # process group data
            data_grp = data1

            # form primary/secondary bipartite graph
            primaries = assign_primaries(data2, clusts, data0)
            batch = get_cluster_batch(data0, clusts)
            # edge_index = primary_bipartite_incidence(batch, primaries)
            group = get_cluster_label(data_grp, clusts)

            primaries_true = assign_primaries(data2,
                                              clusts,
                                              data1,
                                              use_labels=True)
            primary_fdr, primary_tdr, primary_acc = analyze_primaries(
                primaries, primaries_true)
            total_primary_fdr += primary_fdr
            total_primary_acc += primary_acc

            niter = out[3][i][0]  # number of iterations
            total_iter += niter

            # loop over iterations and add loss at each iter.
            for j in range(niter):
                # determine true assignments
                edge_index = out[0][i][j]
                edge_assn = edge_assignment(edge_index,
                                            batch,
                                            group,
                                            cuda=True,
                                            dtype=torch.long)

                # get edge predictions (2 channels)
                edge_pred = out[1][i][j]

                edge_assn = edge_assn.view(-1)

                total_loss += self.lossfn(edge_pred, edge_assn)

            # compute accuracy of assignment
            total_acc += secondary_matching_vox_efficiency2(
                out[2][i], group, primaries, clusts)

            # get clustering metrics
            #print(out[2][i].shape)
            ari, ami, sbd, pur, eff = DBSCAN_cluster_metrics2(
                out[2][i].cpu().numpy(), clusts, group)
            total_ari += ari
            total_ami += ami
            total_sbd += sbd
            total_pur += pur
            total_eff += eff

        return {
            'primary_fdr': total_primary_fdr / ngpus,
            'primary_acc': total_primary_acc / ngpus,
            'ARI': ari / ngpus,
            'AMI': ami / ngpus,
            'SBD': sbd / ngpus,
            'purity': pur / ngpus,
            'efficiency': eff / ngpus,
            'accuracy': total_acc / ngpus,
            'loss': total_loss / ngpus,
            'n_iter': total_iter
        }
Exemplo n.º 3
0
    def forward(self, out, data0, data1):
        """
        out:
            dictionary output from GNN Model
            keys:
                'edge_pred': predicted edge weights from model forward
        data:
            data[0] - DBSCAN data
            data[1] - groups data
        """
        edge_pred = out[0][0]
        data0 = data0[0]
        data1 = data1[0]

        device = data0.device

        # first decide what true edges should be
        # need to form graph, then pass through GNN
        # clusts = form_clusters(data0)
        clusts = form_clusters_new(data0)

        # remove compton clusters
        # if no cluster fits this condition, return
        if self.remove_compton:
            selection = filter_compton(
                clusts, self.compton_thresh)  # non-compton looking clusters
            if not len(selection):
                total_loss = self.lossfn(edge_pred, edge_pred)
                return {'accuracy': 1., 'loss': total_loss}

            clusts = clusts[selection]

        # process group data
        # data_grp = process_group_data(data1, data0)
        data_grp = data1

        # form graph
        batch = get_cluster_batch(data0, clusts)
        edge_index = complete_graph(batch, device=device)

        if not edge_index.shape[0]:
            total_loss = self.lossfn(edge_pred, edge_pred)
            return {'accuracy': 0., 'loss': total_loss}
        group = get_cluster_label(data_grp, clusts)

        # determine true assignments
        edge_assn = edge_assignment(edge_index,
                                    batch,
                                    group,
                                    device=device,
                                    dtype=torch.long)

        edge_assn = edge_assn.view(-1)

        # total loss on batch
        total_loss = self.lossfn(edge_pred, edge_assn)

        # compute assigned clusters
        fe = edge_pred[1, :] - edge_pred[0, :]
        cs = assign_clusters_UF(edge_index, fe, len(clusts), thresh=0.0)

        ari, ami, sbd, pur, eff = DBSCAN_cluster_metrics2(cs, clusts, group)

        edge_ct = edge_index.shape[1]

        return {
            'ARI': ari,
            'AMI': ami,
            'SBD': sbd,
            'purity': pur,
            'efficiency': eff,
            'accuracy': ari,
            'loss': total_loss,
            'edge_count': edge_ct
        }
Exemplo n.º 4
0
    def forward(self, out, data0, data1):
        """
        out : dictionary, with
            'edge_pred': torch.tensor with edge prediction weights
            'complex'  : simplicial complex of Freudenthal triangulation
            'edges'    : torch tensor with edges used
        data:
            data0 - groups data
            data1 - 5-types data
        """
        dev = data0[0].device

        data_grps = data0[0]
        data_seg = data1[0]
        edge_pred = out['edge_pred'][0][0]
        X = out['complex']
        edge_index = out['edges']

        # 1. compute MST on edge weights edge_pred[:,1] - edge_pred[:,0]
        ft = edge_pred[:, 1] - edge_pred[:, 0]
        # if using MST in loss
        if self.mst:
            # loss is only on MST
            ce_inds = X.CriticalEdgeInds(ft)  # critical edges form MST
            active_edge_index = edge_index[:, ce_inds]
            active_edge_pred = edge_pred[ce_inds, :]
        else:
            # loss is on all edges
            active_edge_index = edge_index
            active_edge_pred = edge_pred

        # 2. get edge labels
        """
        inputs:
        edge_index: torch tensor of edges
        batches: torch tensor of batch id for each node
        groups: torch tensor of group ids for each node
        """
        if self.em_only:
            # select voxels with 5-types classification > 1
            sel = data_seg[:, -1] > 1
            data_grps = data_grps[sel, :]

        batch = data_grps[:, -2]  # get batch from data
        group = data_grps[:, -1]  # get gouprs from data
        edge_assn = edge_assignment(active_edge_index,
                                    batch,
                                    group,
                                    cuda=False,
                                    dtype=torch.long,
                                    device=dev)

        # 3. compute loss, only on critical edges
        # extract critical edges
        loss = self.lossfn(active_edge_pred, edge_assn)

        loss_terms = {'loss_raw': loss.detach().cpu().item()}

        # 3a. add regularization (optional)
        if self.reg:
            ph_out = X(ft)
            if self.reg_ph0 > 0:
                penh0 = self.reg_ph0 * self.regh0(ph_out)
                loss_terms['reg_ph0'] = penh0.detach().cpu().item()
                loss = loss + penh0
            if self.reg_ph1 > 0:
                penh1 = self.reg_ph1 * self.regh1(ph_out)
                loss_terms['reg_ph1'] = penh1.detach().cpu().item()
                loss = loss + penh1

        # 4. compute predicted clustering with some threhsold (0?)
        clusts = X.GetClusters(
            ft, 0.0)  # clusters based on edge being more likely than not
        clusts = np.array(clusts)

        # print(clusts)
        # 5. compute clustering metrics vs. group id.
        group = group.cpu().detach().numpy()
        # print(group)
        sbd = SBD(clusts, group)
        ami = AMI(clusts, group)
        ari = ARI(clusts, group)
        pur, eff = purity_efficiency(clusts, group)

        return {
            'SBD': sbd,
            'AMI': ami,
            'ARI': ari,
            'purity': pur,
            'efficiency': eff,
            'accuracy': ari,
            'loss': loss,
            **loss_terms
        }
Exemplo n.º 5
0
    def forward(self, edge_pred, data0, data1, data2):
        """
        edge_pred:
            predicted edge weights from model forward
        data:
            data[0] - 5 types data
            data[1] - groups data
            data[2] - primary data
        """
        data0 = data0[0]
        data1 = data1[0]
        data2 = data2[0]
        # first decide what true edges should be
        # need to form graph, then pass through GNN
        # clusts = form_clusters(data0)
        clusts = form_clusters_new(data0)

        # remove track-like particles
        # types = get_cluster_label(data0, clusts)
        # selection = types > 1 # 0 or 1 are track-like
        # clusts = clusts[selection]

        # remove compton clusters
        # if no cluster fits this condition, return
        selection = filter_compton(clusts)  # non-compton looking clusters
        if not len(selection):
            total_loss = self.lossfn(edge_pred, edge_pred)
            return {'accuracy': 1., 'loss_seg': total_loss}

        clusts = clusts[selection]

        # process group data
        # data_grp = process_group_data(data1, data0)
        data_grp = data1

        # form primary/secondary bipartite graph
        primaries = assign_primaries(data2, clusts, data0)
        batch = get_cluster_batch(data0, clusts)
        edge_index = primary_bipartite_incidence(batch, primaries)
        group = get_cluster_label(data_grp, clusts)

        primaries_true = assign_primaries(data2,
                                          clusts,
                                          data1,
                                          use_labels=True)
        print("primaries (est):  ", primaries)
        print("primaries (true): ", primaries_true)

        # determine true assignments
        edge_assn = edge_assignment(edge_index, batch, group, cuda=True)

        edge_assn = edge_assn.view(-1)
        edge_pred = edge_pred.view(-1)

        if self.balance:
            # weight edges so that 0/1 labels appear equally often
            ind0 = edge_assn == 0
            ind1 = edge_assn == 1
            # number in each class
            n0 = torch.sum(ind0).float()
            n1 = torch.sum(ind1).float()
            print("n0 = ", n0, " n1 = ", n1)
            # weights to balance classes
            w0 = n1 / (n0 + n1)
            w1 = n0 / (n0 + n1)
            print("w0 = ", w0, " w1 = ", w1)
            edge_assn[ind0] = w0 * edge_assn[ind0]
            edge_assn[ind1] = w1 * edge_assn[ind1]
            edge_pred = edge_pred.clone()
            edge_pred[ind0] = w0 * edge_pred[ind0]
            edge_pred[ind1] = w1 * edge_pred[ind1]

        total_loss = self.lossfn(edge_pred, edge_assn)

        # compute accuracy of assignment
        # need to multiply by batch size to be accurate
        total_acc = (np.max(batch) + 1) * torch.tensor(
            secondary_matching_vox_efficiency(edge_index, edge_assn, edge_pred,
                                              primaries, clusts, len(clusts)))

        return {'accuracy': total_acc, 'loss_seg': total_loss}