Ejemplo n.º 1
0
def get_features(model, dataloader, device):
    # form initial cluster centres
    labels_pos_mask = []
    features = []

    model.eval()
    model.to(device)
    print('getting features')
    pbar = tqdm.tqdm(total=len(dataloader))
    for index, data in enumerate(dataloader):
        data = utls.batch_to_device(data, device)
        with torch.no_grad():
            res = model(data['image'])

        clicked_labels = [
            item for sublist in data['labels_clicked'] for item in sublist
        ]

        to_add = np.zeros(np.unique(
            data['labels'].cpu().numpy()).shape[0]).astype(bool)
        to_add[clicked_labels] = True
        labels_pos_mask.append(to_add)

        f = sp_pool(res['feats'], data['labels'])
        features.append(f.detach().cpu().numpy().squeeze())
        pbar.update(1)
    pbar.close()

    res = [features, labels_pos_mask]

    return res
Ejemplo n.º 2
0
def do_prev_clusters(model, device, dataloader, *args):

    model.eval()
    model.to(device)

    prevs = {}

    pbar = tqdm(total=len(dataloader))

    for data in dataloader:
        data = utls.batch_to_device(data, device)

        # forward
        with torch.no_grad():
            res = model(data, *args)

        im = data['image_unnormal'].cpu().squeeze().numpy()
        im = np.rollaxis(im, 0, 3).astype(np.uint8)
        truth = data['label/segmentation'].cpu().squeeze().numpy()
        truth_cntr = segmentation.find_boundaries(truth)
        im[truth_cntr, ...] = (255, 0, 0)
        labels = data['labels'].cpu().squeeze().numpy()
        clusters = res['clusters'].cpu().squeeze().numpy()
        im = im_utils.make_tiled_clusters(im, labels, clusters)
        prevs[data['frame_name'][0]] = im

        pbar.update(1)
    pbar.close()

    return prevs
Ejemplo n.º 3
0
def get_features(model,
                 dataloader,
                 device,
                 return_assign=False,
                 return_obj_preds=False,
                 feat_field='pooled_feats'):
    # form initial cluster centres
    labels_pos_mask = []
    assignments = []
    features = []
    obj_preds = []

    sigmoid = torch.nn.Sigmoid()
    model.eval()
    model.to(device)
    print('getting features')
    pbar = tqdm(total=len(dataloader))
    for index, data in enumerate(dataloader):
        data = utls.batch_to_device(data, device)
        with torch.no_grad():
            res = model(data)

        if (return_assign):
            assignments.append(res['clusters'].argmax(dim=1).cpu().numpy())

        if (return_obj_preds):
            obj_preds.append(sigmoid(res['rho_hat_pooled']).cpu().numpy())

        clicked_labels = [
            item for sublist in data['labels_clicked'] for item in sublist
        ]

        to_add = np.zeros(np.unique(
            data['labels'].cpu().numpy()).shape[0]).astype(bool)
        to_add[clicked_labels] = True
        labels_pos_mask.append(to_add)

        features.append(res[feat_field].detach().cpu().numpy().squeeze())
        pbar.update(1)
    pbar.close()

    res = [features, labels_pos_mask]

    if (return_assign):
        res.append(np.concatenate(assignments))

    if (return_obj_preds):
        res.append(obj_preds)

    return res
Ejemplo n.º 4
0
    def prepare_all(self, all_edges_nn=None, feat_field='pooled_feats'):
        print('preparing features for linkAgent')
        # form initial cluster centres
        self.obj_preds = dict()
        self.feats_csml = dict()
        self.feats = dict()
        self.assignments = dict()

        edges_list = utls.make_edges_ccl(self.model,
                                         self.dl,
                                         self.device,
                                         return_signed=True)

        print('getting features')
        pbar = tqdm.tqdm(total=len(self.dl))
        for index, data in enumerate(self.dl):
            data = utls.batch_to_device(data, self.device)
            edges_ = edges_list[data['frame_idx'][0]].edge_index

            with torch.no_grad():
                res = self.model(data, edges_nn=edges_.to(self.device))

            start = 0
            for i, f in enumerate(data['frame_idx']):
                end = start + torch.unique(data['labels'][i]).numel()
                self.obj_preds[f] = self.sigmoid(
                    res['rho_hat_pooled'][start:end]).detach().cpu().numpy()
                self.feats_csml[f] = res['siam_feats'][start:end]
                self.assignments[f] = res['clusters'][start:end].argmax(
                    dim=1).detach().cpu().numpy()
                self.feats[f] = res['proj_pooled_feats'][start:end].detach(
                ).cpu().numpy().squeeze()
                start += end

            pbar.update(1)
        pbar.close()

        self.obj_preds = [
            self.obj_preds[k] for k in sorted(self.obj_preds.keys())
        ]
        self.feats_csml = [
            self.feats_csml[k] for k in sorted(self.feats_csml.keys())
        ]
        self.feats = [self.feats[k] for k in sorted(self.feats.keys())]
        self.assignments = [
            self.assignments[k] for k in sorted(self.assignments.keys())
        ]

        self.model.train()
Ejemplo n.º 5
0
def do_prev_rags(model, device, dataloader, couple_graphs):
    """
    Generate preview images on region adjacency graphs
    """

    model.eval()

    prevs = {}

    pbar = tqdm(total=len(dataloader))
    for i, data in enumerate(dataloader):
        data = utls.batch_to_device(data, device)

        # keep only adjacent edges
        edges_rag = [
            e for e in data['graph'][0].edges()
            if (data['graph'][0].edges[e]['adjacent'])
        ]
        rag = data['graph'][0].edge_subgraph(edges_rag).copy()

        # forward
        with torch.no_grad():
            res = model(data, torch.tensor(edges_rag))

        probas = res['probas_preds'].detach().cpu().squeeze().numpy()
        im = data['image_unnormal'].cpu().squeeze().numpy().astype(np.uint8)
        im = np.rollaxis(im, 0, 3)
        truth = data['label/segmentation'].cpu().squeeze().numpy()
        labels = data['labels'].cpu().squeeze().numpy()

        predictions = couple_graphs.nodes[data['frame_idx']
                                          [0]]['clst'].cpu().numpy()
        predictions = utls.to_onehot(predictions, res['clusters'].shape[1])
        clusters_colorized = im_utils.make_clusters(labels, predictions)
        truth = data['label/segmentation'].cpu().squeeze().numpy()
        rag_im = im_utils.my_show_rag(rag, im, labels, probas, truth=truth)
        plot = np.concatenate((im, rag_im, clusters_colorized), axis=1)
        prevs[data['frame_name'][0]] = plot

        pbar.update(1)
    pbar.close()

    return prevs
Ejemplo n.º 6
0
    def do_update(self, model, dataloader, device, clst_field='clusters'):

        print('updating targets...')
        model.eval()
        clusters = []
        for i, data in enumerate(dataloader):
            data = utls.batch_to_device(data, device)

            with torch.no_grad():
                res = model(data)
            clusters.append(res[clst_field])

        distrib = torch.cat(clusters)
        tgt = target_distribution(torch.cat(clusters))

        splits = [np.unique(s['labels']).size for s in dataloader.dataset]
        self.tgt_distribs = torch.split(tgt.cpu().detach(), splits)
        self.distribs = torch.split(distrib.cpu().detach(), splits)

        curr_assignments = [
            torch.argmax(f, dim=-1).cpu().detach().numpy()
            for f in self.distribs
        ]
        curr_assignments = np.concatenate(curr_assignments, axis=0)

        self.assignments.append(curr_assignments)

        if (len(self.assignments) > 1):

            n_changed = np.sum(self.assignments[-1] != self.assignments[-2])
            n = self.assignments[-1].size
            self.ratio_changed = n_changed / n
            print('ratio_changed: {}'.format(self.ratio_changed))
            if (self.ratio_changed < self.thr_assign):
                self.converged = True
        model.train()
Ejemplo n.º 7
0
    from ksptrack.siamese import utils as utls
    from ksptrack.siamese.distrib_buffer import DistribBuffer

    device = torch.device('cuda')

    dl = Loader(pjoin('/home/ubelix/artorg/lejeune/data/medical-labeling',
                      'Dataset30'),
                normalization='rescale',
                resize_shape=512)
    dl = DataLoader(dl, collate_fn=dl.collate_fn, batch_size=2, shuffle=True)

    model = Siamese(15, 15, backbone='unet')

    run_path = '/home/ubelix/artorg/lejeune/runs/siamese_dec/Dataset20'
    cp_path = pjoin(run_path, 'checkpoints', 'init_dec.pth.tar')
    state_dict = torch.load(cp_path, map_location=lambda storage, loc: storage)
    model.load_state_dict(state_dict)
    model.to(device)
    model.train()
    criterion = TripletLoss()

    distrib_buff = DistribBuffer(10, thr_assign=0.0001)
    distrib_buff.maybe_update(model, dl, device)

    for data in dl:
        data = utls.batch_to_device(data, device)

        _, targets = distrib_buff[data['frame_idx']]
        res = model(data)
        loss = criterion(res['proj_pooled_feats'], targets, data['graph'])
Ejemplo n.º 8
0
def train_one_epoch(model,
                    dataloaders,
                    optimizers,
                    device,
                    distrib_buff,
                    lr_sch,
                    cfg,
                    probas=None,
                    edges_list=None):

    model.train()

    running_loss = 0.0
    running_clst = 0.0
    running_recons = 0.0
    running_pw = 0.0
    running_obj_pred = 0.0

    criterion_clst = torch.nn.KLDivLoss(reduction='mean')
    criterion_pw = RAGTripletLoss()
    criterion_obj_pred = torch.nn.BCEWithLogitsLoss()
    criterion_recons = torch.nn.MSELoss()

    pbar = tqdm(total=len(dataloaders['train']))
    for i, data in enumerate(dataloaders['train']):
        data = utls.batch_to_device(data, device)

        # forward
        with torch.set_grad_enabled(True):
            for k in optimizers.keys():
                optimizers[k].zero_grad()

            _, targets = distrib_buff[data['frame_idx']]
            probas_ = torch.cat([probas[i] for i in data['frame_idx']])
            edges_ = edges_list[data['frame_idx'][0]].edge_index.to(
                probas_.device)
            with torch.autograd.detect_anomaly():
                res = model(data, edges_nn=edges_)

            loss = 0

            if (not cfg.fix_clst):
                loss_clst = criterion_clst(res['clusters'],
                                           targets.to(res['clusters']))
                loss += cfg.alpha * loss_clst

            if (cfg.clf):
                loss_obj_pred = criterion_obj_pred(
                    res['rho_hat_pooled'].squeeze(), (probas_ >= 0.5).float())
                loss += cfg.lambda_ * loss_obj_pred

            loss_recons = criterion_recons(sigmoid(res['output']),
                                           data['image_noaug'])
            loss += cfg.gamma * loss_recons

            if (cfg.pw):
                loss_pw = criterion_pw(res['siam_feats'], edges_)
                loss += cfg.beta * loss_pw

            with torch.autograd.detect_anomaly():
                loss.backward()

            # torch.nn.utils.clip_grad_norm_(model.parameters(), 0.01)

            for k in optimizers.keys():
                optimizers[k].step()

            for k in lr_sch.keys():
                lr_sch[k].step()

        running_loss += loss.cpu().detach().numpy()
        running_recons += loss_recons.cpu().detach().numpy()
        if (not cfg.fix_clst):
            running_clst += loss_clst.cpu().detach().numpy()
        if (cfg.clf):
            running_obj_pred += loss_obj_pred.cpu().detach().numpy()
        if (cfg.pw):
            running_pw += loss_pw.cpu().detach().numpy()
        loss_ = running_loss / ((i + 1) * cfg.batch_size)
        pbar.set_description('lss {:.6e}'.format(loss_))
        pbar.update(1)

    pbar.close()

    # loss_recons = running_recons / (cfg.batch_size * len(dataloaders['train']))
    loss_pw = running_pw / (cfg.batch_size * len(dataloaders['train']))
    loss_clst = running_clst / (cfg.batch_size * len(dataloaders['train']))
    loss_obj_pred = running_obj_pred / (cfg.batch_size *
                                        len(dataloaders['train']))

    out = {
        'loss': loss_,
        'loss_pw': loss_pw,
        'loss_clst': loss_clst,
        'loss_obj_pred': loss_obj_pred,
        'loss_recons': loss_recons
    }

    return out