Esempio n. 1
0
class Tester():
    def __init__(self, args):
        self.args = args
        self.args.start_epoch = 0
        self.args.cuda = True
        # data transforms
        input_transform = transform.Compose([
            transform.ToTensor(),
            transform.Normalize([.490, .490, .490], [.247, .247, .247])]) # TODO: change mean and std
        
        # dataset
        testset = SegmentationDataset(
                    os.path.join(args.imagelist_path, 'test_stage2.csv'),
                    args.image_path,
                    args.masks_path,
                    input_transform=input_transform, 
                    transform_chain=Compose([Resize(self.args.size, self.args.size)], p=1),
                    base_size=480, is_flip=True, is_clahe=True, is_sh_sc_ro=True
        )
        # dataloader
        kwargs = {'num_workers': args.workers }#, 'pin_memory': True} 
        self.testloader = data.DataLoader(testset, batch_size=args.batch_size,
                                           drop_last=False, shuffle=False, **kwargs)
        self.nclass = 1
        model = EncNet(
            nclass=self.nclass, backbone=args.backbone,
            aux=args.aux, se_loss=args.se_loss, norm_layer=SyncBatchNorm
        )
        print(model)

        self.model = model

        # resuming checkpoint
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            args.start_epoch = checkpoint['epoch']
            state_dict = {k[7:] : v for k,v in checkpoint['state_dict'].items()}
            self.model.load_state_dict(state_dict)
            self.best_pred = checkpoint['best_pred']
            if 'best_loss' in checkpoint.keys(): 
                self.best_loss = checkpoint['best_loss']
            else:
                self.best_loss = 0
            print("=> loaded checkpoint '{}' (epoch {}, best pred: {}, best loss, {})"
                  .format(args.resume, checkpoint['epoch'], self.best_pred, self.best_loss))
        
        self.model = DataParallelModel(self.model).cuda()

        self.mode2func = {
            0 : lambda x, y: (x, y),
            1 : apply_hflip,
            2 : lambda x, y: (x, y),
            3 : lambda x, y: apply_revert_shscro(x, y, angle=5, scale=0.9, dx=0, dy=0),
            4 : lambda x, y: apply_revert_shscro(x, y, angle=10, scale=0.9, dx=0, dy=0),
            5 : lambda x, y: apply_revert_shscro(x, y, angle=15, scale=0.9, dx=0, dy=0),
            6 : lambda x, y: apply_revert_shscro(x, y, angle=20, scale=0.9, dx=0, dy=0),
            7 : lambda x, y: apply_revert_shscro(x, y, angle=-5, scale=0.9, dx=0, dy=0),
            8 : lambda x, y: apply_revert_shscro(x, y, angle=-10, scale=0.9, dx=0, dy=0),
            9 : lambda x, y: apply_revert_shscro(x, y, angle=-15, scale=0.9, dx=0, dy=0),
            10 : lambda x, y: apply_revert_shscro(x, y, angle=-20, scale=0.9, dx=0, dy=0),
        }

    def predict(self):
        train_loss = 0.0
        self.model.eval()
        tbar = tqdm(self.testloader)
        img_ids = []
        encode_pixels = []
        for i, (img_id, image, _, mode) in enumerate(tbar):
            image = image.cuda()
            with torch.no_grad():
                outputs = self.model(image)
            
            preds_ten = [v[0].data.cpu() for v in outputs]
            cls_preds_ten = [v[1].data.cpu() for v in outputs]
            preds_ten = torch.cat(preds_ten)
            cls_preds_ten = torch.cat(cls_preds_ten)
            preds = torch.sigmoid(preds_ten).data.cpu().numpy()[:, 0, :, :]
            mask_pred = torch.sigmoid(cls_preds_ten).data.cpu().numpy().reshape(-1)

            l_img_id = list(img_id)
            img_ids += l_img_id
            for k, imid in enumerate(l_img_id):
                npy_file = os.path.join(self.args.pred_path,
                        str(imid) + f'_{mode[k].item()}.npy')
                if mode[k].item() < 2:
                    np.save(npy_file, cv2.resize(self.mode2func[mode[k].item()](preds[k], None)[0], (1024, 1024)))
                encode_pixels.append(mask_pred[k])
        pd.DataFrame({'ImageId' : img_ids, 'EncodedPixels' : encode_pixels}).to_csv(
            os.path.join(self.args.pred_path, 'stage2_new_model_submit_16.csv'), index=None)


    def __del__(self):
        del self.model
        gc.collect()
Esempio n. 2
0
 def __init__(self, args):
     self.args = args
     # data transforms
     input_transform = transform.Compose([
         transform.ToTensor(),
         transform.Normalize([.485, .456, .406], [.229, .224, .225])
     ])
     # dataset
     data_kwargs = {
         'transform': input_transform,
         'base_size': args.base_size,
         'crop_size': args.crop_size
     }
     trainset = get_dataset(args.dataset,
                            split=args.train_split,
                            mode='train',
                            **data_kwargs)
     valset = get_dataset(
         args.dataset,
         split='val',
         mode='ms_val' if args.multi_scale_eval else 'fast_val',
         **data_kwargs)
     # dataloader
     kwargs = {'num_workers': args.workers, 'pin_memory': True}
     self.trainloader = data.DataLoader(trainset,
                                        batch_size=args.batch_size,
                                        drop_last=True,
                                        shuffle=True,
                                        **kwargs)
     if self.args.multi_scale_eval:
         kwargs['collate_fn'] = test_batchify_fn
     self.valloader = data.DataLoader(valset,
                                      batch_size=args.test_batch_size,
                                      drop_last=False,
                                      shuffle=False,
                                      **kwargs)
     self.nclass = trainset.num_class
     # model
     if args.norm_layer == 'bn':
         norm_layer = BatchNorm2d
     elif args.norm_layer == 'sync_bn':
         assert args.multi_gpu, "SyncBatchNorm can only be used when multi GPUs are available!"
         norm_layer = SyncBatchNorm
     else:
         raise ValueError('Invalid norm_layer {}'.format(args.norm_layer))
     model = get_segmentation_model(
         args.model,
         dataset=args.dataset,
         backbone=args.backbone,
         aux=args.aux,
         se_loss=args.se_loss,
         norm_layer=norm_layer,
         base_size=args.base_size,
         crop_size=args.crop_size,
         multi_grid=True,
         multi_dilation=[2, 4, 8],
         only_pam=True,
     )
     print(model)
     # optimizer using different LR
     params_list = [
         {
             'params': model.pretrained.parameters(),
             'lr': args.lr
         },
     ]
     if hasattr(model, 'head'):
         params_list.append({
             'params': model.head.parameters(),
             'lr': args.lr
         })
     if hasattr(model, 'auxlayer'):
         params_list.append({
             'params': model.auxlayer.parameters(),
             'lr': args.lr
         })
     optimizer = torch.optim.SGD(params_list,
                                 lr=args.lr,
                                 momentum=args.momentum,
                                 weight_decay=args.weight_decay)
     # criterions
     self.criterion = SegmentationMultiLosses()
     self.model, self.optimizer = model, optimizer
     # using cuda
     if args.multi_gpu:
         self.model = DataParallelModel(self.model).cuda()
         self.criterion = DataParallelCriterion(self.criterion).cuda()
     else:
         self.model = self.model.cuda()
         self.criterion = self.criterion.cuda()
     self.single_device_model = self.model.module if self.args.multi_gpu else self.model
     # resuming checkpoint
     if args.resume is not None:
         if not os.path.isfile(args.resume):
             raise RuntimeError("=> no checkpoint found at '{}'".format(
                 args.resume))
         checkpoint = torch.load(args.resume)
         args.start_epoch = checkpoint['epoch']
         self.single_device_model.load_state_dict(checkpoint['state_dict'])
         if not args.ft and not (args.only_val or args.only_vis
                                 or args.only_infer):
             self.optimizer.load_state_dict(checkpoint['optimizer'])
         self.best_pred = checkpoint['best_pred']
         print("=> loaded checkpoint '{}' (epoch {}), best_pred {}".format(
             args.resume, checkpoint['epoch'], checkpoint['best_pred']))
     # clear start epoch if fine-tuning
     if args.ft:
         args.start_epoch = 0
     # lr scheduler
     self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
         optimizer, 0.6)
     self.best_pred = 0.0
Esempio n. 3
0
class Trainer():
    def __init__(self, args):
        self.args = args
        args.log_name = str(args.checkname)
        self.logger = utils.create_logger(args.log_root, args.log_name)
        # data transforms
        input_transform = transform.Compose([
            transform.ToTensor(),
            transform.Normalize([.485, .456, .406], [.229, .224, .225])
        ])
        # dataset
        data_kwargs = {
            'transform': input_transform,
            'base_size': args.base_size,
            'crop_size': args.crop_size,
            'logger': self.logger,
            'scale': args.scale
        }
        trainset = get_segmentation_dataset(args.dataset,
                                            split='train',
                                            mode='train',
                                            **data_kwargs)
        testset = get_segmentation_dataset(args.dataset,
                                           split='val',
                                           mode='val',
                                           **data_kwargs)
        # dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True} \
            if args.cuda else {}
        self.trainloader = data.DataLoader(trainset,
                                           batch_size=args.batch_size,
                                           drop_last=True,
                                           shuffle=True,
                                           **kwargs)
        self.valloader = data.DataLoader(testset,
                                         batch_size=args.batch_size,
                                         drop_last=False,
                                         shuffle=False,
                                         **kwargs)
        self.nclass = trainset.num_class
        # model
        model = get_segmentation_model(args.model,
                                       dataset=args.dataset,
                                       backbone=args.backbone,
                                       aux=args.aux,
                                       se_loss=args.se_loss,
                                       norm_layer=BatchNorm2d,
                                       base_size=args.base_size,
                                       crop_size=args.crop_size,
                                       multi_grid=args.multi_grid,
                                       multi_dilation=args.multi_dilation)
        #print(model)
        self.logger.info(model)
        # optimizer using different LR
        params_list = [
            {
                'params': model.pretrained.parameters(),
                'lr': args.lr
            },
        ]
        if hasattr(model, 'head'):
            params_list.append({
                'params': model.head.parameters(),
                'lr': args.lr * 10
            })
        if hasattr(model, 'auxlayer'):
            params_list.append({
                'params': model.auxlayer.parameters(),
                'lr': args.lr * 10
            })
        optimizer = torch.optim.SGD(params_list,
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
        self.criterion = SegmentationMultiLosses(nclass=self.nclass)
        #self.criterion = SegmentationLosses(se_loss=args.se_loss, aux=args.aux,nclass=self.nclass)

        self.model, self.optimizer = model, optimizer
        # using cuda
        if args.cuda:
            self.model = DataParallelModel(self.model).cuda()
            self.criterion = DataParallelCriterion(self.criterion).cuda()
        # finetune from a trained model
        if args.ft:
            args.start_epoch = 0
            checkpoint = torch.load(args.ft_resume)
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'],
                                                  strict=False)
            else:
                self.model.load_state_dict(checkpoint['state_dict'],
                                           strict=False)
            self.logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                args.ft_resume, checkpoint['epoch']))
        # resuming checkpoint
        if args.resume:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            self.logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        # lr scheduler
        self.scheduler = utils.LR_Scheduler(args.lr_scheduler,
                                            args.lr,
                                            args.epochs,
                                            len(self.trainloader),
                                            logger=self.logger,
                                            lr_step=args.lr_step)
        self.best_pred = 0.0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.trainloader)

        for i, (image, target) in enumerate(tbar):
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            if torch_ver == "0.3":
                image = Variable(image)
                target = Variable(target)
            outputs = self.model(image)
            loss = self.criterion(outputs, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
        self.logger.info('Train loss: %.3f' % (train_loss / (i + 1)))

        if self.args.no_val:
            # save checkpoint every 10 epoch
            filename = "checkpoint_%s.pth.tar" % (epoch + 1)
            is_best = False
            if epoch > 99:
                if not epoch % 5:
                    utils.save_checkpoint(
                        {
                            'epoch': epoch + 1,
                            'state_dict': self.model.module.state_dict(),
                            'optimizer': self.optimizer.state_dict(),
                            'best_pred': self.best_pred,
                        }, self.args, is_best, filename)

    def validation(self, epoch):
        # Fast test during the training
        def eval_batch(model, image, target):
            outputs = model(image)
            outputs = gather(outputs, 0, dim=0)
            pred = outputs[0]
            target = target.cuda()
            correct, labeled = utils.batch_pix_accuracy(pred.data, target)
            inter, union = utils.batch_intersection_union(
                pred.data, target, self.nclass)
            return correct, labeled, inter, union

        is_best = False
        self.model.eval()
        total_inter, total_union, total_correct, total_label = 0, 0, 0, 0
        tbar = tqdm(self.valloader, desc='\r')

        for i, (image, target) in enumerate(tbar):
            if torch_ver == "0.3":
                image = Variable(image, volatile=True)
                correct, labeled, inter, union = eval_batch(
                    self.model, image, target)
            else:
                with torch.no_grad():
                    correct, labeled, inter, union = eval_batch(
                        self.model, image, target)

            total_correct += correct
            total_label += labeled
            total_inter += inter
            total_union += union
            pixAcc = 1.0 * total_correct / (np.spacing(1) + total_label)
            IoU = 1.0 * total_inter / (np.spacing(1) + total_union)
            mIoU = IoU.mean()
            tbar.set_description('pixAcc: %.3f, mIoU: %.3f' % (pixAcc, mIoU))
        self.logger.info('pixAcc: %.3f, mIoU: %.3f' % (pixAcc, mIoU))

        new_pred = (pixAcc + mIoU) / 2
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            utils.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, self.args, is_best)
Esempio n. 4
0
class Trainer():
    def __init__(self, args):
        self.args = args
        # data transforms
        input_transform = transform.Compose([
            transform.ToTensor(),
            transform.Normalize([.485, .456, .406], [.229, .224, .225])
        ])
        # dataset
        data_kwargs = {
            'transform': input_transform,
            'base_size': args.base_size,
            'crop_size': args.crop_size
        }
        trainset = get_segmentation_dataset(args.dataset,
                                            split=args.train_split,
                                            mode='train',
                                            **data_kwargs)
        testset = get_segmentation_dataset(args.dataset,
                                           split='val',
                                           mode='val',
                                           **data_kwargs)
        # dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True} \
            if args.cuda else {}
        self.trainloader = data.DataLoader(trainset,
                                           batch_size=args.batch_size,
                                           drop_last=True,
                                           shuffle=True,
                                           **kwargs)
        self.valloader = data.DataLoader(testset,
                                         batch_size=args.batch_size,
                                         drop_last=False,
                                         shuffle=False,
                                         **kwargs)
        self.nclass = trainset.num_class
        # model
        model = get_segmentation_model(
            args.model,
            dataset=args.dataset,
            backbone=args.backbone,
            dilated=args.dilated,
            lateral=args.lateral,
            jpu=args.jpu,
            aux=args.aux,
            se_loss=args.se_loss,
            norm_layer=torch.nn.BatchNorm2d,  ## BatchNorm2d
            base_size=args.base_size,
            crop_size=args.crop_size)
        print(model)
        # optimizer using different LR
        params_list = [
            {
                'params': model.pretrained.parameters(),
                'lr': args.lr
            },
        ]
        if hasattr(model, 'jpu'):
            params_list.append({
                'params': model.jpu.parameters(),
                'lr': args.lr * 10
            })
        if hasattr(model, 'head'):
            params_list.append({
                'params': model.head.parameters(),
                'lr': args.lr * 10
            })
        if hasattr(model, 'auxlayer'):
            params_list.append({
                'params': model.auxlayer.parameters(),
                'lr': args.lr * 10
            })
        optimizer = torch.optim.SGD(params_list,
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
        # criterions
        self.criterion = SegmentationLosses(se_loss=args.se_loss,
                                            aux=args.aux,
                                            nclass=self.nclass,
                                            se_weight=args.se_weight,
                                            aux_weight=args.aux_weight)
        self.model, self.optimizer = model, optimizer
        # using cuda
        if args.cuda:
            self.model = DataParallelModel(self.model).cuda()
            self.criterion = DataParallelCriterion(self.criterion).cuda()
        # resuming checkpoint
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        # clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0
        # lr scheduler
        self.scheduler = utils.LR_Scheduler(args.lr_scheduler, args.lr,
                                            args.epochs, len(self.trainloader))
        self.best_pred = 0.0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.trainloader)
        for i, (image, target) in enumerate(tbar):
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            if torch_ver == "0.3":
                image = Variable(image)
                target = Variable(target)
            outputs = self.model(image)
            ## original
            loss = self.criterion(outputs, target)
            loss.backward()
            '''
            ## modified loss
            criterion = JointEdgeSegLoss(classes=num_classes, ignore_index=args.dataset_cls.ignore_label, upper_bound=args.wt_bound,
                edge_weight=args.edge_weight, seg_weight=args.seg_weight, att_weight=args.att_weight, dual_weight=args.dual_weight)

            train_main_loss = AverageMeter()
            train_edge_loss = AverageMeter()
            train_seg_loss = AverageMeter()
            train_att_loss = AverageMeter()
            train_dual_loss = AverageMeter()

            main_loss = None
            loss_dict = None
            self.criterion((seg_out, edge_out), gts)


            if args.seg_weight > 0:
                log_seg_loss = loss_dict['seg_loss'].mean().clone().detach_()
                train_seg_loss.update(log_seg_loss.item(), batch_pixel_size)
                main_loss = loss_dict['seg_loss']

            if args.edge_weight > 0:
                log_edge_loss = loss_dict['edge_loss'].mean().clone().detach_()
                train_edge_loss.update(log_edge_loss.item(), batch_pixel_size)
                if main_loss is not None:
                    main_loss += loss_dict['edge_loss']
                else:
                    main_loss = loss_dict['edge_loss']
            
            if args.att_weight > 0:
                log_att_loss = loss_dict['att_loss'].mean().clone().detach_()
                train_att_loss.update(log_att_loss.item(), batch_pixel_size)
                if main_loss is not None:
                    main_loss += loss_dict['att_loss']
                else:
                    main_loss = loss_dict['att_loss']

            if args.dual_weight > 0:
                log_dual_loss = loss_dict['dual_loss'].mean().clone().detach_()
                train_dual_loss.update(log_dual_loss.item(), batch_pixel_size)
                if main_loss is not None:
                    main_loss += loss_dict['dual_loss']
                else:
                    main_loss = loss_dict['dual_loss']
            '''

            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            utils.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, self.args, is_best)

    def validation(self, epoch):
        # Fast test during the training
        def eval_batch(model, image, target):
            outputs = model(image)
            outputs = gather(outputs, 0, dim=0)
            pred = outputs[0]
            target = target.cuda()
            correct, labeled = utils.batch_pix_accuracy(pred.data, target)
            inter, union = utils.batch_intersection_union(
                pred.data, target, self.nclass)
            return correct, labeled, inter, union

        is_best = False
        self.model.eval()
        total_inter, total_union, total_correct, total_label = 0, 0, 0, 0
        tbar = tqdm(self.valloader, desc='\r')
        for i, (image, target) in enumerate(tbar):
            if torch_ver == "0.3":
                image = Variable(image, volatile=True)
                correct, labeled, inter, union = eval_batch(
                    self.model, image, target)
            else:
                with torch.no_grad():
                    correct, labeled, inter, union = eval_batch(
                        self.model, image, target)

            total_correct += correct
            total_label += labeled
            total_inter += inter
            total_union += union
            pixAcc = 1.0 * total_correct / (np.spacing(1) + total_label)
            IoU = 1.0 * total_inter / (np.spacing(1) + total_union)
            mIoU = IoU.mean()
            tbar.set_description('pixAcc: %.3f, mIoU: %.3f' % (pixAcc, mIoU))

        new_pred = (pixAcc + mIoU) / 2
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            utils.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, self.args, is_best)
Esempio n. 5
0
    def __init__(self, args):
        self.args = args
        self.args.start_epoch = 0
        self.args.cuda = True
        # data transforms
        input_transform = transform.Compose([
            transform.ToTensor(),
            transform.Normalize([.490, .490, .490], [.247, .247, .247])
        ])  # TODO: change mean and std

        # dataset
        train_chain = Compose([
            HorizontalFlip(p=0.5),
            OneOf([
                ElasticTransform(
                    alpha=300, sigma=300 * 0.05, alpha_affine=300 * 0.03),
                GridDistortion(),
                OpticalDistortion(distort_limit=2, shift_limit=0.5),
            ],
                  p=0.3),
            RandomSizedCrop(
                min_max_height=(900, 1024), height=1024, width=1024, p=0.5),
            ShiftScaleRotate(rotate_limit=20, p=0.5),
            Resize(self.args.size, self.args.size)
        ],
                              p=1)

        val_chain = Compose([Resize(self.args.size, self.args.size)], p=1)
        num_fold = self.args.num_fold
        df_train = pd.read_csv(os.path.join(args.imagelist_path, 'train.csv'))
        df_val = pd.read_csv(os.path.join(args.imagelist_path, 'val.csv'))
        df_full = pd.concat((df_train, df_val), ignore_index=True, axis=0)
        df_full['lbl'] = (df_full['mask_name'].astype(str) == '-1').astype(int)
        skf = StratifiedKFold(8, shuffle=True, random_state=777)
        train_ids, val_ids = list(
            skf.split(df_full['mask_name'], df_full['lbl']))[num_fold]

        df_test = pd.read_csv(
            os.path.join(args.imagelist_path, 'test_true.csv'))

        df_new_train = pd.concat((df_full.iloc[train_ids], df_test),
                                 ignore_index=True,
                                 axis=0,
                                 sort=False)
        df_new_val = df_full.iloc[val_ids]

        df_new_train.to_csv(f'/tmp/train_new_pneumo_{num_fold}.csv')
        df_new_val.to_csv(f'/tmp/val_new_pneumo_{num_fold}.csv')

        trainset = SegmentationDataset(f'/tmp/train_new_pneumo_{num_fold}.csv',
                                       args.image_path,
                                       args.masks_path,
                                       input_transform=input_transform,
                                       transform_chain=train_chain,
                                       base_size=1024)
        testset = SegmentationDataset(f'/tmp/val_new_pneumo_{num_fold}.csv',
                                      args.image_path,
                                      args.masks_path,
                                      input_transform=input_transform,
                                      transform_chain=val_chain,
                                      base_size=1024)

        imgs = trainset.mask_img_map[:, [0, 3]]
        weights = make_weights_for_balanced_classes(imgs, 2)
        weights = torch.DoubleTensor(weights)
        train_sampler = (torch.utils.data.sampler.WeightedRandomSampler(
            weights, len(weights)))

        # dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.trainloader = data.DataLoader(
            trainset,
            batch_size=args.batch_size,
            drop_last=True,
            sampler=train_sampler,  #shuffle=True, 
            **kwargs)
        self.valloader = data.DataLoader(testset,
                                         batch_size=args.batch_size,
                                         drop_last=False,
                                         shuffle=False,
                                         **kwargs)

        self.nclass = 1
        if self.args.model == 'unet':
            model = UNet(n_classes=self.nclass, norm_layer=SyncBatchNorm)
            params_list = [
                {
                    'params': model.parameters(),
                    'lr': args.lr
                },
            ]
        elif self.args.model == 'encnet':
            model = EncNet(
                nclass=self.nclass,
                backbone=args.backbone,
                aux=args.aux,
                se_loss=args.se_loss,
                norm_layer=SyncBatchNorm  #nn.BatchNorm2d
            )

            # optimizer using different LR
            params_list = [
                {
                    'params': model.pretrained.parameters(),
                    'lr': args.lr
                },
            ]
            if hasattr(model, 'head'):
                params_list.append({
                    'params': model.head.parameters(),
                    'lr': args.lr * 10
                })
            if hasattr(model, 'auxlayer'):
                params_list.append({
                    'params': model.auxlayer.parameters(),
                    'lr': args.lr * 10
                })

        print(model)
        optimizer = torch.optim.SGD(params_list,
                                    lr=args.lr,
                                    momentum=0.9,
                                    weight_decay=args.wd)

        # criterions
        if self.nclass == 1:
            self.criterion = SegmentationLossesBCE(se_loss=args.se_loss,
                                                   aux=args.aux,
                                                   nclass=self.nclass,
                                                   se_weight=args.se_weight,
                                                   aux_weight=args.aux_weight,
                                                   use_dice=args.use_dice)
        else:
            self.criterion = SegmentationLosses(
                se_loss=args.se_loss,
                aux=args.aux,
                nclass=self.nclass,
                se_weight=args.se_weight,
                aux_weight=args.aux_weight,
            )
        self.model, self.optimizer = model, optimizer

        self.best_pred = 0.0
        self.model = DataParallelModel(self.model).cuda()
        self.criterion = DataParallelCriterion(self.criterion).cuda()

        # resuming checkpoint
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)  #, map_location='cpu')
            self.args.start_epoch = checkpoint['epoch']
            state_dict = {k: v for k, v in checkpoint['state_dict'].items()}
            self.model.load_state_dict(state_dict)
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            for g in self.optimizer.param_groups:
                g['lr'] = args.lr
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            print(f'Best dice: {checkpoint["best_pred"]}')
            print(f'LR: {get_lr(self.optimizer):.5f}')

        self.scheduler = ReduceLROnPlateau(self.optimizer,
                                           mode='min',
                                           factor=0.8,
                                           patience=4,
                                           threshold=0.001,
                                           threshold_mode='abs',
                                           min_lr=0.00001)
        self.logger = Logger(args.logger_dir)
        self.step_train = 0
        self.best_loss = 20
        self.step_val = 0
Esempio n. 6
0
class Trainer():
    def __init__(self, args):
        if args.se_loss:
            args.checkname = args.checkname + "_se"

        self.args = args
        # data transforms
        input_transform = transform.Compose([
            transform.ToTensor(),
            transform.Normalize([.485, .456, .406], [.229, .224, .225])])
        # dataset
        data_kwargs = {'transform': input_transform, 'base_size': args.base_size,
                       'crop_size': args.crop_size}
        trainset = get_segmentation_dataset(args.dataset, split='train', mode='train',
                                           **data_kwargs)
        testset = get_segmentation_dataset(args.dataset, split='val', mode ='val',
                                           **data_kwargs)
        # dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': False} \
            if args.cuda else {}
        self.trainloader = data.DataLoader(trainset, batch_size=args.batch_size,
                                           drop_last=True, shuffle=True, **kwargs)
        self.valloader = data.DataLoader(testset, batch_size=args.batch_size,
                                         drop_last=False, shuffle=False, **kwargs)
        self.nclass = trainset.num_class
        # model
        model = get_segmentation_model(args.model, dataset=args.dataset,
                                       backbone = args.backbone, aux = args.aux,
                                       se_loss = args.se_loss, norm_layer = BatchNorm2d,
                                       base_size=args.base_size, crop_size=args.crop_size)
        print(model)

        # count parameter number
        pytorch_total_params = sum(p.numel() for p in model.parameters())
        print("Total number of parameters: %d"%pytorch_total_params)

        # optimizer using different LR
        params_list = [{'params': model.pretrained.parameters(), 'lr': args.lr},]
        if hasattr(model, 'head'):
            if args.diflr:
                params_list.append({'params': model.head.parameters(), 'lr': args.lr*10})
            else:
                params_list.append({'params': model.head.parameters(), 'lr': args.lr})
        if hasattr(model, 'auxlayer'):
            if args.diflr:
                params_list.append({'params': model.auxlayer.parameters(), 'lr': args.lr*10})
            else:
                params_list.append({'params': model.auxlayer.parameters(), 'lr': args.lr})

        optimizer = torch.optim.SGD(params_list,
                    lr=args.lr,
                    momentum=args.momentum,
                    weight_decay=args.weight_decay)

        #optimizer = torch.optim.ASGD(params_list,
        #                            lr=args.lr,
        #                            weight_decay=args.weight_decay)

        # criterions
        self.criterion = SegmentationLosses(se_loss=args.se_loss, aux=args.aux,
                                            nclass=self.nclass)
        self.model, self.optimizer = model, optimizer
        # using cuda
        if args.cuda:
            self.model = DataParallelModel(self.model).cuda()
            self.criterion = DataParallelCriterion(self.criterion).cuda()
        # resuming checkpoint
        if args.resume is not None and len(args.resume)>0:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                # load weights for the same model
                # self.model.module.load_state_dict(checkpoint['state_dict'])



                # model and checkpoint have different strucutures
                pretrained_dict = checkpoint['state_dict']
                model_dict = self.model.module.state_dict()

                for name, param in pretrained_dict.items():
                    if name not in model_dict:
                        continue
                    if isinstance(param, Parameter):
                        # backwards compatibility for serialized parameters
                        param = param.data
                    model_dict[name].copy_(param)

            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

        # lr scheduler
        self.scheduler = utils.LR_Scheduler(args.lr_scheduler, args.lr,
                                            args.epochs, len(self.trainloader),lr_step=args.lr_step)
        self.best_pred = 0.0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.trainloader)
        for i, (image, target) in enumerate(tbar):
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            if torch_ver == "0.3":
                image = Variable(image)
                target = Variable(target)
            outputs = self.model(image)
            loss = self.criterion(outputs, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            utils.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, self.args, is_best)


    def validation(self, epoch):
        # Fast test during the training
        def eval_batch(model, image, target):
            outputs = model(image)
            outputs = gather(outputs, 0, dim=0)
            pred = outputs[0]
            target = target.cuda()
            correct, labeled = utils.batch_pix_accuracy(pred.data, target)
            inter, union = utils.batch_intersection_union(pred.data, target, self.nclass)
            return correct, labeled, inter, union

        is_best = False
        self.model.eval()
        total_inter, total_union, total_correct, total_label = 0, 0, 0, 0
        tbar = tqdm(self.valloader, desc='\r')
        for i, (image, target) in enumerate(tbar):
            if torch_ver == "0.3":
                image = Variable(image, volatile=True)
                correct, labeled, inter, union = eval_batch(self.model, image, target)
            else:
                with torch.no_grad():
                    correct, labeled, inter, union = eval_batch(self.model, image, target)

            total_correct += correct
            total_label += labeled
            total_inter += inter
            total_union += union
            pixAcc = 1.0 * total_correct / (np.spacing(1) + total_label)
            IoU = 1.0 * total_inter / (np.spacing(1) + total_union)
            mIoU = IoU.mean()
            tbar.set_description(
                'pixAcc: %.3f, mIoU: %.3f' % (pixAcc, mIoU))

        new_pred = (pixAcc + mIoU)/2
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            utils.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, self.args, is_best)
Esempio n. 7
0
def train(cfg, logger, logdir):
    # Setup seeds
    init_seed(11733, en_cudnn=False)

    # Setup Augmentations
    train_augmentations = cfg["training"].get("train_augmentations", None)
    t_data_aug = get_composed_augmentations(train_augmentations)
    val_augmentations = cfg["validating"].get("val_augmentations", None)
    v_data_aug = get_composed_augmentations(val_augmentations)

    # Setup Dataloader

    path_n = cfg["model"]["path_num"]

    data_loader = get_loader(cfg["data"]["dataset"])
    data_path = cfg["data"]["path"]

    t_loader = data_loader(data_path,split=cfg["data"]["train_split"],augmentations=t_data_aug,path_num=path_n)
    v_loader = data_loader(data_path,split=cfg["data"]["val_split"],augmentations=v_data_aug,path_num=path_n)

    trainloader = data.DataLoader(t_loader,
                                  batch_size=cfg["training"]["batch_size"],
                                  num_workers=cfg["training"]["n_workers"],
                                  shuffle=True,
                                  drop_last=True  )
    valloader = data.DataLoader(v_loader,
                                batch_size=cfg["validating"]["batch_size"],
                                num_workers=cfg["validating"]["n_workers"] )

    logger.info("Using training seting {}".format(cfg["training"]))
    
    # Setup Metrics
    running_metrics_val = runningScore(t_loader.n_classes)

    # Setup Model and Loss
    loss_fn = get_loss_function(cfg["training"])
    teacher = get_model(cfg["teacher"], t_loader.n_classes)
    model = get_model(cfg["model"],t_loader.n_classes, loss_fn, cfg["training"]["resume"],teacher)
    logger.info("Using loss {}".format(loss_fn))

    # Setup optimizer
    optimizer = get_optimizer(cfg["training"], model)

    # Setup Multi-GPU
    model = DataParallelModel(model).cuda()

    #Initialize training param
    cnt_iter = 0
    best_iou = 0.0
    time_meter = averageMeter()

    while cnt_iter <= cfg["training"]["train_iters"]:
        for (f_img, labels) in trainloader:
            cnt_iter += 1
            model.train()
            optimizer.zero_grad()

            start_ts = time.time()
            outputs = model(f_img,labels,pos_id=cnt_iter%path_n)

            seg_loss = gather(outputs, 0)
            seg_loss = torch.mean(seg_loss)

            seg_loss.backward()
            time_meter.update(time.time() - start_ts)

            optimizer.step()

            if (cnt_iter + 1) % cfg["training"]["print_interval"] == 0:
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(
                                            cnt_iter + 1,
                                            cfg["training"]["train_iters"],
                                            seg_loss.item(),
                                            time_meter.avg / cfg["training"]["batch_size"], )

                print(print_str)
                logger.info(print_str)
                time_meter.reset()

            if (cnt_iter + 1) % cfg["training"]["val_interval"] == 0 or (cnt_iter + 1) == cfg["training"]["train_iters"]:
                model.eval()
                with torch.no_grad():
                    for i_val, (f_img_val, labels_val) in tqdm(enumerate(valloader)):
                        
                        outputs = model(f_img_val,pos_id=i_val%path_n)
                        outputs = gather(outputs, 0, dim=0)
                        
                        pred = outputs.data.max(1)[1].cpu().numpy()
                        gt = labels_val.data.cpu().numpy()

                        running_metrics_val.update(gt, pred)

                score, class_iou = running_metrics_val.get_scores()
                for k, v in score.items():
                    print(k, v)
                    logger.info("{}: {}".format(k, v))

                for k, v in class_iou.items():
                    logger.info("{}: {}".format(k, v))

                running_metrics_val.reset()

                if score["Mean IoU : \t"] >= best_iou:
                    best_iou = score["Mean IoU : \t"]
                    state = {
                        "epoch": cnt_iter + 1,
                        "model_state": clean_state_dict(model.module.state_dict(),'teacher'),
                        "best_iou": best_iou,
                    }
                    save_path = os.path.join(logdir,
                        "{}_{}_best_model.pkl".format(cfg["model"]["arch"], cfg["data"]["dataset"]),
                    )
                    torch.save(state, save_path)
Esempio n. 8
0
    def __init__(self, args):
        self.args = args
        # data transforms
        input_transform = transform.Compose([
            transform.ToTensor(),
            transform.Normalize([.485, .456, .406], [.229, .224, .225])
        ])
        # dataset
        data_kwargs = {
            'transform': input_transform,
            'base_size': args.base_size,
            'crop_size': args.crop_size
        }
        trainset = get_segmentation_dataset(args.dataset,
                                            split=args.train_split,
                                            mode='train',
                                            **data_kwargs)
        testset = get_segmentation_dataset(args.dataset,
                                           split='val',
                                           mode='val',
                                           **data_kwargs)
        # dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True} \
            if args.cuda else {}
        self.trainloader = data.DataLoader(trainset,
                                           batch_size=args.batch_size,
                                           drop_last=False,
                                           shuffle=True,
                                           **kwargs)
        self.valloader = data.DataLoader(testset,
                                         batch_size=args.batch_size,
                                         drop_last=False,
                                         shuffle=False,
                                         **kwargs)
        self.nclass = trainset.num_class
        # model
        model = get_segmentation_model(args.model,
                                       dataset=args.dataset,
                                       backbone=args.backbone,
                                       dilated=args.dilated,
                                       multi_grid=args.multi_grid,
                                       stride=args.stride,
                                       lateral=args.lateral,
                                       jpu=args.jpu,
                                       aux=args.aux,
                                       se_loss=args.se_loss,
                                       norm_layer=SyncBatchNorm,
                                       base_size=args.base_size,
                                       crop_size=args.crop_size)
        # print(model)
        # optimizer using different LR
        params_list = [
            {
                'params': model.pretrained.parameters(),
                'lr': args.lr
            },
        ]
        if hasattr(model, 'jpu') and model.jpu:
            params_list.append({
                'params': model.jpu.parameters(),
                'lr': args.lr * 10
            })
        if hasattr(model, 'head'):
            params_list.append({
                'params': model.head.parameters(),
                'lr': args.lr * 10
            })
        if hasattr(model, 'auxlayer'):
            params_list.append({
                'params': model.auxlayer.parameters(),
                'lr': args.lr * 10
            })
        optimizer = torch.optim.SGD(params_list,
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)

        class_balance_weight = 'None'
        if args.dataset == "pcontext60":
            class_balance_weight = torch.tensor([
                1.3225e-01, 2.0757e+00, 1.8146e+01, 5.5052e+00, 2.2060e+00,
                2.8054e+01, 2.0566e+00, 1.8598e+00, 2.4027e+00, 9.3435e+00,
                3.5990e+00, 2.7487e-01, 1.4216e+00, 2.4986e+00, 7.7258e-01,
                4.9020e-01, 2.9067e+00, 1.2197e+00, 2.2744e+00, 2.0444e+01,
                3.0057e+00, 1.8167e+01, 3.7405e+00, 5.6749e-01, 3.2631e+00,
                1.5007e+00, 5.5519e-01, 1.0056e+01, 1.8952e+01, 2.6792e-01,
                2.7479e-01, 1.8309e+00, 2.0428e+01, 1.4788e+01, 1.4908e+00,
                1.9113e+00, 2.6166e+02, 2.3233e-01, 1.9096e+01, 6.7025e+00,
                2.8756e+00, 6.8804e-01, 4.4140e+00, 2.5621e+00, 4.4409e+00,
                4.3821e+00, 1.3774e+01, 1.9803e-01, 3.6944e+00, 1.0397e+00,
                2.0601e+00, 5.5811e+00, 1.3242e+00, 3.0088e-01, 1.7344e+01,
                2.1569e+00, 2.7216e-01, 5.8731e-01, 1.9956e+00, 4.4004e+00
            ])

        elif args.dataset == "ade20k":
            class_balance_weight = torch.tensor([
                0.0772, 0.0431, 0.0631, 0.0766, 0.1095, 0.1399, 0.1502, 0.1702,
                0.2958, 0.3400, 0.3738, 0.3749, 0.4059, 0.4266, 0.4524, 0.5725,
                0.6145, 0.6240, 0.6709, 0.6517, 0.6591, 0.6818, 0.9203, 0.9965,
                1.0272, 1.0967, 1.1202, 1.2354, 1.2900, 1.5038, 1.5160, 1.5172,
                1.5036, 2.0746, 2.1426, 2.3159, 2.2792, 2.6468, 2.8038, 2.8777,
                2.9525, 2.9051, 3.1050, 3.1785, 3.3533, 3.5300, 3.6120, 3.7006,
                3.6790, 3.8057, 3.7604, 3.8043, 3.6610, 3.8268, 4.0644, 4.2698,
                4.0163, 4.0272, 4.1626, 4.3702, 4.3144, 4.3612, 4.4389, 4.5612,
                5.1537, 4.7653, 4.8421, 4.6813, 5.1037, 5.0729, 5.2657, 5.6153,
                5.8240, 5.5360, 5.6373, 6.6972, 6.4561, 6.9555, 7.9239, 7.3265,
                7.7501, 7.7900, 8.0528, 8.5415, 8.1316, 8.6557, 9.0550, 9.0081,
                9.3262, 9.1391, 9.7237, 9.3775, 9.4592, 9.7883, 10.6705,
                10.2113, 10.5845, 10.9667, 10.8754, 10.8274, 11.6427, 11.0687,
                10.8417, 11.0287, 12.2030, 12.8830, 12.5082, 13.0703, 13.8410,
                12.3264, 12.9048, 12.9664, 12.3523, 13.9830, 13.8105, 14.0345,
                15.0054, 13.9801, 14.1048, 13.9025, 13.6179, 17.0577, 15.8351,
                17.7102, 17.3153, 19.4640, 17.7629, 19.9093, 16.9529, 19.3016,
                17.6671, 19.4525, 20.0794, 18.3574, 19.1219, 19.5089, 19.2417,
                20.2534, 20.0332, 21.7496, 21.5427, 20.3008, 21.1942, 22.7051,
                23.3359, 22.4300, 20.9934, 26.9073, 31.7362, 30.0784
            ])
        elif args.dataset == "cocostuff":
            class_balance_weight = torch.tensor([
                4.8557e-02, 6.4709e-02, 3.9255e+00, 9.4797e-01, 1.2703e+00,
                1.4151e+00, 7.9733e-01, 8.4903e-01, 1.0751e+00, 2.4001e+00,
                8.9736e+00, 5.3036e+00, 6.0410e+00, 9.3285e+00, 1.5952e+00,
                3.6090e+00, 9.8772e-01, 1.2319e+00, 1.9194e+00, 2.7624e+00,
                2.0548e+00, 1.2058e+00, 3.6424e+00, 2.0789e+00, 1.7851e+00,
                6.7138e+00, 2.1315e+00, 6.9813e+00, 1.2679e+02, 2.0357e+00,
                2.2933e+01, 2.3198e+01, 1.7439e+01, 4.1294e+01, 7.8678e+00,
                4.3444e+01, 6.7543e+01, 1.0066e+01, 6.7520e+00, 1.3174e+01,
                3.3499e+00, 6.9737e+00, 2.1482e+00, 1.9428e+01, 1.3240e+01,
                1.9218e+01, 7.6836e-01, 2.6041e+00, 6.1822e+00, 1.4070e+00,
                4.4074e+00, 5.7792e+00, 1.0321e+01, 4.9922e+00, 6.7408e-01,
                3.1554e+00, 1.5832e+00, 8.9685e-01, 1.1686e+00, 2.6487e+00,
                6.5354e-01, 2.3801e-01, 1.9536e+00, 1.5862e+00, 1.7797e+00,
                2.7385e+01, 1.2419e+01, 3.9287e+00, 7.8897e+00, 7.5737e+00,
                1.9758e+00, 8.1962e+01, 3.6922e+00, 2.0039e+00, 2.7333e+00,
                5.4717e+00, 3.9048e+00, 1.9184e+01, 2.2689e+00, 2.6091e+02,
                4.7366e+01, 2.3844e+00, 8.3310e+00, 1.4857e+01, 6.5076e+00,
                2.0854e-01, 1.0425e+00, 1.7386e+00, 1.1973e+01, 5.2862e+00,
                1.7341e+00, 8.6124e-01, 9.3702e+00, 2.8545e+00, 6.0123e+00,
                1.7560e-01, 1.8128e+00, 1.3784e+00, 1.3699e+00, 2.3728e+00,
                6.2819e-01, 1.3097e+00, 4.7892e-01, 1.0268e+01, 1.2307e+00,
                5.5662e+00, 1.2867e+00, 1.2745e+00, 4.7505e+00, 8.4029e+00,
                1.8679e+00, 1.0519e+01, 1.1240e+00, 1.4975e-01, 2.3146e+00,
                4.1265e-01, 2.5896e+00, 1.4537e+00, 4.5575e+00, 7.8143e+00,
                1.4603e+01, 2.8812e+00, 1.8868e+00, 7.8131e+01, 1.9323e+00,
                7.4980e+00, 1.2446e+01, 2.1856e+00, 3.0973e+00, 4.1270e-01,
                4.9016e+01, 7.1001e-01, 7.4035e+00, 2.3395e+00, 2.9207e-01,
                2.4156e+00, 3.3211e+00, 2.1300e+00, 2.4533e-01, 1.7081e+00,
                4.6621e+00, 2.9199e+00, 1.0407e+01, 7.6207e-01, 2.7806e-01,
                3.7711e+00, 1.1852e-01, 8.8280e+00, 3.1700e-01, 6.3765e+01,
                6.6032e+00, 5.2177e+00, 4.3596e+00, 6.2965e-01, 1.0207e+00,
                1.1731e+01, 2.3935e+00, 9.2767e+00, 1.1023e-01, 3.6947e+00,
                1.3943e+00, 2.3407e+00, 1.2112e-01, 2.8518e+00, 2.8195e+00,
                1.0078e+00, 1.6614e+00, 6.5307e-01, 1.9070e+01, 2.7231e+00,
                6.0769e-01
            ])

        # criterions
        self.criterion = SegmentationLosses(se_loss=args.se_loss,
                                            aux=args.aux,
                                            nclass=self.nclass,
                                            se_weight=args.se_weight,
                                            aux_weight=args.aux_weight,
                                            weight=class_balance_weight)
        self.model, self.optimizer = model, optimizer
        # using cuda
        if args.cuda:
            self.model = DataParallelModel(self.model).cuda()
            self.criterion = DataParallelCriterion(self.criterion).cuda()
        # resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        # clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0
        # lr scheduler
        self.scheduler = utils.LR_Scheduler(args.lr_scheduler, args.lr,
                                            args.epochs, len(self.trainloader))
Esempio n. 9
0
def main(args):
    writer = SummaryWriter(log_dir=args.tensorboard_log_dir)
    w, h = map(int, args.input_size.split(','))
    w_target, h_target = map(int, args.input_size_target.split(','))

    joint_transform = joint_transforms.Compose([
        joint_transforms.FreeScale((h, w)),
        joint_transforms.RandomHorizontallyFlip(),
    ])
    normalize = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*normalize),
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.ToPILImage()

    if '5' in args.data_dir:
        dataset = GTA5DataSetLMDB(
            args.data_dir, args.data_list,
            joint_transform=joint_transform,
            transform=input_transform, target_transform=target_transform,
        )
    else:
        dataset = CityscapesDataSetLMDB(
            args.data_dir, args.data_list,
            joint_transform=joint_transform,
            transform=input_transform, target_transform=target_transform,
        )
    loader = data.DataLoader(
        dataset, batch_size=args.batch_size,
        shuffle=True, num_workers=args.num_workers, pin_memory=True
    )
    val_dataset = CityscapesDataSetLMDB(
        args.data_dir_target, args.data_list_target,
        # joint_transform=joint_transform,
        transform=input_transform, target_transform=target_transform
    )
    val_loader = data.DataLoader(
        val_dataset, batch_size=args.batch_size,
        shuffle=False, num_workers=args.num_workers, pin_memory=True
    )


    upsample = nn.Upsample(size=(h_target, w_target),
                           mode='bilinear', align_corners=True)

    net = PSP(
        nclass = args.n_classes, backbone='resnet101', 
        root=args.model_path_prefix, norm_layer=BatchNorm2d,
    )

    params_list = [
        {'params': net.pretrained.parameters(), 'lr': args.learning_rate},
        {'params': net.head.parameters(), 'lr': args.learning_rate*10},
        {'params': net.auxlayer.parameters(), 'lr': args.learning_rate*10},
    ]
    optimizer = torch.optim.SGD(params_list,
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    criterion = SegmentationLosses(nclass=args.n_classes, aux=True, ignore_index=255)
    # criterion = SegmentationMultiLosses(nclass=args.n_classes, ignore_index=255)

    net = DataParallelModel(net).cuda()
    criterion = DataParallelCriterion(criterion).cuda()

    logger = utils.create_logger(args.tensorboard_log_dir, 'PSP_train')
    scheduler = utils.LR_Scheduler(args.lr_scheduler, args.learning_rate,
                                   args.num_epoch, len(loader), logger=logger,
                                   lr_step=args.lr_step)

    net_eval = Eval(net)

    num_batches = len(loader)
    best_pred = 0.0
    for epoch in range(args.num_epoch):

        loss_rec = AverageMeter()
        data_time_rec = AverageMeter()
        batch_time_rec = AverageMeter()

        tem_time = time.time()
        for batch_index, batch_data in enumerate(loader):
            scheduler(optimizer, batch_index, epoch, best_pred)
            show_fig = (batch_index+1) % args.show_img_freq == 0
            iteration = batch_index+1+epoch*num_batches

            net.train()
            img, label, name = batch_data
            img = img.cuda()
            label_cuda = label.cuda()
            data_time_rec.update(time.time()-tem_time)

            output = net(img)
            loss = criterion(output, label_cuda)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss_rec.update(loss.item())
            writer.add_scalar('A_seg_loss', loss.item(), iteration)
            batch_time_rec.update(time.time()-tem_time)
            tem_time = time.time()

            if (batch_index+1) % args.print_freq == 0:
                print(
                    f'Epoch [{epoch+1:d}/{args.num_epoch:d}][{batch_index+1:d}/{num_batches:d}]\t'
                    f'Time: {batch_time_rec.avg:.2f}   '
                    f'Data: {data_time_rec.avg:.2f}   '
                    f'Loss: {loss_rec.avg:.2f}'
                )
            # if show_fig:
            #     # base_lr = optimizer.param_groups[0]["lr"]
            #     output = torch.argmax(output[0][0], dim=1).detach()[0, ...].cpu()
            #     # fig, axes = plt.subplots(2, 1, figsize=(12, 14))
            #     # axes = axes.flat
            #     # axes[0].imshow(colorize_mask(output.numpy()))
            #     # axes[0].set_title(name[0])
            #     # axes[1].imshow(colorize_mask(label[0, ...].numpy()))
            #     # axes[1].set_title(f'seg_true_{base_lr:.6f}')
            #     # writer.add_figure('A_seg', fig, iteration)
            #     output_mask = np.asarray(colorize_mask(output.numpy()))
            #     label = np.asarray(colorize_mask(label[0,...].numpy()))
            #     image_out = np.concatenate([output_mask, label])
            #     writer.add_image('A_seg', image_out, iteration)

        mean_iu = test_miou(net_eval, val_loader, upsample,
                            './style_seg/dataset/info.json')
        torch.save(
            net.module.state_dict(),
            os.path.join(args.save_path_prefix, f'{epoch:d}_{mean_iu*100:.0f}.pth')
        )

    writer.close()
def train(args):
    weight_dir = args.log_root  # os.path.join(args.log_root, 'weights')
    log_dir = os.path.join(
        args.log_root, 'logs',
        'SS-Net-{}'.format(time.strftime("%Y-%m-%d-%H-%M-%S",
                                         time.localtime())))

    data_dir = os.path.join(args.data_root, args.dataset)

    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 1. Setup DataLoader
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #")
    print("> 0. Setting up DataLoader...")
    net_h, net_w = int(args.img_row * args.crop_ratio), int(args.img_col *
                                                            args.crop_ratio)
    augment_train = Compose([
        RandomHorizontallyFlip(),
        RandomSized((0.5, 0.75)),
        RandomRotate(5),
        RandomCrop((net_h, net_w))
    ])
    augment_valid = Compose([
        RandomHorizontallyFlip(),
        Scale((args.img_row, args.img_col)),
        CenterCrop((net_h, net_w))
    ])

    train_loader = CityscapesLoader(data_dir,
                                    gt='gtFine',
                                    split='train',
                                    img_size=(args.img_row, args.img_col),
                                    is_transform=True,
                                    augmentations=augment_train)

    valid_loader = CityscapesLoader(data_dir,
                                    gt='gtFine',
                                    split='val',
                                    img_size=(args.img_row, args.img_col),
                                    is_transform=True,
                                    augmentations=augment_valid)

    num_classes = train_loader.n_classes

    tra_loader = data.DataLoader(train_loader,
                                 batch_size=args.batch_size,
                                 num_workers=int(multiprocessing.cpu_count() /
                                                 2),
                                 shuffle=True)

    val_loader = data.DataLoader(valid_loader,
                                 batch_size=args.batch_size,
                                 num_workers=int(multiprocessing.cpu_count() /
                                                 2))

    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 2. Setup Model
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #")
    print("> 1. Setting up Model...")
    model = RetinaNet(num_classes=num_classes, input_size=(net_h, net_w))
    # model = torch.nn.DataParallel(model, device_ids=[0,1,2]).cuda()
    model = DataParallelModel(model,
                              device_ids=args.device_ids).cuda()  # multi-gpu

    # 2.1 Setup Optimizer
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # Check if model has custom optimizer
    if hasattr(model.module, 'optimizer'):
        print('> Using custom optimizer')
        optimizer = model.module.optimizer
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.learning_rate,
                                    momentum=0.90,
                                    weight_decay=5e-4,
                                    nesterov=True)
        # optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=1e-5)

    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.1)
    # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

    # 2.2 Setup Loss
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    class_weight = np.array([
        0.05570516, 0.32337477, 0.08998544, 1.03602707, 1.03413147, 1.68195437,
        5.58540548, 3.56563995, 0.12704978, 1., 0.46783719, 1.34551528,
        5.29974114, 0.28342531, 0.9396095, 0.81551811, 0.42679146, 3.6399074,
        2.78376194
    ],
                            dtype=float)
    class_weight = torch.from_numpy(class_weight).float().cuda()

    sem_loss = bootstrapped_cross_entropy2d
    sem_loss = DataParallelCriterion(sem_loss, device_ids=args.device_ids)
    se_loss = SemanticEncodingLoss(num_classes=19,
                                   ignore_label=250,
                                   alpha=0.50).cuda()
    se_loss_parallel = DataParallelCriterion(se_loss,
                                             device_ids=args.device_ids)
    """
    # multi-gpu
    bootstrapped_cross_entropy2d = ContextBootstrappedCELoss2D(num_classes=num_classes,
                                                               ignore=250,
                                                               kernel_size=5,
                                                               padding=4,
                                                               dilate=2,
                                                               use_gpu=True)
    loss_sem = DataParallelCriterion(bootstrapped_cross_entropy2d, device_ids=[0, 1]) 
    """

    # 2.3 Setup Metrics
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # !!!!! Here Metrics !!!!!
    metrics = RunningScore(num_classes)  # num_classes = 93

    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 3. Resume Model
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #")
    print("> 2. Model state init or resume...")
    args.start_epoch = 1
    args.start_iter = 0
    beat_map = 0.
    if args.resume is not None:
        full_path = os.path.join(os.path.join(weight_dir, 'train_model'),
                                 args.resume)
        if os.path.isfile(full_path):
            print("> Loading model and optimizer from checkpoint '{}'".format(
                args.resume))

            checkpoint = torch.load(full_path)

            args.start_epoch = checkpoint['epoch']
            args.start_iter = checkpoint['iter']
            beat_map = checkpoint['beat_map']
            model.load_state_dict(checkpoint['model_state'])  # weights
            optimizer.load_state_dict(
                checkpoint['optimizer_state'])  # gradient state
            del checkpoint

            print("> Loaded checkpoint '{}' (epoch {}, iter {})".format(
                args.resume, args.start_epoch, args.start_iter))

        else:
            print("> No checkpoint found at '{}'".format(full_path))
            raise Exception("> No checkpoint found at '{}'".format(full_path))
    else:
        # init_weights(model, pi=0.01,
        #              pre_trained=os.path.join(args.log_root, 'resnet50_imagenet.pth'))

        if args.pre_trained is not None:
            print("> Loading weights from pre-trained model '{}'".format(
                args.pre_trained))
            full_path = os.path.join(args.log_root, args.pre_trained)

            pre_weight = torch.load(full_path)
            prefix = "module.fpn.base_net."

            model_dict = model.state_dict()
            pretrained_dict = {(prefix + k): v
                               for k, v in pre_weight.items()
                               if (prefix + k) in model_dict}

            model_dict.update(pretrained_dict)
            model.load_state_dict(model_dict)

            del pre_weight
            del model_dict
            del pretrained_dict

    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 4. Train Model
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 4.0. Setup tensor-board for visualization
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    writer = None
    if args.tensor_board:
        writer = SummaryWriter(log_dir=log_dir, comment="SSnet_Cityscapes")
        # dummy_input = Variable(torch.rand(1, 3, args.img_row, args.img_col).cuda(), requires_grad=True)
        # writer.add_graph(model, dummy_input)

    print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #")
    print("> 3. Model Training start...")
    topk_init = 512
    num_batches = int(
        math.ceil(
            len(tra_loader.dataset.files[tra_loader.dataset.split]) /
            float(tra_loader.batch_size)))

    # lr_period = 20 * num_batches

    for epoch in np.arange(args.start_epoch - 1, args.num_epochs):
        # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
        # 4.1 Mini-Batch Training
        # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
        model.train()
        topk_base = topk_init

        if epoch == args.start_epoch - 1:
            pbar = tqdm(np.arange(args.start_iter, num_batches))
            start_iter = args.start_iter
        else:
            pbar = tqdm(np.arange(num_batches))
            start_iter = 0

        lr = args.learning_rate

        # lr = adjust_learning_rate(optimizer, init_lr=args.learning_rate, decay_rate=0.1, curr_epoch=epoch,
        #                           epoch_step=20, start_decay_at_epoch=args.start_decay_at_epoch,
        #                           total_epoch=args.num_epochs, mode='exp')

        # scheduler.step()
        # for train_i, (images, gt_masks) in enumerate(tra_loader):  # One mini-Batch datasets, One iteration
        for train_i, (images, gt_masks) in zip(range(start_iter, num_batches),
                                               tra_loader):

            full_iter = (epoch * num_batches) + train_i + 1

            lr = poly_lr_scheduler(optimizer,
                                   init_lr=args.learning_rate,
                                   iter=full_iter,
                                   lr_decay_iter=1,
                                   max_iter=args.num_epochs * num_batches,
                                   power=0.9)

            # lr = args.learning_rate * cosine_annealing_lr(lr_period, full_iter)
            # optimizer = set_optimizer_lr(optimizer, lr)

            images = images.cuda().requires_grad_()
            se_labels = se_loss.unique_encode(gt_masks)
            se_labels = se_labels.cuda()
            gt_masks = gt_masks.cuda()

            topk_base = poly_topk_scheduler(init_topk=topk_init,
                                            iter=full_iter,
                                            topk_decay_iter=1,
                                            max_iter=args.num_epochs *
                                            num_batches,
                                            power=0.95)

            optimizer.zero_grad()

            se, sem_seg_pred = model(images)

            # --------------------------------------------------- #
            # Compute loss
            # --------------------------------------------------- #
            topk = topk_base * 512
            train_loss = sem_loss(input=sem_seg_pred,
                                  target=gt_masks,
                                  K=topk,
                                  weight=None)
            train_se_loss = se_loss_parallel(predicts=se,
                                             enc_cls_target=se_labels,
                                             size_average=True,
                                             reduction='elementwise_mean')

            loss = train_loss + args.alpha * train_se_loss
            loss.backward()  # back-propagation

            torch.nn.utils.clip_grad_norm_(model.parameters(), 1e3)
            optimizer.step()  # parameter update based on the current gradient

            pbar.update(1)
            pbar.set_description("> Epoch [%d/%d]" %
                                 (epoch + 1, args.num_epochs))
            pbar.set_postfix(Train_Loss=train_loss.item(),
                             Train_SE_Loss=train_se_loss.item(),
                             TopK=topk_base)
            # pbar.set_postfix(Train_Loss=train_loss.item(), TopK=topk_base)

            # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
            # 4.1.1 Verbose training process
            # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
            if (train_i + 1) % args.verbose_interval == 0:
                # ---------------------------------------- #
                # 1. Training Losses
                # ---------------------------------------- #
                loss_log = "Epoch [%d/%d], Iter: %d Loss1: \t %.4f " % (
                    epoch + 1, args.num_epochs, train_i + 1, loss.item())

                # ---------------------------------------- #
                # 2. Training Metrics
                # ---------------------------------------- #
                sem_seg_pred = F.softmax(sem_seg_pred, dim=1)
                pred = sem_seg_pred.data.max(1)[1].cpu().numpy()
                gt = gt_masks.data.cpu().numpy()

                metrics.update(
                    gt,
                    pred)  # accumulate the metrics (confusion_matrix and ious)
                score, _ = metrics.get_scores()

                metric_log = ""
                for k, v in score.items():
                    metric_log += " {}: \t %.4f, ".format(k) % v
                metrics.reset()  # reset the metrics for each train_i steps

                logs = loss_log + metric_log

                if args.tensor_board:
                    writer.add_scalar('Training/Train_Loss', train_loss.item(),
                                      full_iter)
                    writer.add_scalar('Training/Train_SE_Loss',
                                      train_se_loss.item(), full_iter)
                    writer.add_scalar('Training/Loss', loss.item(), full_iter)
                    writer.add_scalar('Training/Lr', lr, full_iter)
                    writer.add_scalars('Training/Metrics', score, full_iter)
                    writer.add_text('Training/Text', logs, full_iter)

                    for name, param in model.named_parameters():
                        writer.add_histogram(name,
                                             param.clone().cpu().data.numpy(),
                                             full_iter)
            """
            # each 2000 iterations save model
            if (train_i + 1) % args.iter_interval_save_model == 0:
                pbar.set_postfix(Loss=train_loss.item(), lr=lr)

                state = {"epoch": epoch + 1,
                         "iter": train_i + 1,
                         'beat_map': beat_map,
                         "model_state": model.state_dict(),
                         "optimizer_state": optimizer.state_dict()}

                save_dir = os.path.join(os.path.join(weight_dir, 'train_model'),
                                        "ssnet_model_sem_se_{}epoch_{}iter.pkl".format(epoch+1, train_i+1))
                torch.save(state, save_dir)
            """

        # end of this training phase
        state = {
            "epoch": epoch + 1,
            "iter": num_batches,
            'beat_map': beat_map,
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict()
        }

        save_dir = os.path.join(
            os.path.join(args.log_root, 'train_model'),
            "ssnet_model_sem_se_{}_{}epoch_{}iter.pkl".format(
                args.model_details, epoch + 1, num_batches))
        torch.save(state, save_dir)

        # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
        # 4.2 Mini-Batch Validation
        # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
        model.eval()

        val_loss = 0.0
        vali_count = 0

        with torch.no_grad():
            for i_val, (images_val, gt_masks_val) in enumerate(val_loader):
                vali_count += 1

                images_val = images_val.cuda()
                se_labels_val = se_loss.unique_encode(gt_masks_val)
                se_labels_val = se_labels_val.cuda()
                gt_masks_val = gt_masks_val.cuda()

                se_val, sem_seg_pred_val = model(images_val)

                # !!!!!! Loss !!!!!!
                topk_val = topk_base * 512
                loss = sem_loss(sem_seg_pred_val, gt_masks_val, topk_val, weight=None) + \
                       args.alpha * se_loss_parallel(predicts=se_val, enc_cls_target=se_labels_val,
                                                     size_average=True, reduction='elementwise_mean')
                val_loss += loss.item()

                # accumulating the confusion matrix and ious
                sem_seg_pred_val = F.softmax(sem_seg_pred_val, dim=1)
                pred = sem_seg_pred_val.data.max(1)[1].cpu().numpy()
                gt = gt_masks_val.data.cpu().numpy()
                metrics.update(gt, pred)

            # ---------------------------------------- #
            # 1. Validation Losses
            # ---------------------------------------- #
            val_loss /= vali_count

            loss_log = "Epoch [%d/%d], Loss: \t %.4f" % (
                epoch + 1, args.num_epochs, val_loss)

            # ---------------------------------------- #
            # 2. Validation Metrics
            # ---------------------------------------- #
            metric_log = ""
            score, _ = metrics.get_scores()
            for k, v in score.items():
                metric_log += " {}: \t %.4f, ".format(k) % v
            metrics.reset()  # reset the metrics

            logs = loss_log + metric_log

            pbar.set_postfix(
                Vali_Loss=val_loss, Lr=lr,
                Vali_mIoU=score['Mean_IoU'])  # Train_Loss=train_loss.item()

            if args.tensor_board:
                writer.add_scalar('Validation/Loss', val_loss, epoch)
                writer.add_scalars('Validation/Metrics', score, epoch)
                writer.add_text('Validation/Text', logs, epoch)

                for name, param in model.named_parameters():
                    writer.add_histogram(name,
                                         param.clone().cpu().data.numpy(),
                                         epoch)

        # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
        # 4.3 End of one Epoch
        # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
        # !!!!! Here choose suitable Metric for the best model selection !!!!!

        if score['Mean_IoU'] >= beat_map:
            beat_map = score['Mean_IoU']
            state = {
                "epoch": epoch + 1,
                "beat_map": beat_map,
                "model_state": model.state_dict(),
                "optimizer_state": optimizer.state_dict()
            }

            save_dir = os.path.join(
                weight_dir,
                "SSnet_best_sem_se_{}_model.pkl".format(args.model_details))
            torch.save(state, save_dir)

        # Note that step should be called after validate()
        pbar.close()

    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 4.4 End of Training process
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    if args.tensor_board:
        # export scalar datasets to JSON for external processing
        # writer.export_scalars_to_json("{}/all_scalars.json".format(log_dir))
        writer.close()
    print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #")
    print("> Training Done!!!")
    print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #")
class Trainer():
    def __init__(self, args):
        self.args = args
        # data transforms
        input_transform = transform.Compose([
            transform.ToTensor(),
            transform.Normalize([.485, .456, .406], [.229, .224, .225])
        ])
        # dataset
        trainset = get_segmentation_dataset(args.dataset,
                                            split='train',
                                            transform=input_transform)
        testset = get_segmentation_dataset(args.dataset,
                                           split='val',
                                           transform=input_transform)
        # dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True} \
            if args.cuda else {}
        self.trainloader = data.DataLoader(trainset,
                                           batch_size=args.batch_size,
                                           drop_last=True,
                                           shuffle=True,
                                           **kwargs)
        self.valloader = data.DataLoader(testset,
                                         batch_size=args.batch_size,
                                         drop_last=False,
                                         shuffle=False,
                                         **kwargs)
        self.nclass = trainset.num_class
        # model
        model = get_segmentation_model(args.model,
                                       dataset=args.dataset,
                                       backbone=args.backbone,
                                       aux=args.aux,
                                       se_loss=args.se_loss,
                                       norm_layer=BatchNorm2d)
        #print(model)
        teacher_model = get_segmentation_model('encnet',
                                               dataset=args.dataset,
                                               backbone='resnet50',
                                               aux=True,
                                               se_loss=True,
                                               norm_layer=BatchNorm2d)
        #print(teacher_model)
        checkpoint = torch.load(args.resume_teacher)
        teacher_model.load_state_dict(checkpoint)
        self.teacher_model = teacher_model
        self.teacher_model.eval()

        # optimizer using different LR
        params_list = [
            {
                'params': model.pretrained.parameters(),
                'lr': args.lr
            },
        ]
        if hasattr(model, 'head'):
            params_list.append({
                'params': model.head.parameters(),
                'lr': args.lr * 10
            })
        if hasattr(model, 'auxlayer'):
            params_list.append({
                'params': model.auxlayer.parameters(),
                'lr': args.lr * 10
            })
        optimizer = torch.optim.SGD(params_list,
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
        # clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0
        # criterions
        self.criterion = SegmentationLosses(se_loss=args.se_loss,
                                            aux=args.aux,
                                            nclass=self.nclass)
        self.criterion_kd = KDLosses(se_loss=args.se_loss,
                                     aux=args.aux,
                                     nclass=self.nclass)
        #self.criterion_kd = torch.nn.L1Loss()

        self.model, self.optimizer = model, optimizer
        # using cuda
        if args.cuda:
            self.model = DataParallelModel(self.model).cuda()
            self.teacher_model = DataParallelModel(self.teacher_model).cuda()
            self.criterion = DataParallelCriterion(self.criterion).cuda()
            self.criterion_kd = DataParallelCriterionKD(
                self.criterion_kd).cuda()
        # resuming checkpoint
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        # lr scheduler
        self.scheduler = utils.LR_Scheduler(args.lr_scheduler, args.lr,
                                            args.epochs, len(self.trainloader))
        self.best_pred = 0.0

    def training(self, epoch):
        train_loss = 0.0
        teacher_loss = 0.0
        self.model.train()
        tbar = tqdm(self.trainloader)
        for i, (image, target) in enumerate(tbar):
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            if torch_ver == "0.3":
                image = Variable(image)
                target = Variable(target)
            outputs = self.model(image)
            with torch.no_grad():
                teacher_outputs = self.teacher_model(image)
                teacher_targets = []
                for teacher_output in teacher_outputs:
                    pred1, se_pred, pred2 = tuple(teacher_output)
                    teacher_targets.append(pred1)
                teacher_target = torch.cat(tuple(teacher_targets), 0)
                teacher_target = teacher_target.detach()

            loss_seg = 0
            loss_seg = self.criterion(outputs, target)
            loss_seg.backward(retain_graph=True)
            train_loss += loss_seg.item()

            #loss_kd = self.criterion_kd(outputs, teacher_target)
            loss_kd = self.criterion_kd(outputs, teacher_target)
            loss_kd.backward()
            teacher_loss += loss_kd.item()
            loss = loss_seg + loss_kd
            #loss.backward()
            self.optimizer.step()
            tbar.set_description('Train loss: %.3f, Teacher loss: %.3f' %
                                 (train_loss / (i + 1), teacher_loss /
                                  (i + 1)))

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            utils.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, self.args, is_best)

    def validation(self, epoch):
        # Fast test during the training
        def eval_batch(model, image, target):
            outputs = model(image)
            outputs = gather(outputs, 0, dim=0)
            pred = outputs[0]
            target = target.cuda()
            correct, labeled = utils.batch_pix_accuracy(pred.data, target)
            inter, union = utils.batch_intersection_union(
                pred.data, target, self.nclass)
            return correct, labeled, inter, union

        is_best = False
        self.model.eval()
        total_inter, total_union, total_correct, total_label = 0, 0, 0, 0
        tbar = tqdm(self.valloader, desc='\r')
        for i, (image, target) in enumerate(tbar):
            if torch_ver == "0.3":
                image = Variable(image, volatile=True)
                correct, labeled, inter, union = eval_batch(
                    self.model, image, target)
            else:
                with torch.no_grad():
                    correct, labeled, inter, union = eval_batch(
                        self.model, image, target)

            total_correct += correct
            total_label += labeled
            total_inter += inter
            total_union += union
            pixAcc = 1.0 * total_correct / (np.spacing(1) + total_label)
            IoU = 1.0 * total_inter / (np.spacing(1) + total_union)
            mIoU = IoU.mean()
            tbar.set_description('pixAcc: %.3f, mIoU: %.3f' % (pixAcc, mIoU))

        new_pred = (pixAcc + mIoU) / 2
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            utils.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, self.args, is_best)
Esempio n. 12
0
def test(args):
    # data transforms
    input_transform = transform.Compose([
        transform.ToTensor(),
        transform.Normalize([.485, .456, .406], [.229, .224, .225])
    ])

    # dataset
    if args.eval:  # set split='val' for validation set testing
        testset = get_edge_dataset(args.dataset,
                                   split='val',
                                   mode='testval',
                                   transform=input_transform,
                                   crop_size=args.crop_size)
    else:  # set split='vis' for visulization
        testset = get_edge_dataset(args.dataset,
                                   split='vis',
                                   mode='vis',
                                   transform=input_transform,
                                   crop_size=args.crop_size)

    # output folder
    if args.eval:
        outdir_list_side5 = []
        outdir_list_fuse = []
        for i in range(testset.num_class):
            outdir_side5 = '%s/%s/%s_val/side5/class_%03d' % (
                args.dataset, args.model, args.checkname, i + 1)
            if not os.path.exists(outdir_side5):
                os.makedirs(outdir_side5)
            outdir_list_side5.append(outdir_side5)

            outdir_fuse = '%s/%s/%s_val/fuse/class_%03d' % (
                args.dataset, args.model, args.checkname, i + 1)
            if not os.path.exists(outdir_fuse):
                os.makedirs(outdir_fuse)
            outdir_list_fuse.append(outdir_fuse)

    else:
        outdir = '%s/%s/%s_vis' % (args.dataset, args.model, args.checkname)
        if not os.path.exists(outdir):
            os.makedirs(outdir)

    # dataloader
    loader_kwargs = {'num_workers': args.workers, 'pin_memory': True} \
        if args.cuda else {}
    test_data = data.DataLoader(testset,
                                batch_size=args.test_batch_size,
                                drop_last=False,
                                shuffle=False,
                                collate_fn=test_batchify_fn,
                                **loader_kwargs)

    model = get_edge_model(
        args.model,
        dataset=args.dataset,
        backbone=args.backbone,
        norm_layer=BatchNorm2d,
        crop_size=args.crop_size,
    )

    # resuming checkpoint
    if args.resume is None or not os.path.isfile(args.resume):
        raise RuntimeError("=> no checkpoint found at '{}'".format(
            args.resume))
    checkpoint = torch.load(args.resume)
    # strict=False, so that it is compatible with old pytorch saved models
    model.load_state_dict(checkpoint['state_dict'], strict=False)

    if args.cuda:
        model = DataParallelModel(model).cuda()
    print(model)

    model.eval()
    tbar = tqdm(test_data)

    if args.eval:
        for i, (images, im_paths, im_sizes) in enumerate(tbar):
            with torch.no_grad():
                images = [image.unsqueeze(0) for image in images]
                images = torch.cat(images, 0)
                outputs = model(images.float())

                num_gpus = len(os.environ['CUDA_VISIBLE_DEVICES'].split(','))
                if num_gpus == 1:
                    outputs = [outputs]

                # extract the side5 output and fuse output from outputs
                side5_list = []
                fuse_list = []
                for i in range(len(outputs)):  #iterate for n (gpu counts)
                    im_size = tuple(im_sizes[i].numpy())
                    output = outputs[i]

                    side5 = output[0].squeeze_()
                    side5 = side5.sigmoid_().cpu().numpy()
                    side5 = side5[:, 0:im_size[1], 0:im_size[0]]

                    fuse = output[1].squeeze_()
                    fuse = fuse.sigmoid_().cpu().numpy()
                    fuse = fuse[:, 0:im_size[1], 0:im_size[0]]

                    side5_list.append(side5)
                    fuse_list.append(fuse)

                for predict, impath in zip(side5_list, im_paths):
                    for i in range(predict.shape[0]):
                        predict_c = predict[i]
                        path = os.path.join(outdir_list_side5[i], impath)
                        io.imsave(path, predict_c)

                for predict, impath in zip(fuse_list, im_paths):
                    for i in range(predict.shape[0]):
                        predict_c = predict[i]
                        path = os.path.join(outdir_list_fuse[i], impath)
                        io.imsave(path, predict_c)
    else:
        for i, (images, masks, im_paths, im_sizes) in enumerate(tbar):
            with torch.no_grad():
                images = [image.unsqueeze(0) for image in images]
                images = torch.cat(images, 0)
                outputs = model(images.float())

                num_gpus = len(os.environ['CUDA_VISIBLE_DEVICES'].split(','))
                if num_gpus == 1:
                    outputs = [outputs]

                # extract the side5 output and fuse output from outputs
                side5_list = []
                fuse_list = []
                for i in range(len(outputs)):  #iterate for n (gpu counts)
                    im_size = tuple(im_sizes[i].numpy())
                    output = outputs[i]

                    side5 = output[0].squeeze_()
                    side5 = side5.sigmoid_().cpu().numpy()
                    side5 = side5[:, 0:im_size[1], 0:im_size[0]]

                    fuse = output[1].squeeze_()
                    fuse = fuse.sigmoid_().cpu().numpy()
                    fuse = fuse[:, 0:im_size[1], 0:im_size[0]]

                    side5_list.append(side5)
                    fuse_list.append(fuse)

                # visualize ground truth
                for gt, impath in zip(masks, im_paths):
                    outname = os.path.splitext(impath)[0] + '_gt.png'
                    path = os.path.join(outdir, outname)
                    visualize_prediction(args.dataset, path, gt)

                # visualize side5 output
                for predict, impath in zip(side5_list, im_paths):
                    outname = os.path.splitext(impath)[0] + '_side5.png'
                    path = os.path.join(outdir, outname)
                    visualize_prediction(args.dataset, path, predict)

                # visualize fuse output
                for predict, impath in zip(fuse_list, im_paths):
                    outname = os.path.splitext(impath)[0] + '_fuse.png'
                    path = os.path.join(outdir, outname)
                    visualize_prediction(args.dataset, path, predict)
Esempio n. 13
0
class Trainer():
    def __init__(self, args):

        self.args = args
        if not self.args.tblogger:
            self.tblogger = SummaryWriter('./tensorboardX/')

        # data transforms
        input_transform = transform.Compose([
            transform.ToTensor(),
            transform.Normalize([.485, .456, .406], [.229, .224, .225])
        ])
        # dataset
        data_kwargs = {
            'transform': input_transform,
            'base_size': args.base_size,
            'crop_size': args.crop_size
        }
        trainset = get_segmentation_dataset(args.dataset,
                                            split=args.train_split,
                                            mode='train',
                                            **data_kwargs)
        testset = get_segmentation_dataset(args.dataset,
                                           split='val',
                                           mode='val',
                                           **data_kwargs)
        print('trainset:%d' % len(trainset))
        print('testset:%d' % len(testset))
        # dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True} \
            if args.cuda else {}
        self.trainloader = data.DataLoader(trainset,
                                           batch_size=args.batch_size,
                                           drop_last=True,
                                           shuffle=True,
                                           **kwargs)
        self.valloader = data.DataLoader(testset,
                                         batch_size=args.batch_size,
                                         drop_last=False,
                                         shuffle=False,
                                         **kwargs)
        self.nclass = trainset.num_class
        # model
        model = get_segmentation_model(args.model,
                                       dataset=args.dataset,
                                       backbone=args.backbone,
                                       dilated=args.dilated,
                                       lateral=args.lateral,
                                       jpu=args.jpu,
                                       aux=args.aux,
                                       se_loss=args.se_loss,
                                       norm_layer=SyncBatchNorm,
                                       base_size=args.base_size,
                                       crop_size=args.crop_size)
        print(model)

        # model.apply(inplace_relu)
        # optimizer using different LR
        params_list = [
            {
                'params': model.pretrained.parameters(),
                'lr': args.lr
            },
        ]
        if hasattr(model, 'jpu'):
            params_list.append({
                'params': model.jpu.parameters(),
                'lr': args.lr * 10
            })
        if hasattr(model, 'head'):
            params_list.append({
                'params': model.head.parameters(),
                'lr': args.lr * 10
            })
        if hasattr(model, 'auxlayer'):
            params_list.append({
                'params': model.auxlayer.parameters(),
                'lr': args.lr * 10
            })
        optimizer = torch.optim.SGD(params_list,
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
        # criterions
        self.criterion = SegmentationLosses(se_loss=args.se_loss,
                                            aux=args.aux,
                                            nclass=self.nclass,
                                            se_weight=args.se_weight,
                                            aux_weight=args.aux_weight)

        self.model, self.optimizer = model, optimizer
        # self.model, self.optimizer = amp.initialize(self.model.cuda(), self.optimizer, opt_level="O1")
        # using cuda
        if args.cuda:
            # self.model = torch.nn.parallel.DistributedDataParallel(self.model,device_ids=[0,1])
            # self.model = DDP(self.model)
            # self.criterion = torch.nn.parallel.DistributedDataParallel(self.criterion, find_unused_parameters=True)
            self.model = DataParallelModel(self.model).cuda()
            self.criterion = DataParallelCriterion(self.criterion).cuda()
            # self.model = self.model.cuda()
            # self.criterion = self.criterion.cuda()
        # resuming checkpoint

        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        # clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0
        # lr scheduler
        self.scheduler = utils.LR_Scheduler(args.lr_scheduler, args.lr,
                                            args.epochs, len(self.trainloader))

        self.total_loss = 0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.trainloader)
        for i, (image, target) in enumerate(tbar):
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            if torch_ver == "0.3":
                image = Variable(image)
                target = Variable(target)
            outputs = self.model(image)
            loss = self.criterion(outputs, target)
            # with amp.scale_loss(loss, self.optimizer) as scaled_loss:
            #     scaled_loss.backward()
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))

            if not self.args.tblogger and i % 100 == 0:
                self.tblogger.add_scalar('Train loss', (train_loss / (i + 1)),
                                         i + 1)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            utils.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                },
                self.args,
                is_best,
                filename='checkpoint_{}.pth.tar'.format(epoch))

    def validation(self, epoch):
        # Fast test during the training
        def eval_batch(model, image, target):
            outputs = model(image)
            outputs = gather(outputs, 0, dim=0)
            pred = outputs[0]
            target = target.cuda()
            correct, labeled = utils.batch_pix_accuracy(pred.data, target)
            inter, union = utils.batch_intersection_union(
                pred.data, target, self.nclass)
            return correct, labeled, inter, union

        is_best = False
        self.model.eval()
        total_inter, total_union, total_correct, total_label = 0, 0, 0, 0
        tbar = tqdm(self.valloader, desc='\r')
        for i, (image, target) in enumerate(tbar):
            if torch_ver == "0.3":
                image = Variable(image, volatile=True)
                correct, labeled, inter, union = eval_batch(
                    self.model, image, target)
            else:
                with torch.no_grad():
                    correct, labeled, inter, union = eval_batch(
                        self.model, image, target)

            total_correct += correct
            total_label += labeled
            total_inter += inter
            total_union += union
            pixAcc = 1.0 * total_correct / (np.spacing(1) + total_label)
            IoU = 1.0 * total_inter / (np.spacing(1) + total_union)
            mIoU = IoU.mean()
            tbar.set_description('pixAcc: %.3f, mIoU: %.3f' % (pixAcc, mIoU))

        # new_pred = (pixAcc + mIoU)/2
        new_pred = mIoU
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
        utils.save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': new_pred,
            }, self.args, is_best)
Esempio n. 14
0
    def __init__(self, args):
        self.args = args
        args.log_name = str(args.checkname)
        root_dir = getattr(args, "data_root", '../datasets')
        wo_head = getattr(args, "resume_wo_head", False)

        self.logger = utils.create_logger(args.log_root, args.log_name)
        # data transforms
        input_transform = transform.Compose([
            transform.ToTensor(),
            transform.Normalize([.485, .456, .406], [.229, .224, .225])
        ])
        # dataset
        data_kwargs = {
            'transform': input_transform,
            'base_size': args.base_size,
            'crop_size': args.crop_size,
            'logger': self.logger,
            'scale': args.scale
        }
        trainset = get_segmentation_dataset(args.dataset,
                                            split='train',
                                            mode='train',
                                            root=root_dir,
                                            **data_kwargs)
        testset = get_segmentation_dataset(args.dataset,
                                           split='val',
                                           mode='val',
                                           root=root_dir,
                                           **data_kwargs)
        # dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True} \
            if args.cuda else {}
        self.trainloader = data.DataLoader(trainset,
                                           batch_size=args.batch_size,
                                           drop_last=True,
                                           shuffle=True,
                                           **kwargs)
        self.valloader = data.DataLoader(testset,
                                         batch_size=args.batch_size,
                                         drop_last=False,
                                         shuffle=False,
                                         **kwargs)
        self.nclass = trainset.num_class

        # model
        model = get_segmentation_model(args.model,
                                       dataset=args.dataset,
                                       backbone=args.backbone,
                                       aux=args.aux,
                                       se_loss=args.se_loss,
                                       norm_layer=BatchNorm2d,
                                       base_size=args.base_size,
                                       crop_size=args.crop_size,
                                       multi_grid=args.multi_grid,
                                       multi_dilation=args.multi_dilation)
        #print(model)
        self.logger.info(model)
        # optimizer using different LR

        if not args.wo_backbone:
            params_list = [
                {
                    'params': model.pretrained.parameters(),
                    'lr': args.lr
                },
            ]
        else:
            params_list = []

        if hasattr(model, 'head'):
            params_list.append({
                'params': model.head.parameters(),
                'lr': args.lr * 10
            })
        if hasattr(model, 'auxlayer'):
            params_list.append({
                'params': model.auxlayer.parameters(),
                'lr': args.lr * 10
            })
        optimizer = torch.optim.SGD(params_list,
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
        self.criterion = SegmentationMultiLosses(nclass=self.nclass)
        #self.criterion = SegmentationLosses(se_loss=args.se_loss, aux=args.aux,nclass=self.nclass)

        self.model, self.optimizer = model, optimizer

        # using cuda
        if args.cuda:
            self.model = DataParallelModel(self.model).cuda()
            self.criterion = DataParallelCriterion(self.criterion).cuda()

        # finetune from a trained model
        if args.ft:
            args.start_epoch = 0
            checkpoint = torch.load(args.ft_resume)
            if wo_head:
                print("WITHout HEAD !!!!!!!!!!")
                from collections import OrderedDict
                new = OrderedDict()
                for k, v in checkpoint['state_dict'].items():
                    if not k.startswith("head"):
                        new[k] = v
                checkpoint['state_dict'] = new
            else:
                print("With HEAD !!!!!!!!!!")

            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'],
                                                  strict=False)
            else:
                self.model.load_state_dict(checkpoint['state_dict'],
                                           strict=False)
            # self.logger.info("=> loaded checkpoint '{}' (epoch {})".format(args.ft_resume, checkpoint['epoch']))
        # resuming checkpoint
        if args.resume:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            self.logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        # lr scheduler
        self.scheduler = utils.LR_Scheduler(args.lr_scheduler,
                                            args.lr,
                                            args.epochs,
                                            len(self.trainloader),
                                            logger=self.logger,
                                            lr_step=args.lr_step)
        self.best_pred = 0.0
Esempio n. 15
0
    def __init__(self, args):
        self.args = args
        self.args.start_epoch = 0
        self.args.cuda = True
        # data transforms
        input_transform = transform.Compose([
            transform.ToTensor(),
            transform.Normalize([.490, .490, .490], [.247, .247, .247])]) # TODO: change mean and std
        
        # dataset
        testset = SegmentationDataset(
                    os.path.join(args.imagelist_path, 'test_stage2.csv'),
                    args.image_path,
                    args.masks_path,
                    input_transform=input_transform, 
                    transform_chain=Compose([Resize(self.args.size, self.args.size)], p=1),
                    base_size=480, is_flip=True, is_clahe=True, is_sh_sc_ro=True
        )
        # dataloader
        kwargs = {'num_workers': args.workers }#, 'pin_memory': True} 
        self.testloader = data.DataLoader(testset, batch_size=args.batch_size,
                                           drop_last=False, shuffle=False, **kwargs)
        self.nclass = 1
        model = EncNet(
            nclass=self.nclass, backbone=args.backbone,
            aux=args.aux, se_loss=args.se_loss, norm_layer=SyncBatchNorm
        )
        print(model)

        self.model = model

        # resuming checkpoint
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            args.start_epoch = checkpoint['epoch']
            state_dict = {k[7:] : v for k,v in checkpoint['state_dict'].items()}
            self.model.load_state_dict(state_dict)
            self.best_pred = checkpoint['best_pred']
            if 'best_loss' in checkpoint.keys(): 
                self.best_loss = checkpoint['best_loss']
            else:
                self.best_loss = 0
            print("=> loaded checkpoint '{}' (epoch {}, best pred: {}, best loss, {})"
                  .format(args.resume, checkpoint['epoch'], self.best_pred, self.best_loss))
        
        self.model = DataParallelModel(self.model).cuda()

        self.mode2func = {
            0 : lambda x, y: (x, y),
            1 : apply_hflip,
            2 : lambda x, y: (x, y),
            3 : lambda x, y: apply_revert_shscro(x, y, angle=5, scale=0.9, dx=0, dy=0),
            4 : lambda x, y: apply_revert_shscro(x, y, angle=10, scale=0.9, dx=0, dy=0),
            5 : lambda x, y: apply_revert_shscro(x, y, angle=15, scale=0.9, dx=0, dy=0),
            6 : lambda x, y: apply_revert_shscro(x, y, angle=20, scale=0.9, dx=0, dy=0),
            7 : lambda x, y: apply_revert_shscro(x, y, angle=-5, scale=0.9, dx=0, dy=0),
            8 : lambda x, y: apply_revert_shscro(x, y, angle=-10, scale=0.9, dx=0, dy=0),
            9 : lambda x, y: apply_revert_shscro(x, y, angle=-15, scale=0.9, dx=0, dy=0),
            10 : lambda x, y: apply_revert_shscro(x, y, angle=-20, scale=0.9, dx=0, dy=0),
        }
class Trainer():
    def __init__(self, args):
        self.args = args
        args.log_name = str(args.checkname)
        self.logger = utils.create_logger(args.log_root, args.log_name)
        # data transforms
        input_transform = transform.Compose([
            transform.ToTensor(),
            # transform.Normalize([.485, .456, .406], [.229, .224, .225])
            ])
        # dataset
        data_kwargs = {'transform': input_transform, 'base_size': args.base_size, 'crop_size': args.crop_size, 'logger': self.logger, 'scale': args.scale}
        trainset = get_segmentation_dataset(args.dataset, split='train', mode='train', **data_kwargs)
        testset = get_segmentation_dataset(args.dataset, split='val', mode='val', **data_kwargs)
        # dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True} if args.cuda else {}
        self.trainloader = data.DataLoader(trainset, batch_size=args.batch_size, drop_last=True, shuffle=True, **kwargs)
        self.valloader = data.DataLoader(testset, batch_size=args.batch_size, drop_last=False, shuffle=False, **kwargs)
        self.nclass = trainset.num_class

        self.confusion_matrix_weather = utils.ConfusionMatrix(7)
        self.confusion_matrix_timeofday = utils.ConfusionMatrix(4)

        # model
        model = get_segmentation_model(args.model, dataset=args.dataset, backbone=args.backbone, aux=args.aux, se_loss=args.se_loss,
                                       # norm_layer=BatchNorm2d, # for multi-gpu
                                       base_size=args.base_size, crop_size=args.crop_size, multi_grid=args.multi_grid, multi_dilation=args.multi_dilation)

        #####################################################################
        self.logger.info(model)
        # optimizer using different LR
        params_list = [{'params': model.pretrained.parameters(), 'lr': 1 * args.lr},]
        if hasattr(model, 'head'):
            params_list.append({'params': model.head.parameters(), 'lr': 1 * args.lr*10})
        if hasattr(model, 'auxlayer'):
            params_list.append({'params': model.auxlayer.parameters(), 'lr': 1 * args.lr*10})
        params_list.append({'params': model.weather_classifier.parameters(), 'lr': 0 * args.lr*10})
        params_list.append({'params': model.time_classifier.parameters(), 'lr': 0 * args.lr*10})
        optimizer = torch.optim.SGD(params_list, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

        # self.criterion = SegmentationMultiLosses(nclass=self.nclass)
        self.criterion = SegmentationLosses(se_loss=args.se_loss, aux=args.aux, nclass=self.nclass)
        # self.criterion = torch.nn.CrossEntropyLoss()
        #####################################################################

        self.model, self.optimizer = model, optimizer
        # using cuda
        if args.cuda:
            self.model = DataParallelModel(self.model).cuda()
            self.criterion = DataParallelCriterion(self.criterion).cuda()
        # finetune from a trained model
        if args.ft:
            args.start_epoch = 0
            checkpoint = torch.load(args.ft_resume)
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'], strict=False)
            else:
                self.model.load_state_dict(checkpoint['state_dict'], strict=False)
            self.logger.info("=> loaded checkpoint '{}' (epoch {})".format(args.ft_resume, checkpoint['epoch']))
        # resuming checkpoint
        if args.resume:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            self.logger.info("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
        # lr scheduler
        self.scheduler = utils.LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.trainloader), logger=self.logger, lr_step=args.lr_step)
        self.best_pred = 0.0

        self.logger.info(self.args)

    def training(self, epoch):
        train_loss = 0.0

        ################################################
        self.model.train()
        self.model.module.weather_classifier.eval()
        self.model.module.time_classifier.eval()
        # self.model.eval()
        # self.model.module.weather_classifier.train()
        # self.model.module.time_classifier.train()
        ################################################

        tbar = tqdm(self.trainloader)

        for i, (image, target, weather, timeofday, scene) in enumerate(tbar):
            weather = weather.cuda(); timeofday = timeofday.cuda()
            ################################################
            # self.scheduler(self.optimizer, i, epoch, self.best_pred)
            ################################################
            self.optimizer.zero_grad()
            if torch_ver == "0.3":
                image = Variable(image)
                target = Variable(target)
            outputs, weather_o, timeofday_o = self.model(image)

            # create weather / timeofday target mask #######################
            b, _, h, w = weather_o.size()
            weather_t = torch.ones((b, h, w)).long().cuda()
            for bi in range(b): weather_t[bi] *= weather[bi]
            timeofday_t = torch.ones((b, h, w)).long().cuda()
            for bi in range(b): timeofday_t[bi] *= timeofday[bi]
            ################################################################

            loss = self.criterion(outputs, target)
            # loss = self.criterion(weather_o, weather_t) + self.criterion(timeofday_o, timeofday_t)

            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
        self.logger.info('Train loss: %.3f' % (train_loss / (i + 1)))

        # save checkpoint every 5 epoch
        is_best = False
        if epoch % 5 == 0:
            # filename = "checkpoint_%s.pth.tar"%(epoch+1)
            filename = "checkpoint_%s.%s.%s.%s.pth.tar"%(self.args.log_root, self.args.checkname, self.args.model, epoch+1)
            utils.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
                }, self.args, is_best, filename)


    def validation(self, epoch=None):
        # Fast test during the training
        def eval_batch(model, image, target, weather, timeofday, scene):
            outputs, weather_o, timeofday_o = model(image)
            # Gathers tensors from different GPUs on a specified device
            # outputs = gather(outputs, 0, dim=0)
            pred = outputs[0]

            b, _, h, w = weather_o.size()
            weather_t = torch.ones((b, h, w)).long()
            for bi in range(b): weather_t[bi] *= weather[bi]
            timeofday_t = torch.ones((b, h, w)).long()
            for bi in range(b): timeofday_t[bi] *= timeofday[bi]

            self.confusion_matrix_weather.update([ m.astype(np.int64) for m in weather_t.numpy() ], weather_o.cpu().numpy().argmax(1))
            self.confusion_matrix_timeofday.update([ m.astype(np.int64) for m in timeofday_t.numpy() ], timeofday_o.cpu().numpy().argmax(1))

            correct, labeled = utils.batch_pix_accuracy(pred.data, target)
            inter, union = utils.batch_intersection_union(pred.data, target, self.nclass)

            correct_weather, labeled_weather = utils.batch_pix_accuracy(weather_o.data, weather_t)
            correct_timeofday, labeled_timeofday = utils.batch_pix_accuracy(timeofday_o.data, timeofday_t)
            return correct, labeled, inter, union, correct_weather, labeled_weather, correct_timeofday, labeled_timeofday

        is_best = False
        self.model.eval()
        total_inter, total_union, total_correct, total_label = 0, 0, 0, 0
        total_correct_weather = 0; total_label_weather = 0; total_correct_timeofday = 0; total_label_timeofday = 0
        name2inter = {}; name2union = {}
        tbar = tqdm(self.valloader, desc='\r')

        for i, (image, target, weather, timeofday, scene, name) in enumerate(tbar):
            if torch_ver == "0.3":
                image = Variable(image, volatile=True)
                correct, labeled, inter, union, correct_weather, labeled_weather, correct_timeofday, labeled_timeofday = eval_batch(self.model, image, target, weather, timeofday, scene)
            else:
                with torch.no_grad():
                    correct, labeled, inter, union, correct_weather, labeled_weather, correct_timeofday, labeled_timeofday = eval_batch(self.model, image, target, weather, timeofday, scene)

            total_correct += correct
            total_label += labeled
            total_inter += inter
            total_union += union
            pixAcc = 1.0 * total_correct / (np.spacing(1) + total_label)
            IoU = 1.0 * total_inter / (np.spacing(1) + total_union)
            mIoU = IoU.mean()
            name2inter[name[0]] = inter.tolist()
            name2union[name[0]] = union.tolist()

            total_correct_weather += correct_weather
            total_label_weather += labeled_weather
            pixAcc_weather = 1.0 * total_correct_weather / (np.spacing(1) + total_label_weather)
            total_correct_timeofday += correct_timeofday
            total_label_timeofday += labeled_timeofday
            pixAcc_timeofday = 1.0 * total_correct_timeofday / (np.spacing(1) + total_label_timeofday)

            tbar.set_description('pixAcc: %.2f, mIoU: %.2f, weather: %.2f, timeofday: %.2f' % (pixAcc, mIoU, pixAcc_weather, pixAcc_timeofday))
        self.logger.info('pixAcc: %.3f, mIoU: %.3f, pixAcc_weather: %.3f, pixAcc_timeofday: %.3f' % (pixAcc, mIoU, pixAcc_weather, pixAcc_timeofday))
        with open("name2inter", 'w') as fp:
            json.dump(name2inter, fp)
        with open("name2union", 'w') as fp:
            json.dump(name2union, fp)

        cm = self.confusion_matrix_weather.get_scores()['cm']
        self.logger.info(str(cm))
        self.confusion_matrix_weather.reset()
        cm = self.confusion_matrix_timeofday.get_scores()['cm']
        self.logger.info(str(cm))
        self.confusion_matrix_timeofday.reset()

        if epoch is not None:
            new_pred = (pixAcc + mIoU) / 2
            if new_pred > self.best_pred:
                is_best = True
                self.best_pred = new_pred
                utils.save_checkpoint({
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, self.args, is_best)
Esempio n. 17
0
from osgeo import gdal
import threading
gdal.AllRegister()
# Get the model
checkpoint = torch.load(
    r'E:\Project\PyTorch-Encoding\runs\arcs\deeplab\resnest269\model_best.pth.tar\model_best.pth.tar'
)
model = get_segmentation_model("deeplab",
                               dataset="arcs",
                               backbone="resnest269",
                               aux=True,
                               se_loss=False,
                               norm_layer=SyncBatchNorm,
                               base_size=128,
                               crop_size=128)
model = DataParallelModel(model).cuda()
model.module.load_state_dict(checkpoint['state_dict'])
model.eval()


def processData(tmpName):
    oriTileDir = "F:\\色林错\\dataSet\\" + str(tmpName) + r"\OriginTileData"
    # maskTileDir = "F:\\色林错\\dataSet\\" + str(tmpName) + r"\MaskTileData"
    tmpDir = "F:\\色林错\\dataSet\\" + str(tmpName) + r"\tmpTrainTest"
    if not os.path.exists(tmpDir):
        os.makedirs(tmpDir)
    length = dataSet[tmpName]["length"]
    for i in range(length):

        filename = oriTileDir + "\\" + str(i) + ".tif"
        img = encoding.utils.load_image(filename)
Esempio n. 18
0
    def __init__(self, args):
        self.args = args
        args.log_name = str(args.checkname)
        self.logger = utils.create_logger(args.log_root, args.log_name)
        # data transforms
        input_transform = transform.Compose([
            transform.ToTensor(),
            transform.Normalize([.485, .456, .406], [.229, .224, .225])
        ])
        # dataset
        data_kwargs = {
            'transform': input_transform,
            'base_size': args.base_size,
            'crop_size': args.crop_size,
            'logger': self.logger,
            'scale': args.scale
        }
        trainset = get_segmentation_dataset(args.dataset,
                                            split='train',
                                            mode='train',
                                            **data_kwargs)
        testset = get_segmentation_dataset(args.dataset,
                                           split='val',
                                           mode='val',
                                           **data_kwargs)
        # dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True} \
            if args.cuda else {}
        self.trainloader = data.DataLoader(trainset,
                                           batch_size=args.batch_size,
                                           drop_last=True,
                                           shuffle=True,
                                           **kwargs)
        self.valloader = data.DataLoader(testset,
                                         batch_size=args.batch_size,
                                         drop_last=False,
                                         shuffle=False,
                                         **kwargs)
        self.nclass = trainset.num_class

        # model
        model = get_segmentation_model(args.model,
                                       dataset=args.dataset,
                                       backbone=args.backbone,
                                       aux=args.aux,
                                       se_loss=args.se_loss,
                                       norm_layer=BatchNorm2d,
                                       base_size=args.base_size,
                                       crop_size=args.crop_size,
                                       multi_grid=args.multi_grid,
                                       multi_dilation=args.multi_dilation)
        #print(model)
        self.logger.info(model)
        # optimizer using different LR
        params_list = [
            {
                'params': model.pretrained.parameters(),
                'lr': args.lr
            },
        ]
        if hasattr(model, 'head'):
            params_list.append({
                'params': model.head.parameters(),
                'lr': args.lr * 10
            })
        if hasattr(model, 'auxlayer'):
            params_list.append({
                'params': model.auxlayer.parameters(),
                'lr': args.lr * 10
            })

        cityscape_weight = torch.FloatTensor([
            0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754, 1.0489,
            0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 1.0865, 1.0955,
            1.0865, 1.1529, 1.0507
        ])

        optimizer = torch.optim.SGD(params_list,
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
        #weight for class imbalance
        # self.criterion = SegmentationMultiLosses(nclass=self.nclass, weight=cityscape_weight)
        self.criterion = SegmentationMultiLosses(nclass=self.nclass)
        #self.criterion = SegmentationLosses(se_loss=args.se_loss, aux=args.aux,nclass=self.nclass)

        self.model, self.optimizer = model, optimizer
        # using cuda
        if args.cuda:
            self.model = DataParallelModel(self.model).cuda()
            self.criterion = DataParallelCriterion(self.criterion).cuda()
        # finetune from a trained model
        if args.ft:
            args.start_epoch = 0
            checkpoint = torch.load(args.ft_resume)
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'],
                                                  strict=False)
            else:
                self.model.load_state_dict(checkpoint['state_dict'],
                                           strict=False)
            self.logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                args.ft_resume, checkpoint['epoch']))
        # resuming checkpoint
        if args.resume:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            self.logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # lr scheduler
        self.scheduler = utils.LR_Scheduler(args.lr_scheduler,
                                            args.lr,
                                            args.epochs,
                                            len(self.trainloader),
                                            logger=self.logger,
                                            lr_step=args.lr_step)
        self.best_pred = 0.0
Esempio n. 19
0
    def __init__(self, args):
        if args.se_loss:
            args.checkname = args.checkname + "_se"

        self.args = args
        # data transforms
        input_transform = transform.Compose([
            transform.ToTensor(),
            transform.Normalize([.485, .456, .406], [.229, .224, .225])])
        # dataset
        data_kwargs = {'transform': input_transform, 'base_size': args.base_size,
                       'crop_size': args.crop_size}
        trainset = get_segmentation_dataset(args.dataset, split='train', mode='train',
                                           **data_kwargs)
        testset = get_segmentation_dataset(args.dataset, split='val', mode ='val',
                                           **data_kwargs)
        # dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': False} \
            if args.cuda else {}
        self.trainloader = data.DataLoader(trainset, batch_size=args.batch_size,
                                           drop_last=True, shuffle=True, **kwargs)
        self.valloader = data.DataLoader(testset, batch_size=args.batch_size,
                                         drop_last=False, shuffle=False, **kwargs)
        self.nclass = trainset.num_class
        # model
        model = get_segmentation_model(args.model, dataset=args.dataset,
                                       backbone = args.backbone, aux = args.aux,
                                       se_loss = args.se_loss, norm_layer = BatchNorm2d,
                                       base_size=args.base_size, crop_size=args.crop_size)
        print(model)

        # count parameter number
        pytorch_total_params = sum(p.numel() for p in model.parameters())
        print("Total number of parameters: %d"%pytorch_total_params)

        # optimizer using different LR
        params_list = [{'params': model.pretrained.parameters(), 'lr': args.lr},]
        if hasattr(model, 'head'):
            if args.diflr:
                params_list.append({'params': model.head.parameters(), 'lr': args.lr*10})
            else:
                params_list.append({'params': model.head.parameters(), 'lr': args.lr})
        if hasattr(model, 'auxlayer'):
            if args.diflr:
                params_list.append({'params': model.auxlayer.parameters(), 'lr': args.lr*10})
            else:
                params_list.append({'params': model.auxlayer.parameters(), 'lr': args.lr})

        optimizer = torch.optim.SGD(params_list,
                    lr=args.lr,
                    momentum=args.momentum,
                    weight_decay=args.weight_decay)

        #optimizer = torch.optim.ASGD(params_list,
        #                            lr=args.lr,
        #                            weight_decay=args.weight_decay)

        # criterions
        self.criterion = SegmentationLosses(se_loss=args.se_loss, aux=args.aux,
                                            nclass=self.nclass)
        self.model, self.optimizer = model, optimizer
        # using cuda
        if args.cuda:
            self.model = DataParallelModel(self.model).cuda()
            self.criterion = DataParallelCriterion(self.criterion).cuda()
        # resuming checkpoint
        if args.resume is not None and len(args.resume)>0:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                # load weights for the same model
                # self.model.module.load_state_dict(checkpoint['state_dict'])



                # model and checkpoint have different strucutures
                pretrained_dict = checkpoint['state_dict']
                model_dict = self.model.module.state_dict()

                for name, param in pretrained_dict.items():
                    if name not in model_dict:
                        continue
                    if isinstance(param, Parameter):
                        # backwards compatibility for serialized parameters
                        param = param.data
                    model_dict[name].copy_(param)

            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

        # lr scheduler
        self.scheduler = utils.LR_Scheduler(args.lr_scheduler, args.lr,
                                            args.epochs, len(self.trainloader),lr_step=args.lr_step)
        self.best_pred = 0.0
Esempio n. 20
0
 def __init__(self, args):
     self.args = args
     args.log_name = str(args.checkname)
     args.log_root = os.path.join(args.dataset, args.log_root) # dataset/log/
     self.logger = utils.create_logger(args.log_root, args.log_name)
     # data transforms
     input_transform = transform.Compose([
         transform.ToTensor(),
         transform.Normalize([.485, .456, .406], [.229, .224, .225])])
     # dataset
     data_kwargs = {'transform': input_transform, 'base_size': args.base_size,
                    'crop_size': args.crop_size, 'logger': self.logger,
                    'scale': args.scale}
     trainset = get_segmentation_dataset(args.dataset, split='trainval', mode='trainval',
                                         **data_kwargs)
     testset = get_segmentation_dataset(args.dataset, split='val', mode='val',  # crop fixed size as model input
                                        **data_kwargs)
     # dataloader
     kwargs = {'num_workers': args.workers, 'pin_memory': True} \
         if args.cuda else {}
     self.trainloader = data.DataLoader(trainset, batch_size=args.batch_size,
                                        drop_last=True, shuffle=True, **kwargs)
     self.valloader = data.DataLoader(testset, batch_size=args.batch_size,
                                      drop_last=False, shuffle=False, **kwargs)
     self.nclass = trainset.num_class
     # model
     model = get_segmentation_model(args.model, dataset=args.dataset,
                                    backbone=args.backbone,
                                    norm_layer=BatchNorm2d,
                                    base_size=args.base_size, crop_size=args.crop_size,
                                    )
     #print(model)
     self.logger.info(model)
     # optimizer using different LR
     params_list = [{'params': model.pretrained.parameters(), 'lr': args.lr},]
     if hasattr(model, 'head'):
         print("this model has object, head")
         params_list.append({'params': model.head.parameters(), 'lr': args.lr*10})
     optimizer = torch.optim.SGD(params_list,
                 lr=args.lr,
                 momentum=args.momentum,
                 weight_decay=args.weight_decay)
     self.criterion = SegmentationLosses(nclass=self.nclass)
     
     self.model, self.optimizer = model, optimizer
     # using cuda
     if args.cuda:
         self.model = DataParallelModel(self.model).cuda()
         self.criterion = DataParallelCriterion(self.criterion).cuda()
     
     # resuming checkpoint
     if args.resume:
         if not os.path.isfile(args.resume):
             raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
         checkpoint = torch.load(args.resume)
         args.start_epoch = checkpoint['epoch']
         if args.cuda:
             self.model.module.load_state_dict(checkpoint['state_dict'])
         else:
             self.model.load_state_dict(checkpoint['state_dict'])
         if not args.ft:
             self.optimizer.load_state_dict(checkpoint['optimizer'])
         self.best_pred = checkpoint['best_pred']
         self.logger.info("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
     # lr scheduler
     self.scheduler = utils.LR_Scheduler(args.lr_scheduler, args.lr,
                                         args.epochs, len(self.trainloader), logger=self.logger,
                                         lr_step=args.lr_step)
     self.best_pred = 0.0
Esempio n. 21
0
class Trainer():
    def __init__(self, args):
        self.args = args
        # data transforms
        input_transform = transform.Compose([
            transform.ToTensor(),
            transform.Normalize([.485, .456, .406], [.229, .224, .225])
        ])
        # dataset
        data_kwargs = {
            'transform': input_transform,
            'base_size': args.base_size,
            'crop_size': args.crop_size
        }

        trainset_1 = get_dataset(
            'pascal_voc',
            root=os.path.expanduser('/fast/users/a1675776/data/encoding/data'),
            split='train',
            mode='train',
            **data_kwargs)

        trainset_2 = get_dataset(
            'pascal_aug',
            root=os.path.expanduser('/fast/users/a1675776/data/encoding/data'),
            split='train',
            mode='train',
            **data_kwargs)
        testset = get_dataset(
            'pascal_voc',
            root=os.path.expanduser('/fast/users/a1675776/data/encoding/data'),
            split='val',
            mode='val',
            **data_kwargs)

        concatenate_trainset = torch.utils.data.ConcatDataset(
            [trainset_1, trainset_2])
        # dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True} \
            if args.cuda else {}
        self.trainloader = data.DataLoader(concatenate_trainset,
                                           batch_size=args.batch_size,
                                           drop_last=True,
                                           shuffle=True,
                                           **kwargs)
        self.valloader = data.DataLoader(testset,
                                         batch_size=args.batch_size,
                                         drop_last=False,
                                         shuffle=False,
                                         **kwargs)

        self.nclass = trainset_1.num_class
        # model
        model = get_segmentation_model(args.model,
                                       dataset=args.dataset,
                                       backbone=args.backbone,
                                       aux=args.aux,
                                       se_loss=args.se_loss,
                                       norm_layer=SyncBatchNorm,
                                       base_size=args.base_size,
                                       crop_size=args.crop_size)
        #       print(model)
        # optimizer using different LR
        params_list = [
            {
                'params': model.pretrained.parameters(),
                'lr': args.lr
            },
        ]
        if hasattr(model, 'head'):
            params_list.append({
                'params': model.head.parameters(),
                'lr': args.lr * 10
            })
        if hasattr(model, 'auxlayer'):
            params_list.append({
                'params': model.auxlayer.parameters(),
                'lr': args.lr * 10
            })
        optimizer = torch.optim.Adam(params_list,
                                     lr=args.lr,
                                     weight_decay=args.weight_decay)
        # criterions
        self.criterion = SegmentationLosses(se_loss=args.se_loss,
                                            aux=args.aux,
                                            nclass=self.nclass,
                                            se_weight=args.se_weight,
                                            aux_weight=args.aux_weight)
        self.model, self.optimizer = model, optimizer
        # using cuda
        if args.cuda:
            self.model = DataParallelModel(self.model).cuda()
            self.criterion = DataParallelCriterion(self.criterion).cuda()
        # resuming checkpoint
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        # clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0
        # lr scheduler
        self.scheduler = utils.LR_Scheduler(args.lr_scheduler, args.lr,
                                            args.epochs, len(self.trainloader))
        self.best_pred = 0.0

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.trainloader)
        for i, (image, target) in enumerate(tbar):
            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            if torch_ver == "0.3":
                image = Variable(image)
                target = Variable(target)
            outputs = self.model(image)
            loss = self.criterion(outputs, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            utils.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, self.args, is_best)

    def validation(self, epoch):
        # Fast test during the training
        def eval_batch(model, image, target):
            outputs = model(image)
            outputs = gather(outputs, 0, dim=0)
            pred = outputs[0]
            target = target.cuda()
            correct, labeled = utils.batch_pix_accuracy(pred.data, target)
            inter, union = utils.batch_intersection_union(
                pred.data, target, self.nclass)
            return correct, labeled, inter, union

        is_best = False
        self.model.eval()
        total_inter, total_union, total_correct, total_label = 0, 0, 0, 0
        tbar = tqdm(self.valloader, desc='\r')
        for i, (image, target) in enumerate(tbar):
            if torch_ver == "0.3":
                image = Variable(image, volatile=True)
                correct, labeled, inter, union = eval_batch(
                    self.model, image, target)
            else:
                with torch.no_grad():
                    correct, labeled, inter, union = eval_batch(
                        self.model, image, target)

            total_correct += correct
            total_label += labeled
            total_inter += inter
            total_union += union
            pixAcc = 1.0 * total_correct / (np.spacing(1) + total_label)
            IoU = 1.0 * total_inter / (np.spacing(1) + total_union)
            mIoU = IoU.mean()
            tbar.set_description('pixAcc: %.3f, mIoU: %.3f' % (pixAcc, mIoU))

        test_record.append(mIoU)
        np.save('test_record.npy', test_record)

        new_pred = (mIoU) / 2
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
        utils.save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, self.args, is_best)
Esempio n. 22
0
class Trainer():
    def __init__(self, args):
        self.args = args
        # data transforms
        input_transform = transform.Compose([
            transform.ToTensor(),
            transform.Normalize([.485, .456, .406], [.229, .224, .225])
        ])
        # dataset
        data_kwargs = {
            'transform': input_transform,
            'base_size': args.base_size,
            'crop_size': args.crop_size
        }
        trainset = get_dataset(args.dataset,
                               split=args.train_split,
                               mode='train',
                               **data_kwargs)
        valset = get_dataset(
            args.dataset,
            split='val',
            mode='ms_val' if args.multi_scale_eval else 'fast_val',
            **data_kwargs)
        # dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.trainloader = data.DataLoader(trainset,
                                           batch_size=args.batch_size,
                                           drop_last=True,
                                           shuffle=True,
                                           **kwargs)
        if self.args.multi_scale_eval:
            kwargs['collate_fn'] = test_batchify_fn
        self.valloader = data.DataLoader(valset,
                                         batch_size=args.test_batch_size,
                                         drop_last=False,
                                         shuffle=False,
                                         **kwargs)
        self.nclass = trainset.num_class
        # model
        if args.norm_layer == 'bn':
            norm_layer = BatchNorm2d
        elif args.norm_layer == 'sync_bn':
            assert args.multi_gpu, "SyncBatchNorm can only be used when multi GPUs are available!"
            norm_layer = SyncBatchNorm
        else:
            raise ValueError('Invalid norm_layer {}'.format(args.norm_layer))
        model = get_segmentation_model(
            args.model,
            dataset=args.dataset,
            backbone=args.backbone,
            aux=args.aux,
            se_loss=args.se_loss,
            norm_layer=norm_layer,
            base_size=args.base_size,
            crop_size=args.crop_size,
            multi_grid=True,
            multi_dilation=[2, 4, 8],
            only_pam=True,
        )
        print(model)
        # optimizer using different LR
        params_list = [
            {
                'params': model.pretrained.parameters(),
                'lr': args.lr
            },
        ]
        if hasattr(model, 'head'):
            params_list.append({
                'params': model.head.parameters(),
                'lr': args.lr
            })
        if hasattr(model, 'auxlayer'):
            params_list.append({
                'params': model.auxlayer.parameters(),
                'lr': args.lr
            })
        optimizer = torch.optim.SGD(params_list,
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
        # criterions
        self.criterion = SegmentationMultiLosses()
        self.model, self.optimizer = model, optimizer
        # using cuda
        if args.multi_gpu:
            self.model = DataParallelModel(self.model).cuda()
            self.criterion = DataParallelCriterion(self.criterion).cuda()
        else:
            self.model = self.model.cuda()
            self.criterion = self.criterion.cuda()
        self.single_device_model = self.model.module if self.args.multi_gpu else self.model
        # resuming checkpoint
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            self.single_device_model.load_state_dict(checkpoint['state_dict'])
            if not args.ft and not (args.only_val or args.only_vis
                                    or args.only_infer):
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {}), best_pred {}".format(
                args.resume, checkpoint['epoch'], checkpoint['best_pred']))
        # clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0
        # lr scheduler
        self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
            optimizer, 0.6)
        self.best_pred = 0.0

    def save_ckpt(self, epoch, score):
        is_best = False
        if score >= self.best_pred:
            is_best = True
            self.best_pred = score
        utils.save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': self.single_device_model.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, self.args, is_best)

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        self.lr_scheduler.step()
        tbar = tqdm(self.trainloader, miniters=20)
        for i, (image, target) in enumerate(tbar):

            if not self.args.multi_gpu:
                image = image.cuda()
                target = target.cuda()
            self.optimizer.zero_grad()
            if torch_ver == "0.3":
                image = Variable(image)
                target = Variable(target)
            outputs = self.model(image)
            if self.args.multi_gpu:
                loss = self.criterion(outputs, target)
            else:
                loss = self.criterion(*(outputs + (target, )))
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            ep_log = 'ep {}'.format(epoch + 1)
            lr_log = 'lr ' + '{:.6f}'.format(
                self.optimizer.param_groups[0]['lr']).rstrip('0')
            loss_log = 'loss {:.3f}'.format(train_loss / (i + 1))
            tbar.set_description(', '.join([ep_log, lr_log, loss_log]))

    def validation(self, epoch):
        def _get_pred(batch_im):
            with torch.no_grad():
                # metric.update also accepts list, so no need to gather results from multi gpus
                if self.args.multi_scale_eval:
                    assert len(batch_im) <= torch.cuda.device_count(
                    ), "Multi-scale testing only allows batch size <= number of GPUs"
                    scattered_pred = self.ms_evaluator.parallel_forward(
                        batch_im)
                else:
                    outputs = self.model(batch_im)
                    scattered_pred = [
                        out[0] for out in outputs
                    ] if self.args.multi_gpu else [outputs[0]]
            return scattered_pred

        # Lazy creation
        if not hasattr(self, 'ms_evaluator'):
            self.ms_evaluator = MultiEvalModule(self.single_device_model,
                                                self.nclass,
                                                scales=self.args.eval_scales,
                                                crop=self.args.crop_eval)
            self.metric = utils.SegmentationMetric(self.nclass)
        self.model.eval()
        tbar = tqdm(self.valloader, desc='\r')
        for i, (batch_im, target) in enumerate(tbar):
            # No need to put target to GPU, since the metrics are calculated by numpy.
            # And no need to put data to GPU manually if we use data parallel.
            if not self.args.multi_gpu and not isinstance(
                    batch_im, (list, tuple)):
                batch_im = batch_im.cuda()
            scattered_pred = _get_pred(batch_im)
            scattered_target = []
            ind = 0
            for p in scattered_pred:
                target_tmp = target[ind:ind + len(p)]
                # Multi-scale testing. In fact, len(target_tmp) == 1
                if isinstance(target_tmp, (list, tuple)):
                    assert len(target_tmp) == 1
                    target_tmp = torch.stack(target_tmp)
                scattered_target.append(target_tmp)
                ind += len(p)
            self.metric.update(scattered_target, scattered_pred)
            pixAcc, mIoU = self.metric.get()
            tbar.set_description('ep {}, pixAcc: {:.4f}, mIoU: {:.4f}'.format(
                epoch + 1, pixAcc, mIoU))
        return self.metric.get()

    def visualize(self, epoch):
        if (self.args.dir_of_im_to_vis
                == 'None') and (self.args.im_list_file_to_vis == 'None'):
            return
        if not hasattr(self, 'vis_im_paths'):
            if self.args.dir_of_im_to_vis != 'None':
                print('=> Visualize Dir {}'.format(self.args.dir_of_im_to_vis))
                im_paths = list(
                    walkdir(self.args.dir_of_im_to_vis, exts=['.jpg', '.png']))
            else:
                print('=> Visualize Image List {}'.format(
                    self.args.im_list_file_to_vis))
                im_paths = read_lines(self.args.im_list_file_to_vis)
            print('=> Save Dir {}'.format(self.args.vis_save_dir))
            im_paths = sorted(im_paths)
            # np.random.RandomState(seed=1).shuffle(im_paths)
            self.vis_im_paths = im_paths[:self.args.max_num_vis]
        cfg = {
            'save_path':
            os.path.join(self.args.vis_save_dir,
                         'vis_epoch{}.png'.format(epoch)),
            'multi_scale':
            self.args.multi_scale_eval,
            'crop':
            self.args.crop_eval,
            'num_class':
            self.nclass,
            'scales':
            self.args.eval_scales,
            'base_size':
            self.args.base_size,
        }
        vis_im_list(self.single_device_model, self.vis_im_paths, cfg)

    def infer_and_save(self, infer_dir, infer_save_dir):
        print('=> Infer Dir {}'.format(infer_dir))
        print('=> Save Dir {}'.format(infer_save_dir))
        sub_im_paths = list(
            walkdir(infer_dir, exts=['.jpg', '.png'], sub_path=True))
        im_paths = [os.path.join(infer_dir, p) for p in sub_im_paths]
        # NOTE: Don't save result as JPEG, since it causes aliasing.
        save_paths = [
            os.path.join(infer_save_dir, p.replace('.jpg', '.png'))
            for p in sub_im_paths
        ]
        cfg = {
            'multi_scale': self.args.multi_scale_eval,
            'crop': self.args.crop_eval,
            'num_class': self.nclass,
            'scales': self.args.eval_scales,
            'base_size': self.args.base_size,
        }
        infer_and_save_im_list(self.single_device_model, im_paths, save_paths,
                               cfg)
Esempio n. 23
0
 def __init__(self, args):
     self.args = args
     # data transforms
     input_transform = transform.Compose([
         transform.ToTensor(),
         transform.Normalize([.485, .456, .406], [.229, .224, .225])
     ])
     # dataset
     data_kwargs = {
         'transform': input_transform,
         'base_size': args.base_size,
         'crop_size': args.crop_size
     }
     trainset = get_segmentation_dataset(args.dataset,
                                         split=args.train_split,
                                         mode='train',
                                         **data_kwargs)
     testset = get_segmentation_dataset(args.dataset,
                                        split='val',
                                        mode='val',
                                        **data_kwargs)
     # dataloader
     kwargs = {'num_workers': args.workers, 'pin_memory': True} \
         if args.cuda else {}
     self.trainloader = data.DataLoader(trainset,
                                        batch_size=args.batch_size,
                                        drop_last=True,
                                        shuffle=True,
                                        **kwargs)
     self.valloader = data.DataLoader(testset,
                                      batch_size=args.batch_size,
                                      drop_last=False,
                                      shuffle=False,
                                      **kwargs)
     self.nclass = trainset.num_class
     # model
     model = get_segmentation_model(
         args.model,
         dataset=args.dataset,
         backbone=args.backbone,
         dilated=args.dilated,
         lateral=args.lateral,
         jpu=args.jpu,
         aux=args.aux,
         se_loss=args.se_loss,
         norm_layer=torch.nn.BatchNorm2d,  ## BatchNorm2d
         base_size=args.base_size,
         crop_size=args.crop_size)
     print(model)
     # optimizer using different LR
     params_list = [
         {
             'params': model.pretrained.parameters(),
             'lr': args.lr
         },
     ]
     if hasattr(model, 'jpu'):
         params_list.append({
             'params': model.jpu.parameters(),
             'lr': args.lr * 10
         })
     if hasattr(model, 'head'):
         params_list.append({
             'params': model.head.parameters(),
             'lr': args.lr * 10
         })
     if hasattr(model, 'auxlayer'):
         params_list.append({
             'params': model.auxlayer.parameters(),
             'lr': args.lr * 10
         })
     optimizer = torch.optim.SGD(params_list,
                                 lr=args.lr,
                                 momentum=args.momentum,
                                 weight_decay=args.weight_decay)
     # criterions
     self.criterion = SegmentationLosses(se_loss=args.se_loss,
                                         aux=args.aux,
                                         nclass=self.nclass,
                                         se_weight=args.se_weight,
                                         aux_weight=args.aux_weight)
     self.model, self.optimizer = model, optimizer
     # using cuda
     if args.cuda:
         self.model = DataParallelModel(self.model).cuda()
         self.criterion = DataParallelCriterion(self.criterion).cuda()
     # resuming checkpoint
     if args.resume is not None:
         if not os.path.isfile(args.resume):
             raise RuntimeError("=> no checkpoint found at '{}'".format(
                 args.resume))
         checkpoint = torch.load(args.resume)
         args.start_epoch = checkpoint['epoch']
         if args.cuda:
             self.model.module.load_state_dict(checkpoint['state_dict'])
         else:
             self.model.load_state_dict(checkpoint['state_dict'])
         if not args.ft:
             self.optimizer.load_state_dict(checkpoint['optimizer'])
         self.best_pred = checkpoint['best_pred']
         print("=> loaded checkpoint '{}' (epoch {})".format(
             args.resume, checkpoint['epoch']))
     # clear start epoch if fine-tuning
     if args.ft:
         args.start_epoch = 0
     # lr scheduler
     self.scheduler = utils.LR_Scheduler(args.lr_scheduler, args.lr,
                                         args.epochs, len(self.trainloader))
     self.best_pred = 0.0
Esempio n. 24
0
class Trainer():
    def __init__(self, args):
        self.args = args
        self.args.start_epoch = 0
        self.args.cuda = True
        # data transforms
        input_transform = transform.Compose([
            transform.ToTensor(),
            transform.Normalize([.490, .490, .490], [.247, .247, .247])
        ])  # TODO: change mean and std

        # dataset
        train_chain = Compose([
            HorizontalFlip(p=0.5),
            OneOf([
                ElasticTransform(
                    alpha=300, sigma=300 * 0.05, alpha_affine=300 * 0.03),
                GridDistortion(),
                OpticalDistortion(distort_limit=2, shift_limit=0.5),
            ],
                  p=0.3),
            RandomSizedCrop(
                min_max_height=(900, 1024), height=1024, width=1024, p=0.5),
            ShiftScaleRotate(rotate_limit=20, p=0.5),
            Resize(self.args.size, self.args.size)
        ],
                              p=1)

        val_chain = Compose([Resize(self.args.size, self.args.size)], p=1)
        num_fold = self.args.num_fold
        df_train = pd.read_csv(os.path.join(args.imagelist_path, 'train.csv'))
        df_val = pd.read_csv(os.path.join(args.imagelist_path, 'val.csv'))
        df_full = pd.concat((df_train, df_val), ignore_index=True, axis=0)
        df_full['lbl'] = (df_full['mask_name'].astype(str) == '-1').astype(int)
        skf = StratifiedKFold(8, shuffle=True, random_state=777)
        train_ids, val_ids = list(
            skf.split(df_full['mask_name'], df_full['lbl']))[num_fold]

        df_test = pd.read_csv(
            os.path.join(args.imagelist_path, 'test_true.csv'))

        df_new_train = pd.concat((df_full.iloc[train_ids], df_test),
                                 ignore_index=True,
                                 axis=0,
                                 sort=False)
        df_new_val = df_full.iloc[val_ids]

        df_new_train.to_csv(f'/tmp/train_new_pneumo_{num_fold}.csv')
        df_new_val.to_csv(f'/tmp/val_new_pneumo_{num_fold}.csv')

        trainset = SegmentationDataset(f'/tmp/train_new_pneumo_{num_fold}.csv',
                                       args.image_path,
                                       args.masks_path,
                                       input_transform=input_transform,
                                       transform_chain=train_chain,
                                       base_size=1024)
        testset = SegmentationDataset(f'/tmp/val_new_pneumo_{num_fold}.csv',
                                      args.image_path,
                                      args.masks_path,
                                      input_transform=input_transform,
                                      transform_chain=val_chain,
                                      base_size=1024)

        imgs = trainset.mask_img_map[:, [0, 3]]
        weights = make_weights_for_balanced_classes(imgs, 2)
        weights = torch.DoubleTensor(weights)
        train_sampler = (torch.utils.data.sampler.WeightedRandomSampler(
            weights, len(weights)))

        # dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.trainloader = data.DataLoader(
            trainset,
            batch_size=args.batch_size,
            drop_last=True,
            sampler=train_sampler,  #shuffle=True, 
            **kwargs)
        self.valloader = data.DataLoader(testset,
                                         batch_size=args.batch_size,
                                         drop_last=False,
                                         shuffle=False,
                                         **kwargs)

        self.nclass = 1
        if self.args.model == 'unet':
            model = UNet(n_classes=self.nclass, norm_layer=SyncBatchNorm)
            params_list = [
                {
                    'params': model.parameters(),
                    'lr': args.lr
                },
            ]
        elif self.args.model == 'encnet':
            model = EncNet(
                nclass=self.nclass,
                backbone=args.backbone,
                aux=args.aux,
                se_loss=args.se_loss,
                norm_layer=SyncBatchNorm  #nn.BatchNorm2d
            )

            # optimizer using different LR
            params_list = [
                {
                    'params': model.pretrained.parameters(),
                    'lr': args.lr
                },
            ]
            if hasattr(model, 'head'):
                params_list.append({
                    'params': model.head.parameters(),
                    'lr': args.lr * 10
                })
            if hasattr(model, 'auxlayer'):
                params_list.append({
                    'params': model.auxlayer.parameters(),
                    'lr': args.lr * 10
                })

        print(model)
        optimizer = torch.optim.SGD(params_list,
                                    lr=args.lr,
                                    momentum=0.9,
                                    weight_decay=args.wd)

        # criterions
        if self.nclass == 1:
            self.criterion = SegmentationLossesBCE(se_loss=args.se_loss,
                                                   aux=args.aux,
                                                   nclass=self.nclass,
                                                   se_weight=args.se_weight,
                                                   aux_weight=args.aux_weight,
                                                   use_dice=args.use_dice)
        else:
            self.criterion = SegmentationLosses(
                se_loss=args.se_loss,
                aux=args.aux,
                nclass=self.nclass,
                se_weight=args.se_weight,
                aux_weight=args.aux_weight,
            )
        self.model, self.optimizer = model, optimizer

        self.best_pred = 0.0
        self.model = DataParallelModel(self.model).cuda()
        self.criterion = DataParallelCriterion(self.criterion).cuda()

        # resuming checkpoint
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)  #, map_location='cpu')
            self.args.start_epoch = checkpoint['epoch']
            state_dict = {k: v for k, v in checkpoint['state_dict'].items()}
            self.model.load_state_dict(state_dict)
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            for g in self.optimizer.param_groups:
                g['lr'] = args.lr
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            print(f'Best dice: {checkpoint["best_pred"]}')
            print(f'LR: {get_lr(self.optimizer):.5f}')

        self.scheduler = ReduceLROnPlateau(self.optimizer,
                                           mode='min',
                                           factor=0.8,
                                           patience=4,
                                           threshold=0.001,
                                           threshold_mode='abs',
                                           min_lr=0.00001)
        self.logger = Logger(args.logger_dir)
        self.step_train = 0
        self.best_loss = 20
        self.step_val = 0

    def logging(self,
                loss,
                running_acc,
                total,
                is_train,
                step,
                is_per_epoch,
                inputs=None,
                pred_masks=None,
                true_masks=None):
        #============ TensorBoard logging ============#
        # Log the scalar values
        accuracy = 100.0 * running_acc / total
        loss_str = 'Loss per epoch' if is_per_epoch else 'Loss per step'
        accuracy_str = 'Accuracy per epoch' if is_per_epoch else 'Accuracy per step'
        if is_per_epoch:
            loss = loss / len(self.trainloader) if is_train else loss / len(
                self.valloader)
        info = {loss_str: loss, accuracy_str: accuracy}

        for tag, value in info.items():
            self.logger.scalar_summary(tag, value, step, is_train)

        # Log values and gradients of the parameters (histogram)
        for tag, value in filter(lambda p: p[1].requires_grad,
                                 self.model.named_parameters()):
            tag = tag.replace('.', '/')
            self.logger.histo_summary(tag, to_np(value), step, 1000, is_train)
            if value.grad is not None:
                self.logger.histo_summary(tag + '/grad', to_np(value.grad),
                                          step, 1000, is_train)

        if inputs is not None:
            # Log the images
            inputs = to_np(inputs)[:10].transpose(0, 2, 3, 1)
            for i in range(inputs.shape[0]):
                inputs[i] *= np.array([.247, .247, .247])
                inputs[i] += np.array([.490, .490, .490])
            inputs = (255 * inputs).astype(np.uint8)  #.transpose(0, 3, 1, 2)
            pred_masks = (255 * pred_masks)[:10].astype(np.uint8)
            true_masks = (255 * true_masks)[:10].astype(np.uint8)
            inputs[..., 0] = (0.5 * inputs[..., 0] + 0.5 * pred_masks)
            inputs[..., 1] = (0.5 * inputs[..., 1] + 0.5 * true_masks).astype(
                np.uint8)
            info = {
                'images': inputs,
            }

            for tag, inputs in info.items():
                self.logger.image_summary(tag, inputs, step, is_train)

    def training(self, epoch):
        self.model.train()
        # tbar = tqdm(self.trainloader)
        total_score = 0
        total_score_simple = 0
        total_count = 0
        total_loss = 0
        for i, (_, image, target) in enumerate(self.trainloader):
            # if i >= 1000:
            #     break
            start_t = dtm.now()
            torch.cuda.empty_cache()
            if torch_ver == "0.3":
                image = Variable(image)
                target = Variable(target)

            outputs = self.model(image)
            loss = self.criterion(outputs, target)
            # loss = loss / 4
            loss.backward()
            # if (i+1) % 4 == 0:
            self.optimizer.step()
            # self.model.zero_grad()
            self.optimizer.zero_grad()

            total_loss += loss.item()

            preds_ten = [v[0].data.cpu() for v in outputs]
            preds_ten = torch.cat(preds_ten)
            if self.args.model == 'encnet':
                cls_preds_ten = [v[1].data.cpu() for v in outputs]
                cls_preds_ten = torch.cat(cls_preds_ten)
                cls_mask = torch.sigmoid(cls_preds_ten).numpy().reshape(
                    -1) < 0.5
                preds = torch.sigmoid(preds_ten).numpy()[:, 0, :, :]
            elif self.args.model == 'unet':
                cls_mask = np.zeros(preds_ten.size(0))
                preds = preds_ten.numpy()[:, 0, :, :]
            trues = target.numpy()

            local_score = dice_loss(trues, preds, cls_mask=cls_mask)
            local_score_simple = dice_loss(trues, preds)
            batch_size = preds.shape[0]

            total_score += local_score
            total_score_simple += local_score_simple
            total_count += batch_size

            print((
                f'Epoch: {epoch}, Batch: {i + 1} / {len(self.trainloader)}, '  #{len(self.trainloader)}, '
                f'loss: {total_loss / (i + 1):.3f}, batch loss: {loss.item():.3f}'
                f', batch simple DICE: {local_score_simple / batch_size:.3f}'
                f', total simple DICE: {total_score_simple / total_count:.3f}'
                f', batch DICE: {local_score / batch_size:.3f}'
                f', total DICE: {total_score / total_count:.5f}'
                f', lr: {get_lr(self.optimizer):.5f}'
                f', time: {(dtm.now() - start_t).total_seconds():.2f}'))

            # if i > 5:
            #     break
            self.step_train += 1
            if i % 10 == 0:
                sys.stdout.flush()
                pred_masks = np.array(
                    [preds[i] * cls_mask[i] for i in range(len(cls_mask))])
                pred_masks = (pred_masks > 0.5).astype(int)
                self.logging(loss.item(),
                             total_score,
                             total_count,
                             is_train=True,
                             step=self.step_train,
                             is_per_epoch=False,
                             inputs=image,
                             pred_masks=pred_masks,
                             true_masks=trues)

    def validation(self, epoch):
        # Fast test during the training
        def eval_batch(model, image, target):
            outputs = model(image)
            loss = self.criterion(outputs, target)

            preds_ten = [v[0].data.cpu() for v in outputs]
            preds_ten = torch.cat(preds_ten)
            if self.args.model == 'encnet':
                cls_preds_ten = [v[1].data.cpu() for v in outputs]
                cls_preds_ten = torch.cat(cls_preds_ten)
                cls_mask = torch.sigmoid(cls_preds_ten).numpy().reshape(
                    -1) < 0.5
                preds = torch.sigmoid(preds_ten).numpy()[:, 0, :, :]
            elif self.args.model == 'unet':
                cls_mask = np.zeros(preds_ten.size(0))
                preds = preds_ten.numpy()[:, 0, :, :]

            trues = target.numpy()

            local_score = dice_loss(trues, preds, cls_mask=cls_mask)
            batch_size = preds.shape[0]

            return preds, trues, cls_mask, local_score, batch_size, loss

        is_best = False
        self.model.eval()
        total_inter, total_union, total_correct, total_label = 0, 0, 0, 0
        total_score = 0
        total_loss = 0
        total_count = 0

        for i, (_, image, target) in enumerate(self.valloader):
            if torch_ver == "0.3":
                image = Variable(image, volatile=True)
                correct, labeled, inter, union = eval_batch(
                    self.model, image, target)
            else:
                with torch.no_grad():
                    preds, trues, cls_mask, local_score, batch_size, loss = (
                        eval_batch(self.model, image, target))

            total_score += local_score
            total_loss += loss.item()
            total_count += batch_size

            dice = total_score / total_count
            print(
                f'val epoch: {epoch}, batch: {i + 1} / {len(self.valloader)}, DICE: {dice:.5f}'
            )
            self.step_val += 1
            # if i > 15:
            #     break
            if i % 10 == 0:
                sys.stdout.flush()
                pred_masks = np.array(
                    [preds[i] * cls_mask[i] for i in range(len(cls_mask))])
                pred_masks = (pred_masks > 0.5).astype(int)
                self.logging(loss.item(),
                             total_score,
                             total_count,
                             is_train=False,
                             step=self.step_val,
                             is_per_epoch=False,
                             inputs=image,
                             pred_masks=pred_masks,
                             true_masks=trues)

        new_pred = dice
        if new_pred > self.best_pred:
            self.best_pred = new_pred
            torch.save(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                    'best_loss': self.best_loss
                }, self.args.ckpt_name[:-4] + '_best.pth')

        new_loss = total_loss / total_count
        if new_loss < self.best_loss:
            self.best_loss = new_loss
            torch.save(
                {
                    'epoch': epoch + 1,
                    'state_dict': self.model.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                    'best_loss': self.best_loss
                }, self.args.ckpt_name[:-4] + '_best_loss.pth')

        print(f'Validation DICE: {dice:.5f}, loss: {new_loss:.5f}')
        print(
            f'Validation best DICE: {self.best_pred:.5f}, best loss: {self.best_loss:.5f}'
        )
        torch.save(
            {
                'epoch': epoch + 1,
                'state_dict': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': dice,
            }, self.args.ckpt_name[:-4] + '_last.pth')
        return new_loss

    def __del__(self):
        del self.model
        gc.collect()