Exemple #1
0
def main():
    opts = get_argparser().parse_args()
    opts = modify_command_options(opts)

    os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id
    device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' )
    print("Device: %s"%device)

    # Set up random seed
    torch.manual_seed(opts.random_seed)
    torch.cuda.manual_seed(opts.random_seed)
    np.random.seed(opts.random_seed)
    random.seed(opts.random_seed)

    # Set up dataloader
    _, val_dst = get_dataset(opts)
    val_loader = data.DataLoader(val_dst, batch_size=opts.batch_size if opts.crop_val else 1 , shuffle=False, num_workers=opts.num_workers)
    print("Dataset: %s, Val set: %d"%(opts.dataset, len(val_dst)))
    
    # Set up model
    print("Backbone: %s"%opts.backbone)
    model = DeepLabv3(num_classes=opts.num_classes, backbone=opts.backbone, pretrained=True, momentum=opts.bn_mom, output_stride=opts.output_stride, use_separable_conv=opts.use_separable_conv)
    if opts.use_gn==True:
        print("[!] Replace BatchNorm with GroupNorm!")
        model = utils.convert_bn2gn(model)

    if torch.cuda.device_count()>1: # Parallel
        print("%d GPU parallel"%(torch.cuda.device_count()))
        model = torch.nn.DataParallel(model)
        model_ref = model.module # for ckpt
    else:
        model_ref = model
    model = model.to(device)
    
    # Set up metrics
    metrics = StreamSegMetrics(opts.num_classes)

    if opts.save_path is not None:
        utils.mkdir(opts.save_path)

    # Restore
    if opts.ckpt is not None and os.path.isfile(opts.ckpt):
        checkpoint = torch.load(opts.ckpt)
        model_ref.load_state_dict(checkpoint["model_state"])
        print("Model restored from %s"%opts.ckpt)
    else:
        print("[!] Retrain")
    
    label2color = utils.Label2Color(cmap=utils.color_map(opts.dataset)) # convert labels to images
    denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406],  
                               std=[0.229, 0.224, 0.225])  # denormalization for ori images
    model.eval()
    metrics.reset()
    idx = 0

    if opts.save_path is not None:
        import matplotlib
        matplotlib.use('Agg')
        import matplotlib.pyplot as plt
        

    with torch.no_grad():
        for i, (images, labels) in tqdm( enumerate( val_loader ) ):
            images = images.to(device, dtype=torch.float32)
            labels = labels.to(device, dtype=torch.long)

            outputs = model(images)
            preds = outputs.detach().max(dim=1)[1].cpu().numpy()
            targets = labels.cpu().numpy()
            
            metrics.update(targets, preds)
            if opts.save_path is not None:
                for i in range(len(images)):
                    image = images[i].detach().cpu().numpy()
                    target = targets[i]
                    pred = preds[i]

                    image = (denorm(image) * 255).transpose(1,2,0).astype(np.uint8)
                    target = label2color(target).astype(np.uint8)
                    pred = label2color(pred).astype(np.uint8)

                    Image.fromarray(image).save(os.path.join(opts.save_path, '%d_image.png'%idx) )
                    Image.fromarray(target).save(os.path.join(opts.save_path, '%d_target.png'%idx) )
                    Image.fromarray(pred).save(os.path.join(opts.save_path, '%d_pred.png'%idx) )
                    
                    fig = plt.figure()
                    plt.imshow(image)
                    plt.axis('off')
                    plt.imshow(pred, alpha=0.7)
                    ax = plt.gca()
                    ax.xaxis.set_major_locator(matplotlib.ticker.NullLocator())
                    ax.yaxis.set_major_locator(matplotlib.ticker.NullLocator())
                    plt.savefig(os.path.join(opts.save_path, '%d_overlay.png'%idx), bbox_inches='tight', pad_inches=0)
                    plt.close()
                    idx+=1
                
    score = metrics.get_results()
    print(metrics.to_str(score))
    if opts.save_path is not None:
        with open(os.path.join(opts.save_path, 'score.txt'), mode='w') as f:
            f.write(metrics.to_str(score))
print("Val images: ",len(val_dataset.images))

start_epoch = 0

if opts.resume is not None:
  state = torch.load(opts.resume)
  start_epoch = state['epoch']+1
  model.load_state_dict(state['state_dict'])
  optimizer.load_state_dict(state['optimizer'])
  opts = state['opts']
  print("resuming")

for epoch in range(start_epoch, num_epochs):

  # train part
  metrics.reset()
  model.train()
  train_loss = 0.0
  for step, (images, labels) in enumerate(train_loader):
    
    start = time.time()
    # if cuda
    images = images.to(device, dtype=torch.float32)
    labels = labels.to(device, dtype=torch.long)
    labels = labels.squeeze(1)
    
    # get loss
    optimizer.zero_grad()
    outputs = model(images)

    loss = criterion(outputs, labels)
Exemple #3
0
class Trainer():
    def __init__(self, data_loader, opts):
        #super(Trainer, self).__init__(data_loader, opts)
        #self.opts = opts
        self.train_loader = data_loader[0]
        self.val_loader = data_loader[1]

        # Set up model
        model_map = {
            'deeplabv3_resnet50': network.deeplabv3_resnet50,
            'deeplabv3plus_resnet50': network.deeplabv3plus_resnet50,
            'deeplabv3_resnet101': network.deeplabv3_resnet101,
            'deeplabv3plus_resnet101': network.deeplabv3plus_resnet101,
            'deeplabv3_mobilenet': network.deeplabv3_mobilenet,
            'deeplabv3plus_mobilenet': network.deeplabv3plus_mobilenet
        }

        self.model = model_map[opts.model](num_classes=opts.num_classes, output_stride=opts.output_stride)
        if opts.separable_conv and 'plus' in opts.model:
            network.convert_to_separable_conv(self.model.classifier)

        def set_bn_momentum(model, momentum=0.1):
            for m in model.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.momentum = momentum

        set_bn_momentum(self.model.backbone, momentum=0.01)             ##### What is Momentum? 0.01 or 0.99? #####


        # Set up metrics
        self.metrics = StreamSegMetrics(opts.num_classes)

        # Set up optimizer
        self.optimizer = torch.optim.SGD(params=[
            {'params': self.model.backbone.parameters(), 'lr': 0.1*opts.lr},
            {'params': self.model.classifier.parameters(), 'lr': opts.lr},
        ], lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)

        if opts.lr_policy=='poly':
            self.scheduler = utils.PolyLR(self.optimizer, opts.total_itrs, power=0.9)
        elif opts.lr_policy=='step':
            self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=opts.step_size, gamma=0.1)

        # Set up criterion
        if opts.loss_type == 'focal_loss':
            self.criterion = utils.FocalLoss(ignore_index=255, size_average=True)
        elif opts.loss_type == 'cross_entropy':
            self.criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='mean')


        self.best_mean_iu = 0
        self.iteration = 0


    def _label_accuracy_score(self, label_trues, label_preds, n_class):
        """Returns accuracy score evaluation result.
          - overall accuracy
          - mean accuracy
          - mean IU
          - fwavacc
        """
        def _fast_hist(label_true, label_pred, n_class):
            mask = (label_true >= 0) & (label_true < n_class)
            hist = np.bincount(n_class * label_true[mask].astype(int) +
                           label_pred[mask],
                           minlength=n_class**2).reshape(n_class, n_class)
            return hist



        hist = np.zeros((n_class, n_class))
        for lt, lp in zip(label_trues, label_preds):
            hist += _fast_hist(lt.flatten(), lp.flatten(), n_class)
        acc = np.diag(hist).sum() / hist.sum()
        acc_cls = np.diag(hist) / hist.sum(axis=1)
        acc_cls = np.nanmean(acc_cls)
        iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
        mean_iu = np.nanmean(iu)
        freq = hist.sum(axis=1) / hist.sum()
        fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
        return acc, acc_cls, mean_iu, fwavacc


    def validate(self, opts):
        # import matplotlib.pyplot as plt
        training = self.model.training
        self.model.eval()

        n_class = opts.num_classes

        val_loss = 0
        visualizations = []
        label_trues, label_preds = [], []
        with torch.no_grad():
            for i, (data, target) in tqdm(enumerate(self.val_loader)):
            #for batch_idx, (data, target) in tqdm.tqdm(
            #        enumerate(self.val_loader),
            #        total=len(self.val_loader),
            #        desc='Valid iteration=%d' % self.iteration,
            #        ncols=80,
            #        leave=False):
                #print(target)
                data, target = data.to(opts.device, dtype=torch.float), target.to(opts.device)
                score = self.model(data)

                loss = self.criterion(score, target)
                if np.isnan(float(loss.item())):
                    raise ValueError('loss is nan while validating')
                val_loss += float(loss.item()) / len(data)

                imgs = data.data.cpu()
                lbl_pred = score.data.max(1)[1].cpu().numpy()[:, :, :]
                lbl_true = target.data.cpu()
                for img, lt, lp in zip(imgs, lbl_true, lbl_pred):
                    img, lt = self.val_loader.dataset.untransform(img, lt)
                    label_trues.append(lt)
                    label_preds.append(lp)
                    if len(visualizations) < 9:
                        pass
                        viz = utils.fcn_utils.visualize_segmentation(lbl_pred=lp,
                                                           lbl_true=lt,
                                                           img=img,
                                                           n_class=opts.num_classes)
                        visualizations.append(viz)
                        pass
        acc, acc_cls, mean_iu, fwavacc = self._label_accuracy_score(label_trues, label_preds, n_class)

        out = os.path.join(opts.output, 'visualization_viz')
        if not os.path.exists(out):
            os.makedirs(out)
        out_file = os.path.join(out, 'iter%012d.jpg' % self.iteration)
        #raise Exception(len(visualizations))
        img_ = utils.fcn_utils.get_tile_image(visualizations)
        imageio.imwrite(out_file, img_)
        # plt.imshow(imageio.imread(out_file))
        # plt.show()

        val_loss /= len(self.val_loader)



        is_best = mean_iu > self.best_mean_iu
        if is_best:  # save best model
            self.best_mean_iu = mean_iu

        def save_ckpt(path):
            """ save current model
            """
            torch.save({
                "cur_itrs": self.iteration,
                "model_state": self.model.module.state_dict(),
                "optimizer_state": self.optimizer.state_dict(),
                "scheduler_state": self.scheduler.state_dict(),
                "best_score": self.best_mean_iu,
            }, path)
            print("Model saved as %s" % path)


        if is_best:
            save_ckpt('checkpoints/best_%s_%s_os%d.pth' % (opts.model, opts.dataset, opts.output_stride))
        else:
            save_ckpt('checkpoints/latest_%s_%s_os%d.pth' % (opts.model, opts.dataset, opts.output_stride))

        if training:
            self.model.train()

        """Do validation and return specified samples"""
        self.metrics.reset()




    def train(self, opts):

        print("Dataset: %s, Train set: %d, Val set: %d" % (opts.dataset, len(self.train_loader), len(self.val_loader)))


        if not os.path.exists('./checkpoints'):
            os.mkdir('./checkpoints')

        # Restore
        best_score = 0.0
        cur_itrs = 0
        cur_epochs = 0
        if opts.ckpt is not None and os.path.isfile(opts.ckpt):
            print('[!] Retrain from a checkpoint')
            # https://github.com/VainF/DeepLabV3Plus-Pytorch/issues/8#issuecomment-605601402, @PytaichukBohdan
            checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu'))
            self.model.load_state_dict(checkpoint["model_state"])
            self.model = nn.DataParallel(model)
            self.model.to(opts.device)
            if opts.continue_training:
                self.optimizer.load_state_dict(checkpoint["optimizer_state"])
                self.scheduler.load_state_dict(checkpoint["scheduler_state"])
                cur_itrs = checkpoint["cur_itrs"]
                best_score = checkpoint['best_score']
                print("Training state restored from %s" % opts.ckpt)
            print("Model restored from %s" % opts.ckpt)
            del checkpoint  # free memory
        else:
            print('[!] Retrain from base network')
            self.model = nn.DataParallel(self.model)
            self.model.to(opts.device)

        #==========   Train Loop   ==========#
        interval_loss = 0
        while True: #cur_itrs < opts.total_itrs:
            # =====  Train  =====
            self.model.train()
            cur_epochs += 1
            for (images, labels) in self.train_loader:
                cur_itrs += 1

                images = images.to(opts.device, dtype=torch.float32)
                labels = labels.to(opts.device, dtype=torch.long)

                self.optimizer.zero_grad()

                outputs = self.model(images)
                loss = self.criterion(outputs, labels)
                loss.backward()
                self.optimizer.step()

                np_loss = loss.detach().cpu().numpy()
                interval_loss += np_loss


                lbl_pred = outputs.data.max(1)[1].cpu().numpy()[:, :, :]
                lbl_true = labels.data.cpu().numpy()
                acc, acc_cls, mean_iu, fwavacc = self._label_accuracy_score(lbl_true, lbl_pred, n_class=opts.num_classes)




                self.iteration = cur_itrs
                '''
                if (cur_itrs) % 10 == 0:
                    interval_loss = interval_loss/10
                    print("Epoch %d, Itrs %d/%d, Loss=%f" %
                          (cur_epochs, cur_itrs, opts.total_itrs, interval_loss))
                    interval_loss = 0.0
                '''

                if cur_itrs % opts.interval_val == 0:
                    print('Epoch %d, Itrs %d/%d' % (cur_epochs, cur_itrs, opts.total_itrs))
                    self.validate(opts)

                if cur_itrs >=  opts.total_itrs:
                    return