class BaseTrainer(object): def __del__(self): # commented out, because hangs on exit # (presumably some bug with threading in TensorboardX) """ if not self.quiet: self.writer.close() self.writer_val.close() """ pass def __init__(self, args, quiet=False): self.args = args self.quiet = quiet # config # Reading the config if type(args.cfg_file) is str \ and os.path.isfile(args.cfg_file): cfg_from_file(args.cfg_file) if args.set_cfgs is not None: cfg_from_list(args.set_cfgs) self.start_epoch = 0 self.best_score = -1e16 self.checkpoint = Checkpoint(args.snapshot_dir, max_n = 5) if not quiet: #self.model_id = "%s" % args.run logdir = os.path.join(args.logdir, 'train') logdir_val = os.path.join(args.logdir, 'val') self.writer = SummaryWriter(logdir) self.writer_val = SummaryWriter(logdir_val) def _define_checkpoint(self, name, model, optim): self.checkpoint.add_model(name, model, optim) def _load_checkpoint(self, suffix): if self.checkpoint.load(suffix): # loading the epoch and the best score tmpl = re.compile("^e(\d+)Xs([\.\d+\-]+)$") match = tmpl.match(suffix) if not match: print("Warning: epoch and score could not be recovered") return else: epoch, score = match.groups() self.start_epoch = int(epoch) + 1 self.best_score = float(score) def checkpoint_epoch(self, score, epoch): if score > self.best_score: self.best_score = score print(">>> Saving checkpoint with score {:3.2e}, epoch {}".format(score, epoch)) suffix = "e{:03d}Xs{:4.3f}".format(epoch, score) self.checkpoint.checkpoint(suffix) return True def checkpoint_best(self, score, epoch): if score > self.best_score: print(">>> Saving checkpoint with score {:3.2e}, epoch {}".format(score, epoch)) self.best_score= score suffix = "e{:03d}Xs{:4.3f}".format(epoch, score) self.checkpoint.checkpoint(suffix) return True return False @staticmethod def get_optim(params, cfg): if not hasattr(torch.optim, cfg.OPT): print("Optimiser {} not supported".format(cfg.OPT)) raise NotImplementedError optim = getattr(torch.optim, cfg.OPT) if cfg.OPT == 'Adam': upd = torch.optim.Adam(params, lr=cfg.LR, \ betas=(cfg.BETA1, 0.999), \ weight_decay=cfg.WEIGHT_DECAY) elif cfg.OPT == 'SGD': print("Using SGD >>> learning rate = {:4.3e}, momentum = {:4.3e}, weight decay = {:4.3e}".format(cfg.LR, cfg.MOMENTUM, cfg.WEIGHT_DECAY)) upd = torch.optim.SGD(params, lr=cfg.LR, \ momentum=cfg.MOMENTUM, \ weight_decay=cfg.WEIGHT_DECAY) else: upd = optim(params, lr=cfg.LR) upd.zero_grad() return upd @staticmethod def set_lr(optim, lr): for param_group in optim.param_groups: param_group['lr'] = lr def write_image(self, images, epoch): for i, group in enumerate(images): for j, image in enumerate(group): self.writer.add_image("{}/{}".format(i, j), image, epoch) def _visualise_grid(self, x_all, labels, t, ious=None, tag="visualisation", scores=None, save_image=False, epoch=0, index=0, info=None): # adding the labels to images bs, ch, h, w = x_all.size() x_all_new = torch.zeros(bs, ch, h + 85, w) _, y_labels_idx = torch.max(labels, -1) classNamesOffset = len(self.classNames) - labels.size(1) - 1 classNames = self.classNames[classNamesOffset:-1] for b in range(bs): label_idx = labels[b] predict_idx = torch.argmax(scores[b]).item() label_names = [name for i,name in enumerate(classNames) if label_idx[i].item()] predict = classNames[predict_idx] # label_name = + + '\n' row2 = ["Ground truth: " + ", ".join(label_names), "Predict: " + predict, "", ""] for i in range(len(classNames)): row2.append("{} mask".format(classNames[i])) row3 = ["Input image", "Raw output", "PAMR", "Pseudo gt"] for i in range(len(classNames)): row3.append("score: {:.2f}".format(scores[b][i])) row_template = "{:<22}" * (4+len(classNames)) label_name = info[b][:200] + '\n' + info[b][200:] + '\n' + row_template.format(*row2) + '\n' + row_template.format(*row3) ndarr = x_all[b].mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy() arr = np.zeros((85, w, ch), dtype=ndarr.dtype) ndarr = np.concatenate((arr, ndarr), 0) im = Image.fromarray(ndarr) draw = ImageDraw.Draw(im) font = ImageFont.truetype("fonts/UbuntuMono-R.ttf", 20) draw.text((5, 1), label_name, (255,255,255), font=font) if save_image: path = "./logs/images/{}/{}/{}/{}".format(self.args.run, epoch, label_names[0], predict) if not os.path.exists(path): os.makedirs(path) im.save("{}/{:0>4}.{:0>2}.jpg".format(path, index, b)) im_np = np.array(im).astype(np.float) x_all_new[b] = (torch.from_numpy(im_np)/255.0).permute(2,0,1) if not save_image: summary_grid = vutils.make_grid(x_all_new, nrow=1, padding=8, pad_value=0.9) self.writer.add_image(tag, summary_grid, t) def _apply_cmap(self, mask_idx, mask_conf): palette = self.trainloader.dataset.get_palette() masks = [] col = Colorize() mask_conf = mask_conf.float() / 255.0 for mask, conf in zip(mask_idx.split(1), mask_conf.split(1)): m = col(mask).float() m = m * conf masks.append(m[None, ...]) return torch.cat(masks, 0) def _mask_rgb(self, masks, image_norm, alpha=0.3): # visualising masks masks_conf, masks_idx = torch.max(masks, 1) masks_conf = masks_conf - F.relu(masks_conf - 1, 0) masks_idx_rgb = self._apply_cmap(masks_idx.cpu(), masks_conf.cpu()) return alpha * image_norm + (1 - alpha) * masks_idx_rgb def _init_norm(self): self.trainloader.dataset.set_norm(self.enc.normalize) self.valloader.dataset.set_norm(self.enc.normalize) self.trainloader_val.dataset.set_norm(self.enc.normalize)
class BaseTrainer(object): def __init__(self, args, cfg, main_process): self.args = args self.cfg = cfg self.start_epoch = 0 self.best_score = -1e16 self.checkpoint = Checkpoint(args.snapshot_dir, max_n=3) self.main_process = main_process self.writer = None self.writer_target = None if main_process: logdir = os.path.join(args.logdir, 'train') self.writer = SummaryWriter(logdir) self.writer_target = SummaryWriter( os.path.join(args.logdir, 'train_target')) def checkpoint_best(self, score, epoch): if score > self.best_score: print(">>> Saving checkpoint with score {:3.2e}, epoch {}".format( score, epoch)) self.best_score = score self.checkpoint.checkpoint(score, epoch) return True return False @staticmethod def get_optim(params, cfg): if not hasattr(torch.optim, cfg.OPT): print("Optimiser {} not supported".format(cfg.OPT)) raise NotImplementedError optim = getattr(torch.optim, cfg.OPT) if cfg.OPT == 'Adam': print( "Using Adam >>> learning rate = {:4.3e}, momentum = {:4.3e}, weight decay = {:4.3e}" .format(cfg.LR, cfg.MOMENTUM, cfg.WEIGHT_DECAY)) upd = torch.optim.Adam(params, lr=cfg.LR, \ betas=(cfg.BETA1, 0.999), \ weight_decay=cfg.WEIGHT_DECAY) elif cfg.OPT == 'SGD': print( "Using SGD >>> learning rate = {:4.3e}, momentum = {:4.3e}, weight decay = {:4.3e}" .format(cfg.LR, cfg.MOMENTUM, cfg.WEIGHT_DECAY)) upd = torch.optim.SGD(params, lr=cfg.LR, \ momentum=cfg.MOMENTUM, \ nesterov=cfg.OPT_NESTEROV, \ weight_decay=cfg.WEIGHT_DECAY) else: upd = optim(params, lr=cfg.LR) upd.zero_grad() return upd def _visualise(self, epoch, image, masks_gt, logits, writer, tag, image2=None): # gathering def gather_cpu(tensor): out_list = [tensor.clone() for _ in range(self.world_size)] dist.all_gather(out_list, tensor) out_tensor = torch.cat(out_list, 0) return out_tensor.cpu() image = gather_cpu(image) masks_gt = gather_cpu(masks_gt) for key, val in logits.items(): if not val.is_contiguous(): print("Tensor {} is not contiguous".format(key)) #val = val.contiguous() else: logits[key] = gather_cpu(val) if not image2 is None: image2 = gather_cpu(image2) data_palette = self.loader_source.dataset.get_palette() def downsize(x, mode="bilinear"): x = x.float() if x.dim() == 3: x = x.unsqueeze(1) if mode == "nearest": x = F.interpolate(x, self.cfg.TB.IM_SIZE, mode="nearest") else: x = F.interpolate(x, self.cfg.TB.IM_SIZE, mode=mode, align_corners=True) return x.squeeze() def compute_entpy_rgb(x): x = -(x * torch.log(1e-8 + x)).sum(1) x_min = x.min() x_max = x.max() x = (x - x_min) / (x_max - x_min) return self._error_rgb(x) visuals = [] image_norm = downsize(self.denorm(image.clone())).cpu() visuals.append(image_norm) # GT mask masks_gt_rgb = downsize(self._apply_cmap(masks_gt, data_palette)) masks_gt_rgb = 0.3 * image_norm + 0.7 * masks_gt_rgb visuals.append(masks_gt_rgb) if "teacher_labels" in logits: pseudo_gt = logits["teacher_labels"].cpu() masks_gt_rgb = downsize(self._apply_cmap(pseudo_gt, data_palette)) masks_gt_rgb = 0.3 * image_norm + 0.7 * masks_gt_rgb visuals.append(masks_gt_rgb) # Prediction masks = downsize(F.softmax(logits["logits_up"], 1)).cpu() rgb_mask = self._mask_rgb(masks, image_norm, data_palette) masks_conf, masks_idx = masks.max(1) rgb_mask = self._apply_cmap(masks_idx, data_palette) rgb_mask = 0.3 * image_norm + 0.7 * rgb_mask visuals.append(rgb_mask) # Confidence masks_conf_rgb = self._error_rgb(1 - masks_conf, cmap=cm.get_cmap('inferno')) masks_conf_rgb = 0.3 * image_norm + 0.7 * masks_conf_rgb visuals.append(masks_conf_rgb) if image2 is not None: image2_norm = downsize(self.denorm(image2.clone())).cpu() visuals.append(image2_norm) vis_extra = [] def vlogits_rgb(vlogits, frames_, softmax=True): if softmax: vlogits = F.softmax(vlogits, 1) masks = downsize(vlogits) masks_conf, masks_idx = masks.max(1) rgb_mask = self._apply_cmap(masks_idx, data_palette) rgb_mask = 0.3 * frames_ + 0.7 * rgb_mask vis_extra.append(rgb_mask) masks_conf_rgb = self._error_rgb(1 - masks_conf, cmap=cm.get_cmap('inferno')) masks_conf_rgb = 0.3 * frames_ + 0.7 * masks_conf_rgb vis_extra.append(masks_conf_rgb) if "teacher_init" in logits: # slow logits vlogits_rgb(logits["teacher_init"].cpu(), image2_norm) if "teacher_aligned" in logits: frames_aligned = downsize( self.denorm(logits["frames_aligned"].cpu())) vlogits_rgb(logits["teacher_aligned"].cpu(), frames_aligned, softmax=False) if "teacher_refined" in logits: logits_ = logits["teacher_refined"].cpu() vlogits_rgb(logits_.cpu(), image_norm, softmax=False) if "teacher_conf" in logits: teach_conf = downsize(logits["teacher_conf"].cpu()) teach_conf_rgb = self._error_rgb((1. - teach_conf), cmap=cm.get_cmap('inferno')) teach_conf_rgb = 0.3 * image_norm + 0.7 * teach_conf_rgb visuals.append(teach_conf_rgb) visuals += vis_extra visuals = [x.float() for x in visuals] visuals = torch.cat(visuals, -1) if self.main_process: self._visualise_grid(writer, visuals, epoch, tag) if "running_conf" in logits: _, C, _, _ = logits["logits_up"].size() confs = logits["running_conf"].view(-1, C).mean(0).tolist() for ii, conf in enumerate(confs): conf_key = "{:02d}".format(ii) writer.add_scalar('running_conf/{}'.format(conf_key), conf, epoch) def save_fixed_batch(self, key, batch): if self.fixed_batch is None: self.fixed_batch = {} if key in self.fixed_batch: print("Updating fixed batch: ", key) self.fixed_batch[key] = {} batch_items = [] for el in batch: el = el.clone().cpu() if torch.is_tensor(el) else el batch_items.append(el) self.fixed_batch[key] = batch_items def has_fixed_batch(self, key): return (not self.fixed_batch is None and \ key in self.fixed_batch) def _mask_rgb(self, masks, image_norm, palette, alpha=0.3): # visualising masks masks_conf, masks_idx = torch.max(masks, 1) masks_conf = masks_conf - F.relu(masks_conf - 1, 0) masks_idx_rgb = self._apply_cmap(masks_idx.cpu(), palette, mask_conf=masks_conf.cpu()) return alpha * image_norm + (1 - alpha) * masks_idx_rgb def _apply_cmap(self, mask_idx, palette, mask_conf=None): # convert mask to RGB masks_rgb = [] for mask in mask_idx.split(1, 0): mask = mask.cpu().numpy()[0].astype(np.uint32) im = Image.fromarray(mask).convert("P") im.putpalette(palette) mask_rgb = torch.as_tensor(np.array(im.convert("RGB"))) mask_rgb = mask_rgb.permute(2, 0, 1) masks_rgb.append(mask_rgb[None, :, :, :]) # cat back mask_rgb = torch.cat(masks_rgb, 0).float() / 255.0 if not mask_conf is None: # entropy mask_entropy = 1 - mask_conf * torch.log(1e-8 + mask_conf) / ( 0.5 * math.log(1e-8 + 0.5)) mask_rgb *= mask_entropy[:, None, :, :] return mask_rgb def _error_rgb(self, error_mask, cmap=cm.get_cmap('jet')): error_np = error_mask.cpu().numpy() # remove alpha channel error_rgb = cmap(error_np)[:, :, :, :3] error_rgb = np.transpose(error_rgb, (0, 3, 1, 2)) return torch.from_numpy(error_rgb) def _visualise_grid(self, writer, x_all, t, tag, ious=None, scores=None): # adding the labels to images bs, ch, h, w = x_all.size() x_all_new = torch.zeros(bs, ch, h, w) for b in range(bs): ndarr = x_all[b].mul(255).clamp(0, 255).byte().permute( 1, 2, 0).cpu().numpy() im = Image.fromarray(ndarr) im_np = np.array(im).astype(np.float) x_all_new[b] = (torch.from_numpy(im_np) / 255.0).permute(2, 0, 1) summary_grid = vutils.make_grid(x_all_new, nrow=1, padding=8, pad_value=0.9) writer.add_image(tag, summary_grid, t) def visualise_results(self, epoch, writer, tag, step_func): # visualising self.net.eval() with torch.no_grad(): step_func(epoch, self.fixed_batch[tag], \ train=False, visualise=True, \ writer=writer, tag=tag)