Esempio n. 1
0
    # loading the model
    args = get_arguments(sys.argv[1:])

    # reading the config
    cfg_from_file(args.cfg_file)
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)

    # initialising the dirs
    check_dir(args.mask_output_dir, "vis")
    check_dir(args.mask_output_dir, "crf")

    # Loading the model
    model = get_model(cfg.NET, num_classes=cfg.TEST.NUM_CLASSES)
    checkpoint = Checkpoint(args.snapshot_dir, max_n=5)
    checkpoint.add_model('enc', model)
    checkpoint.load(args.resume)

    for p in model.parameters():
        p.requires_grad = False

    # setting the evaluation mode
    model.eval()

    assert hasattr(model, 'normalize')
    transform = tf.Compose([np.asarray, model.normalize])

    WriterClass, DatasetClass = get_inference_io(cfg.TEST.METHOD)

    dataset = DatasetClass(args.infer_list, cfg.TEST, transform=transform)
Esempio n. 2
0
class BaseTrainer(object):

    def __del__(self):
        # commented out, because hangs on exit
        # (presumably some bug with threading in TensorboardX)
        """
        if not self.quiet:
            self.writer.close()
            self.writer_val.close()
        """
        pass

    def __init__(self, args, quiet=False):
        self.args = args
        self.quiet = quiet

        # config
        # Reading the config
        if type(args.cfg_file) is str \
                and os.path.isfile(args.cfg_file):

            cfg_from_file(args.cfg_file)
            if args.set_cfgs is not None:
                cfg_from_list(args.set_cfgs)

        self.start_epoch = 0
        self.best_score = -1e16
        self.checkpoint = Checkpoint(args.snapshot_dir, max_n = 5)

        if not quiet:
            #self.model_id = "%s" % args.run
            logdir = os.path.join(args.logdir, 'train')
            logdir_val = os.path.join(args.logdir, 'val')

            self.writer = SummaryWriter(logdir)
            self.writer_val = SummaryWriter(logdir_val)

    def _define_checkpoint(self, name, model, optim):
        self.checkpoint.add_model(name, model, optim)

    def _load_checkpoint(self, suffix):
        if self.checkpoint.load(suffix):
            # loading the epoch and the best score
            tmpl = re.compile("^e(\d+)Xs([\.\d+\-]+)$")
            match = tmpl.match(suffix)
            if not match:
                print("Warning: epoch and score could not be recovered")
                return
            else:
                epoch, score = match.groups()
                self.start_epoch = int(epoch) + 1
                self.best_score = float(score)

    def checkpoint_epoch(self, score, epoch):

        if score > self.best_score:
            self.best_score = score

        print(">>> Saving checkpoint with score {:3.2e}, epoch {}".format(score, epoch))
        suffix = "e{:03d}Xs{:4.3f}".format(epoch, score)
        self.checkpoint.checkpoint(suffix)

        return True

    def checkpoint_best(self, score, epoch):

        if score > self.best_score:
            print(">>> Saving checkpoint with score {:3.2e}, epoch {}".format(score, epoch))
            self.best_score= score

            suffix = "e{:03d}Xs{:4.3f}".format(epoch, score)
            self.checkpoint.checkpoint(suffix)

            return True

        return False

    @staticmethod
    def get_optim(params, cfg):

        if not hasattr(torch.optim, cfg.OPT):
            print("Optimiser {} not supported".format(cfg.OPT))
            raise NotImplementedError

        optim = getattr(torch.optim, cfg.OPT)

        if cfg.OPT == 'Adam':
            upd = torch.optim.Adam(params, lr=cfg.LR, \
                                   betas=(cfg.BETA1, 0.999), \
                                   weight_decay=cfg.WEIGHT_DECAY)
        elif cfg.OPT == 'SGD':
            print("Using SGD >>> learning rate = {:4.3e}, momentum = {:4.3e}, weight decay = {:4.3e}".format(cfg.LR, cfg.MOMENTUM, cfg.WEIGHT_DECAY))
            upd = torch.optim.SGD(params, lr=cfg.LR, \
                                  momentum=cfg.MOMENTUM, \
                                  weight_decay=cfg.WEIGHT_DECAY)

        else:
            upd = optim(params, lr=cfg.LR)

        upd.zero_grad()

        return upd
    
    @staticmethod
    def set_lr(optim, lr):
        for param_group in optim.param_groups:
            param_group['lr'] = lr

    def write_image(self, images, epoch):
        for i, group in enumerate(images):
            for j, image in enumerate(group):
                self.writer.add_image("{}/{}".format(i, j), image, epoch)

    def _visualise_grid(self, x_all, labels, t, ious=None, tag="visualisation", scores=None, save_image=False, epoch=0, index=0, info=None):

        # adding the labels to images
        bs, ch, h, w = x_all.size()
        x_all_new = torch.zeros(bs, ch, h + 85, w)
        _, y_labels_idx = torch.max(labels, -1)
        classNamesOffset = len(self.classNames) - labels.size(1) - 1
        classNames = self.classNames[classNamesOffset:-1]
        for b in range(bs):
            label_idx = labels[b]
            predict_idx = torch.argmax(scores[b]).item()
            label_names = [name for i,name in enumerate(classNames) if label_idx[i].item()]
            predict = classNames[predict_idx]


            # label_name =  +  + '\n'
            row2 = ["Ground truth: " + ", ".join(label_names), "Predict: " + predict, "", ""]
            for i in range(len(classNames)):
                row2.append("{} mask".format(classNames[i]))
            row3 = ["Input image", "Raw output", "PAMR", "Pseudo gt"]
            for i in range(len(classNames)):
                row3.append("score: {:.2f}".format(scores[b][i]))
            row_template = "{:<22}" * (4+len(classNames))
            label_name = info[b][:200] + '\n' + info[b][200:] + '\n' + row_template.format(*row2) + '\n' + row_template.format(*row3)

            ndarr = x_all[b].mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()
            arr = np.zeros((85, w, ch), dtype=ndarr.dtype)
            ndarr = np.concatenate((arr, ndarr), 0)
            im = Image.fromarray(ndarr)
            draw = ImageDraw.Draw(im)

            font = ImageFont.truetype("fonts/UbuntuMono-R.ttf", 20)

            draw.text((5, 1), label_name, (255,255,255), font=font)

            if save_image:
                path = "./logs/images/{}/{}/{}/{}".format(self.args.run, epoch, label_names[0], predict)
                if not os.path.exists(path):
                    os.makedirs(path)
                im.save("{}/{:0>4}.{:0>2}.jpg".format(path, index, b))

            im_np = np.array(im).astype(np.float)
            x_all_new[b] = (torch.from_numpy(im_np)/255.0).permute(2,0,1)

        if not save_image:
            summary_grid = vutils.make_grid(x_all_new, nrow=1, padding=8, pad_value=0.9)
            self.writer.add_image(tag, summary_grid, t)

    def _apply_cmap(self, mask_idx, mask_conf):
        palette = self.trainloader.dataset.get_palette()

        masks = []
        col = Colorize()
        mask_conf = mask_conf.float() / 255.0
        for mask, conf in zip(mask_idx.split(1), mask_conf.split(1)):
            m = col(mask).float()
            m = m * conf
            masks.append(m[None, ...])

        return torch.cat(masks, 0)

    def _mask_rgb(self, masks, image_norm, alpha=0.3):
        # visualising masks
        masks_conf, masks_idx = torch.max(masks, 1)
        masks_conf = masks_conf - F.relu(masks_conf - 1, 0)

        masks_idx_rgb = self._apply_cmap(masks_idx.cpu(), masks_conf.cpu())
        return alpha * image_norm + (1 - alpha) * masks_idx_rgb

    def _init_norm(self):
        self.trainloader.dataset.set_norm(self.enc.normalize)
        self.valloader.dataset.set_norm(self.enc.normalize)
        self.trainloader_val.dataset.set_norm(self.enc.normalize)