def get_features(model, dataloader, device): # form initial cluster centres labels_pos_mask = [] features = [] model.eval() model.to(device) print('getting features') pbar = tqdm.tqdm(total=len(dataloader)) for index, data in enumerate(dataloader): data = utls.batch_to_device(data, device) with torch.no_grad(): res = model(data['image']) clicked_labels = [ item for sublist in data['labels_clicked'] for item in sublist ] to_add = np.zeros(np.unique( data['labels'].cpu().numpy()).shape[0]).astype(bool) to_add[clicked_labels] = True labels_pos_mask.append(to_add) f = sp_pool(res['feats'], data['labels']) features.append(f.detach().cpu().numpy().squeeze()) pbar.update(1) pbar.close() res = [features, labels_pos_mask] return res
def do_prev_clusters(model, device, dataloader, *args): model.eval() model.to(device) prevs = {} pbar = tqdm(total=len(dataloader)) for data in dataloader: data = utls.batch_to_device(data, device) # forward with torch.no_grad(): res = model(data, *args) im = data['image_unnormal'].cpu().squeeze().numpy() im = np.rollaxis(im, 0, 3).astype(np.uint8) truth = data['label/segmentation'].cpu().squeeze().numpy() truth_cntr = segmentation.find_boundaries(truth) im[truth_cntr, ...] = (255, 0, 0) labels = data['labels'].cpu().squeeze().numpy() clusters = res['clusters'].cpu().squeeze().numpy() im = im_utils.make_tiled_clusters(im, labels, clusters) prevs[data['frame_name'][0]] = im pbar.update(1) pbar.close() return prevs
def get_features(model, dataloader, device, return_assign=False, return_obj_preds=False, feat_field='pooled_feats'): # form initial cluster centres labels_pos_mask = [] assignments = [] features = [] obj_preds = [] sigmoid = torch.nn.Sigmoid() model.eval() model.to(device) print('getting features') pbar = tqdm(total=len(dataloader)) for index, data in enumerate(dataloader): data = utls.batch_to_device(data, device) with torch.no_grad(): res = model(data) if (return_assign): assignments.append(res['clusters'].argmax(dim=1).cpu().numpy()) if (return_obj_preds): obj_preds.append(sigmoid(res['rho_hat_pooled']).cpu().numpy()) clicked_labels = [ item for sublist in data['labels_clicked'] for item in sublist ] to_add = np.zeros(np.unique( data['labels'].cpu().numpy()).shape[0]).astype(bool) to_add[clicked_labels] = True labels_pos_mask.append(to_add) features.append(res[feat_field].detach().cpu().numpy().squeeze()) pbar.update(1) pbar.close() res = [features, labels_pos_mask] if (return_assign): res.append(np.concatenate(assignments)) if (return_obj_preds): res.append(obj_preds) return res
def prepare_all(self, all_edges_nn=None, feat_field='pooled_feats'): print('preparing features for linkAgent') # form initial cluster centres self.obj_preds = dict() self.feats_csml = dict() self.feats = dict() self.assignments = dict() edges_list = utls.make_edges_ccl(self.model, self.dl, self.device, return_signed=True) print('getting features') pbar = tqdm.tqdm(total=len(self.dl)) for index, data in enumerate(self.dl): data = utls.batch_to_device(data, self.device) edges_ = edges_list[data['frame_idx'][0]].edge_index with torch.no_grad(): res = self.model(data, edges_nn=edges_.to(self.device)) start = 0 for i, f in enumerate(data['frame_idx']): end = start + torch.unique(data['labels'][i]).numel() self.obj_preds[f] = self.sigmoid( res['rho_hat_pooled'][start:end]).detach().cpu().numpy() self.feats_csml[f] = res['siam_feats'][start:end] self.assignments[f] = res['clusters'][start:end].argmax( dim=1).detach().cpu().numpy() self.feats[f] = res['proj_pooled_feats'][start:end].detach( ).cpu().numpy().squeeze() start += end pbar.update(1) pbar.close() self.obj_preds = [ self.obj_preds[k] for k in sorted(self.obj_preds.keys()) ] self.feats_csml = [ self.feats_csml[k] for k in sorted(self.feats_csml.keys()) ] self.feats = [self.feats[k] for k in sorted(self.feats.keys())] self.assignments = [ self.assignments[k] for k in sorted(self.assignments.keys()) ] self.model.train()
def do_prev_rags(model, device, dataloader, couple_graphs): """ Generate preview images on region adjacency graphs """ model.eval() prevs = {} pbar = tqdm(total=len(dataloader)) for i, data in enumerate(dataloader): data = utls.batch_to_device(data, device) # keep only adjacent edges edges_rag = [ e for e in data['graph'][0].edges() if (data['graph'][0].edges[e]['adjacent']) ] rag = data['graph'][0].edge_subgraph(edges_rag).copy() # forward with torch.no_grad(): res = model(data, torch.tensor(edges_rag)) probas = res['probas_preds'].detach().cpu().squeeze().numpy() im = data['image_unnormal'].cpu().squeeze().numpy().astype(np.uint8) im = np.rollaxis(im, 0, 3) truth = data['label/segmentation'].cpu().squeeze().numpy() labels = data['labels'].cpu().squeeze().numpy() predictions = couple_graphs.nodes[data['frame_idx'] [0]]['clst'].cpu().numpy() predictions = utls.to_onehot(predictions, res['clusters'].shape[1]) clusters_colorized = im_utils.make_clusters(labels, predictions) truth = data['label/segmentation'].cpu().squeeze().numpy() rag_im = im_utils.my_show_rag(rag, im, labels, probas, truth=truth) plot = np.concatenate((im, rag_im, clusters_colorized), axis=1) prevs[data['frame_name'][0]] = plot pbar.update(1) pbar.close() return prevs
def do_update(self, model, dataloader, device, clst_field='clusters'): print('updating targets...') model.eval() clusters = [] for i, data in enumerate(dataloader): data = utls.batch_to_device(data, device) with torch.no_grad(): res = model(data) clusters.append(res[clst_field]) distrib = torch.cat(clusters) tgt = target_distribution(torch.cat(clusters)) splits = [np.unique(s['labels']).size for s in dataloader.dataset] self.tgt_distribs = torch.split(tgt.cpu().detach(), splits) self.distribs = torch.split(distrib.cpu().detach(), splits) curr_assignments = [ torch.argmax(f, dim=-1).cpu().detach().numpy() for f in self.distribs ] curr_assignments = np.concatenate(curr_assignments, axis=0) self.assignments.append(curr_assignments) if (len(self.assignments) > 1): n_changed = np.sum(self.assignments[-1] != self.assignments[-2]) n = self.assignments[-1].size self.ratio_changed = n_changed / n print('ratio_changed: {}'.format(self.ratio_changed)) if (self.ratio_changed < self.thr_assign): self.converged = True model.train()
from ksptrack.siamese import utils as utls from ksptrack.siamese.distrib_buffer import DistribBuffer device = torch.device('cuda') dl = Loader(pjoin('/home/ubelix/artorg/lejeune/data/medical-labeling', 'Dataset30'), normalization='rescale', resize_shape=512) dl = DataLoader(dl, collate_fn=dl.collate_fn, batch_size=2, shuffle=True) model = Siamese(15, 15, backbone='unet') run_path = '/home/ubelix/artorg/lejeune/runs/siamese_dec/Dataset20' cp_path = pjoin(run_path, 'checkpoints', 'init_dec.pth.tar') state_dict = torch.load(cp_path, map_location=lambda storage, loc: storage) model.load_state_dict(state_dict) model.to(device) model.train() criterion = TripletLoss() distrib_buff = DistribBuffer(10, thr_assign=0.0001) distrib_buff.maybe_update(model, dl, device) for data in dl: data = utls.batch_to_device(data, device) _, targets = distrib_buff[data['frame_idx']] res = model(data) loss = criterion(res['proj_pooled_feats'], targets, data['graph'])
def train_one_epoch(model, dataloaders, optimizers, device, distrib_buff, lr_sch, cfg, probas=None, edges_list=None): model.train() running_loss = 0.0 running_clst = 0.0 running_recons = 0.0 running_pw = 0.0 running_obj_pred = 0.0 criterion_clst = torch.nn.KLDivLoss(reduction='mean') criterion_pw = RAGTripletLoss() criterion_obj_pred = torch.nn.BCEWithLogitsLoss() criterion_recons = torch.nn.MSELoss() pbar = tqdm(total=len(dataloaders['train'])) for i, data in enumerate(dataloaders['train']): data = utls.batch_to_device(data, device) # forward with torch.set_grad_enabled(True): for k in optimizers.keys(): optimizers[k].zero_grad() _, targets = distrib_buff[data['frame_idx']] probas_ = torch.cat([probas[i] for i in data['frame_idx']]) edges_ = edges_list[data['frame_idx'][0]].edge_index.to( probas_.device) with torch.autograd.detect_anomaly(): res = model(data, edges_nn=edges_) loss = 0 if (not cfg.fix_clst): loss_clst = criterion_clst(res['clusters'], targets.to(res['clusters'])) loss += cfg.alpha * loss_clst if (cfg.clf): loss_obj_pred = criterion_obj_pred( res['rho_hat_pooled'].squeeze(), (probas_ >= 0.5).float()) loss += cfg.lambda_ * loss_obj_pred loss_recons = criterion_recons(sigmoid(res['output']), data['image_noaug']) loss += cfg.gamma * loss_recons if (cfg.pw): loss_pw = criterion_pw(res['siam_feats'], edges_) loss += cfg.beta * loss_pw with torch.autograd.detect_anomaly(): loss.backward() # torch.nn.utils.clip_grad_norm_(model.parameters(), 0.01) for k in optimizers.keys(): optimizers[k].step() for k in lr_sch.keys(): lr_sch[k].step() running_loss += loss.cpu().detach().numpy() running_recons += loss_recons.cpu().detach().numpy() if (not cfg.fix_clst): running_clst += loss_clst.cpu().detach().numpy() if (cfg.clf): running_obj_pred += loss_obj_pred.cpu().detach().numpy() if (cfg.pw): running_pw += loss_pw.cpu().detach().numpy() loss_ = running_loss / ((i + 1) * cfg.batch_size) pbar.set_description('lss {:.6e}'.format(loss_)) pbar.update(1) pbar.close() # loss_recons = running_recons / (cfg.batch_size * len(dataloaders['train'])) loss_pw = running_pw / (cfg.batch_size * len(dataloaders['train'])) loss_clst = running_clst / (cfg.batch_size * len(dataloaders['train'])) loss_obj_pred = running_obj_pred / (cfg.batch_size * len(dataloaders['train'])) out = { 'loss': loss_, 'loss_pw': loss_pw, 'loss_clst': loss_clst, 'loss_obj_pred': loss_obj_pred, 'loss_recons': loss_recons } return out