예제 #1
0
def modelDeploy(args, model, optimizer, scheduler, logger):
    if args.num_gpus >= 1:
        from torch.nn.parallel import DataParallel
        model = DataParallel(model)
        model = model.cuda()

    if torch.backends.cudnn.is_available():
        import torch.backends.cudnn as cudnn
        cudnn.benchmark = True
        cudnn.deterministic = True

    trainData = {'epoch': 0, 'loss': [], 'miou': [], 'val': [], 'bestMiou': 0}

    if args.resume:
        if os.path.isfile(args.resume):
            logger.info("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume,
                                    map_location=torch.device('cpu'))

            # model&optimizer
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])

            # stop point
            trainData = checkpoint['trainData']
            for i in range(trainData['epoch']):
                scheduler.step()
            # print(trainData)

            logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, trainData['epoch']))

        else:
            logger.error("=> no checkpoint found at '{}'".format(args.resume))
            assert False, "=> no checkpoint found at '{}'".format(args.resume)

    if args.finetune:
        if os.path.isfile(args.finetune):
            logger.info("=> finetuning checkpoint '{}'".format(args.finetune))
            state_all = torch.load(args.finetune, map_location='cpu')['model']
            state_clip = {}  # only use backbone parameters
            # print(model.state_dict().keys())
            for k, v in state_all.items():
                state_clip[k] = v
            # print(state_clip.keys())
            model.load_state_dict(state_clip, strict=False)
        else:
            logger.warning("finetune is not a file.")
            pass

    if args.freeze_bn:
        logger.warning('Freezing batch normalization layers')
        for m in model.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()
                m.weight.requires_grad = False
                m.bias.requires_grad = False

    return model, trainData
예제 #2
0
def load_reid_model():
    model = DataParallel(Model())
    ckpt = '/home/honglongcai/Github/PretrainedModel/model_410.pt'
    model.load_state_dict(torch.load(ckpt, map_location='cuda'))
    logger.info('Load ReID model from {}'.format(ckpt))

    model = model.cuda()
    model.eval()
    return model
예제 #3
0
class TestProcess:
    def __init__(self):
        self.net = ET_Net()

        if (ARGS['gpu']):
            self.net = DataParallel(module=self.net.cuda())
        
        self.net.load_state_dict(torch.load(ARGS['weight']))

        self.test_dataset = get_dataset(dataset_name=ARGS['dataset'], part='test')

    def predict(self):

        start = time.time()
        self.net.eval()
        test_dataloader = DataLoader(self.test_dataset, batch_size=1) # only support batch size = 1
        os.makedirs(ARGS['prediction_save_folder'], exist_ok=True)
        for items in test_dataloader:
            images, mask, filename = items['image'], items['mask'], items['filename']
            images = images.float()
            mask = mask.long()
            print('image shape:', images.size())

            image_patches, big_h, big_w = get_test_patches(images, ARGS['crop_size'], ARGS['stride_size'])
            test_patch_dataloader = DataLoader(image_patches, batch_size=ARGS['batch_size'], shuffle=False, drop_last=False)
            test_results = []
            print('Number of batches for testing:', len(test_patch_dataloader))

            for patches in test_patch_dataloader:
                
                if ARGS['gpu']:
                    patches = patches.cuda()
                
                with torch.no_grad():
                    result_patches_edge, result_patches = self.net(patches)
                
                test_results.append(result_patches.cpu())           
            
            test_results = torch.cat(test_results, dim=0)
            # merge
            test_results = recompone_overlap(test_results, ARGS['crop_size'], ARGS['stride_size'], big_h, big_w)
            test_results = test_results[:, 1, :images.size(2), :images.size(3)] * mask
            test_results = Image.fromarray(test_results[0].numpy())
            test_results.save(os.path.join(ARGS['prediction_save_folder'], filename[0]))
            print(f'Finish prediction for {filename[0]}')

        finish = time.time()

        print('Predicting time consumed: {:.2f}s'.format(finish - start))
예제 #4
0
class BaseInferencer(object):
    def __init__(self,
                 model,
                 images_path,
                 labels_path,
                 patient_ids,
                 sample_shape,
                 checkpoint_restore,
                 inference_dir,
                 use_gpu=False,
                 gpu_ids=None):
        # model settings
        self.model = model

        # data settings
        assert len(images_path) == len(labels_path)
        self.images_path = images_path
        self.labels_path = labels_path
        self.patient_ids = patient_ids
        self.length = len(images_path)
        self.sample_shape = sample_shape

        self.inference_dir = inference_dir

        # gpu settings
        self.use_gpu = use_gpu
        if use_gpu and torch.cuda.device_count() > 0:
            self.model.cuda()
            if gpu_ids is not None:
                if len(gpu_ids) > 1:
                    self.multi_gpu = True
                    self.model = DataParallel(model, gpu_ids)
                else:
                    self.multi_gpu = False
            else:
                if torch.cuda.device_count() > 1:
                    self.multi_gpu = True
                    self.model = DataParallel(model)
                else:
                    self.multi_gpu = False
        else:
            self.multi_gpu = False
            self.model = self.model.cpu()

        self.model.load_state_dict(torch.load(checkpoint_restore))

    def inference(self):
        self.model.eval()

        logging.info('*' * 80)
        logging.info('start inference loop')
        logging.info('%d patients need to be inference ' % self.length)

        for index in range(self.length):
            logging.info('start inference %d-th patient' % (index + 1))
            self.__inference__(index)

        logging.info('*' * 80)
        logging.info('inference patient: %d' % self.length)
        logging.info('inference result saved in: %s' % self.inference_dir)

    def __inference__(self, index):
        """

        :rtype: object
        """
        pass
예제 #5
0
class BaseEvaluation(object):
    def __init__(self,
                 model,
                 metrics,
                 images_path,
                 labels_path,
                 sample_shape,
                 checkpoint_restore,
                 use_gpu=False,
                 gpu_ids=None):
        # model settings
        self.model = model

        # metrics settings
        assert type(metrics) == dict
        self.metrics = metrics

        # data settings
        assert len(images_path) == len(labels_path)
        self.images_path = images_path
        self.labels_path = labels_path
        self.length = len(images_path)
        self.sample_shape = sample_shape

        # gpu settings
        self.use_gpu = use_gpu
        if use_gpu and torch.cuda.device_count() > 0:
            self.model.cuda()
            if gpu_ids is not None:
                if len(gpu_ids) > 1:
                    self.multi_gpu = True
                    self.model = DataParallel(model, gpu_ids)
                else:
                    self.multi_gpu = False
            else:
                if torch.cuda.device_count() > 1:
                    self.multi_gpu = True
                    self.model = DataParallel(model)
                else:
                    self.multi_gpu = False
        else:
            self.multi_gpu = False
            self.model = self.model.cpu()

        self.model.load_state_dict(torch.load(checkpoint_restore))

    def load_data(self, index):
        """

        :rtype: image -> nd-array, label -> nd-array
        """
        pass

    def eval_one_patient(self, image, label):
        """

        :rtype: metrics -> dict
        """
        pass

    def eval(self):
        self.model.eval()

        logging.info('*' * 80)
        logging.info('start evaluation loop')
        logging.info('%d patients need to be evaluated ' % self.length)

        result = dict()
        for index in range(self.length):

            logging.info('start evaluation %d-th patient' % (index + 1))
            image, label = self.load_data(index)
            with torch.no_grad():
                metrics = self.eval_one_patient(image, label)

            for key in metrics.keys():
                if key not in result.keys():
                    result[key] = list()
                result[key].append(metrics[key])

            logging.info('evaluation metrics result: %s' % str(metrics))

        mean_result = dict()
        for key in result.keys():
            mean_result[key] = np.mean(result[key])

        logging.info('*' * 80)
        logging.info('evaluation report: ')
        logging.info('evaluation patient: %d' % self.length)
        logging.info('evaluation metrics %s' % str(mean_result))
예제 #6
0
class BaseTrainer(object):
    def __init__(self,
                 epochs,
                 model,
                 train_dataloader,
                 train_loss_func,
                 train_metrics_func,
                 optimizer,
                 log_dir,
                 checkpoint_dir,
                 checkpoint_frequency,
                 checkpoint_restore=None,
                 val_dataloader=None,
                 val_metrics_func=None,
                 lr_scheduler=None,
                 lr_reduce_metric=None,
                 use_gpu=False,
                 gpu_ids=None):
        # train settings
        self.epochs = epochs
        self.model = model
        self.train_dataloader = train_dataloader
        self.train_loss_func = train_loss_func
        self.train_metrics_func = train_metrics_func
        self.optimizer = optimizer
        self.checkpoint_dir = checkpoint_dir
        self.checkpoint_frequency = checkpoint_frequency
        self.writer = SummaryWriter(logdir=log_dir)

        # validation settings
        if val_dataloader is not None:
            self.validation = True
            self.val_dataloader = val_dataloader
            self.val_metrics_func = val_metrics_func
        else:
            self.validation = False

        # lr scheduler settings
        if lr_scheduler is not None:
            self.lr_schedule = True
            self.lr_scheduler = lr_scheduler
            if isinstance(lr_scheduler,
                          torch.optim.lr_scheduler.ReduceLROnPlateau):
                self.lr_reduce_metric = lr_reduce_metric
        else:
            self.lr_schedule = False

        # multi-gpu settings
        self.use_gpu = use_gpu
        gpu_visible = list()
        for index in range(len(gpu_ids)):
            gpu_visible.append(str(gpu_ids[index]))
        os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(gpu_visible)

        if use_gpu and torch.cuda.device_count() > 0:
            self.model.cuda()
            if gpu_ids is not None:
                if len(gpu_ids) > 1:
                    self.multi_gpu = True
                    self.model = DataParallel(model, gpu_ids)
                else:
                    self.multi_gpu = False
            else:
                if torch.cuda.device_count() > 1:
                    self.multi_gpu = True
                    self.model = DataParallel(model)
                else:
                    self.multi_gpu = False
        else:
            self.multi_gpu = False
            self.device = torch.device('cpu')
            self.model = self.model.cpu()

        # checkpoint settings
        if checkpoint_restore is not None:
            self.model.load_state_dict(torch.load(checkpoint_restore))

    def train(self):

        for epoch in range(1, self.epochs + 1):
            logging.info('*' * 80)
            logging.info('start epoch %d training loop' % epoch)
            # train
            self.model.train()
            loss, metrics = self.train_epochs(epoch)

            self.writer.add_scalar('train_loss', loss, epoch)
            for key in metrics.keys():
                self.writer.add_scalar(key, metrics[key], epoch)
            if self.lr_schedule:
                if isinstance(self.lr_scheduler,
                              torch.optim.lr_scheduler.ReduceLROnPlateau):
                    self.lr_scheduler.step(loss[self.lr_reduce_metric])
                else:
                    self.lr_scheduler.step()
            logging.info('train loss result: %s' % str(loss))
            logging.info('train metrics result: %s' % str(metrics))

            # validation
            if self.validation:
                logging.info('validation start ... ')
                self.model.eval()
                loss, metrics = self.val_epochs(epoch)
                self.writer.add_scalar('val_loss', loss, epoch)
                for key in metrics.keys():
                    self.writer.add_scalar(key, metrics[key], epoch)
                logging.info('validation loss result: %s' % str(loss))
                logging.info('validation metrics result: %s' % str(metrics))

            # model checkpoint
            if epoch % self.checkpoint_frequency == 0:
                logging.info('saving model...')
                checkpoint_name = 'checkpoint_%d.pth' % epoch
                if self.multi_gpu:
                    torch.save(
                        self.model.module.state_dict(),
                        os.path.join(self.checkpoint_dir, checkpoint_name))
                else:
                    torch.save(
                        self.model.state_dict(),
                        os.path.join(self.checkpoint_dir, checkpoint_name))
                logging.info('model have saved for epoch_%d ' % epoch)
            else:
                logging.info('saving model skipped.')

    def train_epochs(self, epoch) -> (dict, dict):
        """

        :rtype: loss -> dict , metrics -> dict
        """
        pass

    def val_epochs(self, epoch) -> (dict, dict):
        """

        :rtype: loss -> dict , metrics -> dict
        """
        pass
예제 #7
0
    'params': IDE.classifier.parameters(),
    'lr': 0.01
}],
                                momentum=0.9,
                                weight_decay=5e-4,
                                nesterov=True)
# Decay LR by a factor of 0.1 every 20 epochs (20 epochs for market and 30 epochs for duke)
scheduler_IDE = lr_scheduler.StepLR(IDE_optimizer, step_size=10, gamma=0.1)

## load checkpoint
ckpt_dir = './checkpoints/espgan_m2d_lam5/'
utils.mkdir(ckpt_dir)
try:
    ckpt = utils.load_checkpoint(ckpt_dir, map_location=torch.device('cpu'))
    start_epoch = ckpt['epoch']
    Da.load_state_dict(ckpt['Da'])
    Db.load_state_dict(ckpt['Db'])
    Ga.load_state_dict(ckpt['Ga'])
    Gb.load_state_dict(ckpt['Gb'])
    IDE.load_state_dict(ckpt['IDE'])

    da_optimizer.load_state_dict(ckpt['da_optimizer'])
    db_optimizer.load_state_dict(ckpt['db_optimizer'])
    ga_optimizer.load_state_dict(ckpt['ga_optimizer'])
    gb_optimizer.load_state_dict(ckpt['gb_optimizer'])
    IDE_optimizer.load_state_dict(ckpt['IDE_optimizer'])
except:
    start_epoch = 0
    print('Training form zero')

## run
예제 #8
0
class SSLOnlineEvaluator(Callback):  # pragma: no cover
    """Attaches a MLP for fine-tuning using the standard self-supervised protocol.

    Example::

        # your datamodule must have 2 attributes
        dm = DataModule()
        dm.num_classes = ... # the num of classes in the datamodule
        dm.name = ... # name of the datamodule (e.g. ImageNet, STL10, CIFAR10)

        # your model must have 1 attribute
        model = Model()
        model.z_dim = ... # the representation dim

        online_eval = SSLOnlineEvaluator(
            z_dim=model.z_dim
        )
    """
    def __init__(
        self,
        z_dim: int,
        drop_p: float = 0.2,
        hidden_dim: Optional[int] = None,
        num_classes: Optional[int] = None,
        dataset: Optional[str] = None,
    ):
        """
        Args:
            z_dim: Representation dimension
            drop_p: Dropout probability
            hidden_dim: Hidden dimension for the fine-tune MLP
        """
        super().__init__()

        self.z_dim = z_dim
        self.hidden_dim = hidden_dim
        self.drop_p = drop_p

        self.optimizer: Optional[Optimizer] = None
        self.online_evaluator: Optional[SSLEvaluator] = None
        self.num_classes: Optional[int] = None
        self.dataset: Optional[str] = None
        self.num_classes: Optional[int] = num_classes
        self.dataset: Optional[str] = dataset

        self._recovered_callback_state: Optional[Dict[str, Any]] = None

    def setup(self,
              trainer: Trainer,
              pl_module: LightningModule,
              stage: Optional[str] = None) -> None:
        if self.num_classes is None:
            self.num_classes = trainer.datamodule.num_classes
        if self.dataset is None:
            self.dataset = trainer.datamodule.name

    def on_pretrain_routine_start(self, trainer: Trainer,
                                  pl_module: LightningModule) -> None:
        # must move to device after setup, as during setup, pl_module is still on cpu
        self.online_evaluator = SSLEvaluator(
            n_input=self.z_dim,
            n_classes=self.num_classes,
            p=self.drop_p,
            n_hidden=self.hidden_dim,
        ).to(pl_module.device)

        # switch fo PL compatibility reasons
        accel = (trainer.accelerator_connector if hasattr(
            trainer, "accelerator_connector") else
                 trainer._accelerator_connector)
        if accel.is_distributed:
            if accel.use_ddp:
                from torch.nn.parallel import DistributedDataParallel as DDP

                self.online_evaluator = DDP(self.online_evaluator,
                                            device_ids=[pl_module.device])
            elif accel.use_dp:
                from torch.nn.parallel import DataParallel as DP

                self.online_evaluator = DP(self.online_evaluator,
                                           device_ids=[pl_module.device])
            else:
                rank_zero_warn(
                    "Does not support this type of distributed accelerator. The online evaluator will not sync."
                )

        self.optimizer = torch.optim.Adam(self.online_evaluator.parameters(),
                                          lr=1e-4)

        if self._recovered_callback_state is not None:
            self.online_evaluator.load_state_dict(
                self._recovered_callback_state["state_dict"])
            self.optimizer.load_state_dict(
                self._recovered_callback_state["optimizer_state"])

    def to_device(self, batch: Sequence,
                  device: Union[str, torch.device]) -> Tuple[Tensor, Tensor]:
        # get the labeled batch
        if self.dataset == "stl10":
            labeled_batch = batch[1]
            batch = labeled_batch

        inputs, y = batch

        # last input is for online eval
        x = inputs[-1]
        x = x.to(device)
        y = y.to(device)

        return x, y

    def shared_step(
        self,
        pl_module: LightningModule,
        batch: Sequence,
    ):
        with torch.no_grad():
            with set_training(pl_module, False):
                x, y = self.to_device(batch, pl_module.device)
                representations = pl_module(x).flatten(start_dim=1)

        # forward pass
        mlp_logits = self.online_evaluator(
            representations)  # type: ignore[operator]
        mlp_loss = F.cross_entropy(mlp_logits, y)

        acc = accuracy(mlp_logits.softmax(-1), y)

        return acc, mlp_loss

    def on_train_batch_end(
        self,
        trainer: Trainer,
        pl_module: LightningModule,
        outputs: Sequence,
        batch: Sequence,
        batch_idx: int,
        dataloader_idx: int,
    ) -> None:
        train_acc, mlp_loss = self.shared_step(pl_module, batch)

        # update finetune weights
        mlp_loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()

        pl_module.log("online_train_acc",
                      train_acc,
                      on_step=True,
                      on_epoch=False)
        pl_module.log("online_train_loss",
                      mlp_loss,
                      on_step=True,
                      on_epoch=False)

    def on_validation_batch_end(
        self,
        trainer: Trainer,
        pl_module: LightningModule,
        outputs: Sequence,
        batch: Sequence,
        batch_idx: int,
        dataloader_idx: int,
    ) -> None:
        val_acc, mlp_loss = self.shared_step(pl_module, batch)
        pl_module.log("online_val_acc",
                      val_acc,
                      on_step=False,
                      on_epoch=True,
                      sync_dist=True)
        pl_module.log("online_val_loss",
                      mlp_loss,
                      on_step=False,
                      on_epoch=True,
                      sync_dist=True)

    def on_save_checkpoint(self, trainer: Trainer, pl_module: LightningModule,
                           checkpoint: Dict[str, Any]) -> dict:
        return {
            "state_dict": self.online_evaluator.state_dict(),
            "optimizer_state": self.optimizer.state_dict()
        }

    def on_load_checkpoint(self, trainer: Trainer, pl_module: LightningModule,
                           callback_state: Dict[str, Any]) -> None:
        self._recovered_callback_state = callback_state
def main():

    global args, best_prec1
    args = parser.parse_args()

    # Read list of training and validation data
    listfiles_train, labels_train = read_lists(TRAIN_OUT)
    listfiles_val, labels_val = read_lists(VAL_OUT)
    listfiles_test, labels_test = read_lists(TEST_OUT)
    dataset_train = Dataset(listfiles_train,
                            labels_train,
                            subtract_mean=False,
                            V=12)
    dataset_val = Dataset(listfiles_val, labels_val, subtract_mean=False, V=12)
    dataset_test = Dataset(listfiles_test,
                           labels_test,
                           subtract_mean=False,
                           V=12)

    # shuffle data
    dataset_train.shuffle()
    dataset_val.shuffle()
    dataset_test.shuffle()
    tra_data_size, val_data_size, test_data_size = dataset_train.size(
    ), dataset_val.size(), dataset_test.size()
    print 'training size:', tra_data_size
    print 'validation size:', val_data_size
    print 'testing size:', test_data_size

    batch_size = args.b
    print("batch_size is :" + str(batch_size))
    learning_rate = args.lr
    print("learning_rate is :" + str(learning_rate))
    num_cuda = cuda.device_count()
    print("number of GPUs have been detected:" + str(num_cuda))

    # creat model
    print("model building...")
    mvcnn = DataParallel(modelnet40_Alex(num_cuda, batch_size))
    #mvcnn = modelnet40(num_cuda, batch_size, multi_gpu = False)
    mvcnn.cuda()

    # Optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint'{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            mvcnn.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    #print(mvcnn)

    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.Adadelta(mvcnn.parameters(), weight_decay=1e-4)
    # evaluate performance only
    if args.evaluate:
        print 'testing mode ------------------'
        validate(dataset_test, mvcnn, criterion, optimizer, batch_size)
        return

    print 'training mode ------------------'
    for epoch in xrange(args.start_epoch, args.epochs):
        print('epoch:', epoch)

        #adjust_learning_rate(optimizer, epoch)
        # train for one epoch
        train(dataset_train, mvcnn, criterion, optimizer, epoch, batch_size)

        # evaluate on validation set
        prec1 = validate(dataset_val, mvcnn, criterion, optimizer, batch_size)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        if is_best:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': mvcnn.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best, epoch)
        elif epoch % 5 is 0:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': mvcnn.state_dict(),
                    'best_prec1': best_prec1,
                }, is_best, epoch)
예제 #10
0
# Train or Test 


if not args.demo:
    
    avg_tool = CumulativeAverager()

    vloss, is_best = torch.tensor(float(np.inf)), None
    if args.load_from is not None:
        if os.path.isfile(args.load_from):
            log_str = add_to_log("=> loading checkpoint '{}'".format(args.load_from))
            checkpoint = torch.load(args.load_from)
            start = checkpoint['epoch']
            vloss = checkpoint['best_val_loss']
            net.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            args.lr = checkpoint['learning_rate']
            log_str = add_to_log("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.load_from, checkpoint['epoch']))
        else:
            log_str = add_to_log("=> no checkpoint found at '{}'".format(args.load_from)) 
    else:
        start = 0
    
    for epoch in range(start, args.epochs):

        train(epoch, losstype=losstype)
        val_loss = validate(losstype=losstype).cpu()
        scheduler.step(val_loss)
		
예제 #11
0
def main(args):
    crop_size = args.crop_size
    assert isinstance(crop_size, tuple)
    print_info_message(
        'Running Model at image resolution {}x{} with batch size {}'.format(
            crop_size[0], crop_size[1], args.batch_size))
    if not os.path.isdir(args.savedir):
        os.makedirs(args.savedir)

    num_gpus = torch.cuda.device_count()
    device = 'cuda' if num_gpus > 0 else 'cpu'
    print('device : ' + device)

    # Get a summary writer for tensorboard
    writer = SummaryWriter(log_dir=args.savedir,
                           comment='Training and Validation logs')

    #
    # Training the model with 13 classes of CamVid dataset
    # TODO: This process should be done only if specified
    #
    if not args.finetune:
        train_dataset, val_dataset, class_wts, seg_classes, color_encoding = import_dataset(
            label_conversion=False)  # 13 classes
        args.use_depth = False  # 'use_depth' is always false for camvid

        print_info_message('Training samples: {}'.format(len(train_dataset)))
        print_info_message('Validation samples: {}'.format(len(val_dataset)))

        # Import model
        if args.model == 'espnetv2':
            from model.segmentation.espnetv2 import espnetv2_seg
            args.classes = seg_classes
            model = espnetv2_seg(args)
        elif args.model == 'espdnet':
            from model.segmentation.espdnet import espdnet_seg
            args.classes = seg_classes
            print("Trainable fusion : {}".format(args.trainable_fusion))
            print("Segmentation classes : {}".format(seg_classes))
            model = espdnet_seg(args)
        elif args.model == 'espdnetue':
            from model.segmentation.espdnet_ue import espdnetue_seg2
            args.classes = seg_classes
            print("Trainable fusion : {}".format(args.trainable_fusion))
            ("Segmentation classes : {}".format(seg_classes))
            print(args.weights)
            model = espdnetue_seg2(args, False, fix_pyr_plane_proj=True)
        else:
            print_error_message('Arch: {} not yet supported'.format(
                args.model))
            exit(-1)

        # Freeze batch normalization layers?
        if args.freeze_bn:
            freeze_bn_layer(model)

        # Set learning rates
        train_params = [{
            'params': model.get_basenet_params(),
            'lr': args.lr
        }, {
            'params': model.get_segment_params(),
            'lr': args.lr * args.lr_mult
        }]

        # Define an optimizer
        optimizer = optim.SGD(train_params,
                              lr=args.lr * args.lr_mult,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)

        # Compute the FLOPs and the number of parameters, and display it
        num_params, flops = show_network_stats(model, crop_size)

        try:
            writer.add_graph(model,
                             input_to_model=torch.Tensor(
                                 1, 3, crop_size[0], crop_size[1]))
        except:
            print_log_message(
                "Not able to generate the graph. Likely because your model is not supported by ONNX"
            )

        #criterion = nn.CrossEntropyLoss(weight=class_wts, reduction='none', ignore_index=args.ignore_idx)
        criterion = SegmentationLoss(n_classes=seg_classes,
                                     loss_type=args.loss_type,
                                     device=device,
                                     ignore_idx=args.ignore_idx,
                                     class_wts=class_wts.to(device))
        nid_loss = NIDLoss(image_bin=32,
                           label_bin=seg_classes) if args.use_nid else None

        if num_gpus >= 1:
            if num_gpus == 1:
                # for a single GPU, we do not need DataParallel wrapper for Criteria.
                # So, falling back to its internal wrapper
                from torch.nn.parallel import DataParallel
                model = DataParallel(model)
                model = model.cuda()
                criterion = criterion.cuda()
                if args.use_nid:
                    nid_loss.cuda()
            else:
                from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria
                model = DataParallelModel(model)
                model = model.cuda()
                criterion = DataParallelCriteria(criterion)
                criterion = criterion.cuda()
                if args.use_nid:
                    nid_loss = DataParallelCriteria(nid_loss)
                    nid_loss = nid_loss.cuda()

            if torch.backends.cudnn.is_available():
                import torch.backends.cudnn as cudnn
                cudnn.benchmark = True
                cudnn.deterministic = True

        # Get data loaders for training and validation data
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   pin_memory=True,
                                                   num_workers=args.workers)
        val_loader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_size=20,
                                                 shuffle=False,
                                                 pin_memory=True,
                                                 num_workers=args.workers)

        # Get a learning rate scheduler
        lr_scheduler = get_lr_scheduler(args.scheduler)

        write_stats_to_json(num_params, flops)

        extra_info_ckpt = '{}_{}_{}'.format(args.model, args.s, crop_size[0])
        #
        # Main training loop of 13 classes
        #
        start_epoch = 0
        best_miou = 0.0
        for epoch in range(start_epoch, args.epochs):
            lr_base = lr_scheduler.step(epoch)
            # set the optimizer with the learning rate
            # This can be done inside the MyLRScheduler
            lr_seg = lr_base * args.lr_mult
            optimizer.param_groups[0]['lr'] = lr_base
            optimizer.param_groups[1]['lr'] = lr_seg

            print_info_message(
                'Running epoch {} with learning rates: base_net {:.6f}, segment_net {:.6f}'
                .format(epoch, lr_base, lr_seg))

            # Use different training functions for espdnetue
            if args.model == 'espdnetue':
                from utilities.train_eval_seg import train_seg_ue as train
                from utilities.train_eval_seg import val_seg_ue as val
            else:
                from utilities.train_eval_seg import train_seg as train
                from utilities.train_eval_seg import val_seg as val

            miou_train, train_loss = train(model,
                                           train_loader,
                                           optimizer,
                                           criterion,
                                           seg_classes,
                                           epoch,
                                           device=device,
                                           use_depth=args.use_depth,
                                           add_criterion=nid_loss)
            miou_val, val_loss = val(model,
                                     val_loader,
                                     criterion,
                                     seg_classes,
                                     device=device,
                                     use_depth=args.use_depth,
                                     add_criterion=nid_loss)

            batch_train = iter(train_loader).next()
            batch = iter(val_loader).next()
            in_training_visualization_img(
                model,
                images=batch_train[0].to(device=device),
                labels=batch_train[1].to(device=device),
                class_encoding=color_encoding,
                writer=writer,
                epoch=epoch,
                data='Segmentation/train',
                device=device)
            in_training_visualization_img(model,
                                          images=batch[0].to(device=device),
                                          labels=batch[1].to(device=device),
                                          class_encoding=color_encoding,
                                          writer=writer,
                                          epoch=epoch,
                                          data='Segmentation/val',
                                          device=device)

            # remember best miou and save checkpoint
            is_best = miou_val > best_miou
            best_miou = max(miou_val, best_miou)

            weights_dict = model.module.state_dict(
            ) if device == 'cuda' else model.state_dict()
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.model,
                    'state_dict': weights_dict,
                    'best_miou': best_miou,
                    'optimizer': optimizer.state_dict(),
                }, is_best, args.savedir, extra_info_ckpt)

            writer.add_scalar('Segmentation/LR/base', round(lr_base, 6), epoch)
            writer.add_scalar('Segmentation/LR/seg', round(lr_seg, 6), epoch)
            writer.add_scalar('Segmentation/Loss/train', train_loss, epoch)
            writer.add_scalar('Segmentation/Loss/val', val_loss, epoch)
            writer.add_scalar('Segmentation/mIOU/train', miou_train, epoch)
            writer.add_scalar('Segmentation/mIOU/val', miou_val, epoch)
            writer.add_scalar('Segmentation/Complexity/Flops', best_miou,
                              math.ceil(flops))
            writer.add_scalar('Segmentation/Complexity/Params', best_miou,
                              math.ceil(num_params))

        # Save the pretrained weights
        model_dict = copy.deepcopy(model.state_dict())
        del model
        torch.cuda.empty_cache()

    #
    # Finetuning with 4 classes
    #
    args.ignore_idx = 4
    train_dataset, val_dataset, class_wts, seg_classes, color_encoding = import_dataset(
        label_conversion=True)  # 5 classes

    print_info_message('Training samples: {}'.format(len(train_dataset)))
    print_info_message('Validation samples: {}'.format(len(val_dataset)))

    #set_parameters_for_finetuning()

    # Import model
    if args.model == 'espnetv2':
        from model.segmentation.espnetv2 import espnetv2_seg
        args.classes = seg_classes
        model = espnetv2_seg(args)
    elif args.model == 'espdnet':
        from model.segmentation.espdnet import espdnet_seg
        args.classes = seg_classes
        print("Trainable fusion : {}".format(args.trainable_fusion))
        print("Segmentation classes : {}".format(seg_classes))
        model = espdnet_seg(args)
    elif args.model == 'espdnetue':
        from model.segmentation.espdnet_ue import espdnetue_seg2
        args.classes = seg_classes
        print("Trainable fusion : {}".format(args.trainable_fusion))
        print("Segmentation classes : {}".format(seg_classes))
        print(args.weights)
        model = espdnetue_seg2(args, args.finetune, fix_pyr_plane_proj=True)
    else:
        print_error_message('Arch: {} not yet supported'.format(args.model))
        exit(-1)

    if not args.finetune:
        new_model_dict = model.state_dict()
        #        for k, v in model_dict.items():
        #            if k.lstrip('module.') in new_model_dict:
        #                print('In:{}'.format(k.lstrip('module.')))
        #            else:
        #                print('Not In:{}'.format(k.lstrip('module.')))
        overlap_dict = {
            k.replace('module.', ''): v
            for k, v in model_dict.items()
            if k.replace('module.', '') in new_model_dict
            and new_model_dict[k.replace('module.', '')].size() == v.size()
        }
        no_overlap_dict = {
            k.replace('module.', ''): v
            for k, v in new_model_dict.items()
            if k.replace('module.', '') not in new_model_dict
            or new_model_dict[k.replace('module.', '')].size() != v.size()
        }
        print(no_overlap_dict.keys())

        new_model_dict.update(overlap_dict)
        model.load_state_dict(new_model_dict)

    output = model(torch.ones(1, 3, 288, 480))
    print(output[0].size())

    print(seg_classes)
    print(class_wts.size())
    #print(model_dict.keys())
    #print(new_model_dict.keys())
    criterion = SegmentationLoss(n_classes=seg_classes,
                                 loss_type=args.loss_type,
                                 device=device,
                                 ignore_idx=args.ignore_idx,
                                 class_wts=class_wts.to(device))
    nid_loss = NIDLoss(image_bin=32,
                       label_bin=seg_classes) if args.use_nid else None

    # Set learning rates
    args.lr /= 100
    train_params = [{
        'params': model.get_basenet_params(),
        'lr': args.lr
    }, {
        'params': model.get_segment_params(),
        'lr': args.lr * args.lr_mult
    }]
    # Define an optimizer
    optimizer = optim.SGD(train_params,
                          lr=args.lr * args.lr_mult,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    if num_gpus >= 1:
        if num_gpus == 1:
            # for a single GPU, we do not need DataParallel wrapper for Criteria.
            # So, falling back to its internal wrapper
            from torch.nn.parallel import DataParallel
            model = DataParallel(model)
            model = model.cuda()
            criterion = criterion.cuda()
            if args.use_nid:
                nid_loss.cuda()
        else:
            from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria
            model = DataParallelModel(model)
            model = model.cuda()
            criterion = DataParallelCriteria(criterion)
            criterion = criterion.cuda()
            if args.use_nid:
                nid_loss = DataParallelCriteria(nid_loss)
                nid_loss = nid_loss.cuda()

        if torch.backends.cudnn.is_available():
            import torch.backends.cudnn as cudnn
            cudnn.benchmark = True
            cudnn.deterministic = True

    # Get data loaders for training and validation data
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=args.workers)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=20,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=args.workers)

    # Get a learning rate scheduler
    args.epochs = 50
    lr_scheduler = get_lr_scheduler(args.scheduler)

    # Compute the FLOPs and the number of parameters, and display it
    num_params, flops = show_network_stats(model, crop_size)
    write_stats_to_json(num_params, flops)

    extra_info_ckpt = '{}_{}_{}_{}'.format(args.model, seg_classes, args.s,
                                           crop_size[0])
    #
    # Main training loop of 13 classes
    #
    start_epoch = 0
    best_miou = 0.0
    for epoch in range(start_epoch, args.epochs):
        lr_base = lr_scheduler.step(epoch)
        # set the optimizer with the learning rate
        # This can be done inside the MyLRScheduler
        lr_seg = lr_base * args.lr_mult
        optimizer.param_groups[0]['lr'] = lr_base
        optimizer.param_groups[1]['lr'] = lr_seg

        print_info_message(
            'Running epoch {} with learning rates: base_net {:.6f}, segment_net {:.6f}'
            .format(epoch, lr_base, lr_seg))

        # Use different training functions for espdnetue
        if args.model == 'espdnetue':
            from utilities.train_eval_seg import train_seg_ue as train
            from utilities.train_eval_seg import val_seg_ue as val
        else:
            from utilities.train_eval_seg import train_seg as train
            from utilities.train_eval_seg import val_seg as val

        miou_train, train_loss = train(model,
                                       train_loader,
                                       optimizer,
                                       criterion,
                                       seg_classes,
                                       epoch,
                                       device=device,
                                       use_depth=args.use_depth,
                                       add_criterion=nid_loss)
        miou_val, val_loss = val(model,
                                 val_loader,
                                 criterion,
                                 seg_classes,
                                 device=device,
                                 use_depth=args.use_depth,
                                 add_criterion=nid_loss)

        batch_train = iter(train_loader).next()
        batch = iter(val_loader).next()
        in_training_visualization_img(model,
                                      images=batch_train[0].to(device=device),
                                      labels=batch_train[1].to(device=device),
                                      class_encoding=color_encoding,
                                      writer=writer,
                                      epoch=epoch,
                                      data='SegmentationConv/train',
                                      device=device)
        in_training_visualization_img(model,
                                      images=batch[0].to(device=device),
                                      labels=batch[1].to(device=device),
                                      class_encoding=color_encoding,
                                      writer=writer,
                                      epoch=epoch,
                                      data='SegmentationConv/val',
                                      device=device)

        # remember best miou and save checkpoint
        is_best = miou_val > best_miou
        best_miou = max(miou_val, best_miou)

        weights_dict = model.module.state_dict(
        ) if device == 'cuda' else model.state_dict()
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.model,
                'state_dict': weights_dict,
                'best_miou': best_miou,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.savedir, extra_info_ckpt)

        writer.add_scalar('SegmentationConv/LR/base', round(lr_base, 6), epoch)
        writer.add_scalar('SegmentationConv/LR/seg', round(lr_seg, 6), epoch)
        writer.add_scalar('SegmentationConv/Loss/train', train_loss, epoch)
        writer.add_scalar('SegmentationConv/Loss/val', val_loss, epoch)
        writer.add_scalar('SegmentationConv/mIOU/train', miou_train, epoch)
        writer.add_scalar('SegmentationConv/mIOU/val', miou_val, epoch)
        writer.add_scalar('SegmentationConv/Complexity/Flops', best_miou,
                          math.ceil(flops))
        writer.add_scalar('SegmentationConv/Complexity/Params', best_miou,
                          math.ceil(num_params))

    writer.close()
예제 #12
0
class GetFeature(object):
    """Extract features
    Arguments
        model_weight_file: pre-trained model
        sys_device_ids: cpu/gpu
    """
    def __init__(self, model_weight_file, sys_device_ids=''):
        if len(sys_device_ids) > 0:
            os.environ['CUDA_VISIBLE_DEVICES'] = sys_device_ids
        self.sys_device_ids = sys_device_ids
        self.model = DataParallel(Model())
        if torch.cuda.is_available() and self.sys_device_ids != '':
            device = torch.device('cuda')
        else:
            device = torch.device('cpu')
        self.model.load_state_dict(torch.load(model_weight_file,
                                              map_location=device))
        self.model.to(device)
        self.model.eval()
        
    def __call__(self, photo_path=None, batch_size=1):
        """
        get global feature and local feature
        :param photo_path : either photo directory or a single image
        :param batch_size : useful only when photo_path is a directory
        :return: feature: numpy array, dim = num_images * 2048,
                 photo_name: a list, len = num_images
        """
        '''
        if photo_dir is None and photo is None:
            raise self.InputError('Error: both photo_path '
                                  'and images is None.')
        if photo_dir and photo:
            raise self.InputError('Error: only need one argument, '
                                  'either photo_path or images.')
        '''
        # input is a directory
        if os.path.isdir(photo_path):
            dataset = Data(photo_path, self._img_process)
            data_loader = DataLoader(dataset, batch_size=batch_size,
                                     num_workers=8)
            features = torch.FloatTensor()
            photos = []
            for batch, (images, names) in enumerate(data_loader):
                images = images.float()
                if torch.cuda.is_available() and self.sys_device_ids != '':
                    images = images.to('cuda')
                feature = self.model(images).data.cpu()
                features = torch.cat((features, feature), 0)
                photos = photos + list(names)
                if batch % 10 == 0:
                    print('processing batch: {}'.format(batch))
            features = features.numpy()
            features = features/np.linalg.norm(features, axis=1,
                                               keepdims=True)
            return features, photos
        # input is a single image
        else:
            photo_name = photo_path.split('/')[-1]
            img = Image.open(photo_path)
            image = self._img_process(img)
            image = np.expand_dims(image, axis=0)
            image = torch.from_numpy(image).float()
            feature = self.model(image).data.numpy()
            feature = feature/np.linalg.norm(feature, axis=1,
                                             keepdims=True)
            return feature, [photo_name]
    
    def _img_process(self, img):
        img = img.resize((128, 384), resample=3)
        img = np.asarray(img)
        img = img[:, :, :3]
        img = img.astype(float)
        img = img / 255
        im_mean = np.array([0.485, 0.456, 0.406])
        im_std = np.array([0.229, 0.224, 0.225])
        img = img - im_mean
        img = img / im_std
        img = np.transpose(img, (2, 0, 1))
        return img
예제 #13
0
def main(args):
    if not os.path.exists(args.outputs_dir):
        os.makedirs(args.outputs_dir)

    print("===> Loading datasets")
    data_set = EvalDataset(
        args.test_lr,
        n_frames=args.n_frames,
        interval_list=args.interval_list,
    )
    eval_loader = DataLoader(data_set,
                             batch_size=args.batch_size,
                             num_workers=args.workers)

    #### random seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    cudnn.benchmark = True
    #cudnn.deterministic = True

    print("===> Building model")
    #### create model
    model = EDVR_arch.EDVR(nf=args.nf,
                           nframes=args.n_frames,
                           groups=args.groups,
                           front_RBs=args.front_RBs,
                           back_RBs=args.back_RBs,
                           center=args.center,
                           predeblur=args.predeblur,
                           HR_in=args.HR_in,
                           w_TSA=args.w_TSA)
    print("===> Setting GPU")
    gups = args.gpus if args.gpus != 0 else torch.cuda.device_count()
    device_ids = list(range(gups))
    model = DataParallel(model, device_ids=device_ids)
    model = model.cuda()

    # print(model)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isdir(args.resume):
            # 获取目录中最后一个
            pth_list = sorted(glob(os.path.join(args.resume, '*.pth')))
            if len(pth_list) > 0:
                args.resume = pth_list[-1]
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            state_dict = checkpoint['state_dict']

            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                namekey = 'module.' + k  # remove `module.`
                new_state_dict[namekey] = v
            model.load_state_dict(new_state_dict)

    #### training
    print("===> Eval")
    model.eval()
    with tqdm(total=(len(data_set) - len(data_set) % args.batch_size)) as t:
        for data in eval_loader:
            data_x = data['LRs'].cuda()
            names = data['files']

            with torch.no_grad():
                outputs = model(data_x).data.float().cpu()
            outputs = outputs * 255.
            outputs = outputs.clamp_(0, 255).numpy()
            for img, file in zip(outputs, names):
                img = np.transpose(img[[2, 1, 0], :, :], (1, 2, 0))
                img = img.round()

                arr = file.split('/')
                dst_dir = os.path.join(args.outputs_dir, arr[-2])
                if not os.path.exists(dst_dir):
                    os.makedirs(dst_dir)
                dst_name = os.path.join(dst_dir, arr[-1])

                cv2.imwrite(dst_name, img)
            t.update(len(names))
예제 #14
0
class Im2latex(BaseAgent):
    def __init__(self, cfg):
        super().__init__(cfg)
        self.device = get_device()
        cfg.device = self.device
        self.cfg = cfg

        # dataset
        train_dataset = Im2LatexDataset(cfg, mode="train")
        self.id2token = train_dataset.id2token
        self.token2id = train_dataset.token2id

        collate = custom_collate(self.token2id, cfg.max_len)

        self.train_loader = DataLoader(train_dataset,
                                       batch_size=cfg.bs,
                                       shuffle=cfg.data_shuffle,
                                       num_workers=cfg.num_w,
                                       collate_fn=collate,
                                       drop_last=True)
        if cfg.valid_img_path != "":
            valid_dataset = Im2LatexDataset(cfg,
                                            mode="valid",
                                            vocab={
                                                'id2token': self.id2token,
                                                'token2id': self.token2id
                                            })
            self.valid_loader = DataLoader(valid_dataset,
                                           batch_size=cfg.bs //
                                           cfg.beam_search_k,
                                           shuffle=cfg.data_shuffle,
                                           num_workers=cfg.num_w,
                                           collate_fn=collate,
                                           drop_last=True)

        # define models
        self.model = Im2LatexModel(cfg)  # fill the parameters
        # weight initialization setting
        for name, param in self.model.named_parameters():
            if 'localization_fc2' in name:
                print(f'Skip {name} as it is already initialized')
                continue
            try:
                if 'bias' in name:
                    torch.nn.init.constant_(param, 0.0)
                elif 'weight' in name:
                    torch.nn.init.kaiming_normal_(param)
            except Exception as e:  # for batchnorm.
                if 'weight' in name:
                    param.data.fill_(1)
                continue

        self.model = DataParallel(self.model)
        # define criterion
        self.criterion = cal_loss

        self.optimizer = torch.optim.Adam(params=self.model.parameters(),
                                          lr=cfg.lr,
                                          betas=(cfg.adam_beta_1,
                                                 cfg.adam_beta_2))

        milestones = cfg.milestones
        self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer,
                                                              milestones,
                                                              gamma=cfg.gamma,
                                                              verbose=True)

        # initialize counter
        self.current_epoch = 1
        self.current_iteration = 1
        self.best_metric = 100
        self.best_info = ''

        # set the manual seed for torch
        torch.cuda.manual_seed_all(self.cfg.seed)
        if self.cfg.cuda:
            self.model = self.model.to(self.device)
            self.logger.info("Program will run on *****GPU-CUDA***** ")
        else:
            self.logger.info("Program will run on *****CPU*****\n")

        # Model Loading from cfg if not found start from scratch.
        self.exp_dir = os.path.join('./experiments', cfg.exp_name)
        self.load_checkpoint(cfg.checkpoint_filename)
        # Summary Writer
        self.summary_writer = SummaryWriter(
            log_dir=os.path.join(self.exp_dir, 'summaries'))

    def load_checkpoint(self, file_name):
        """
        Latest checkpoint loader
        :param file_name: name of the checkpoint file
        :return:
        """
        try:
            self.logger.info("Loading checkpoint '{}'".format(file_name))
            checkpoint = torch.load(file_name, map_location=self.device)

            self.current_epoch = checkpoint['epoch']
            self.current_iteration = checkpoint['iteration']
            self.model.load_state_dict(checkpoint['model'], strict=False)
            self.optimizer.load_state_dict(checkpoint['optimizer'])

            info = "Checkpoint loaded successfully from "
            self.logger.info(
                info + "'{}' at (epoch {}) at (iteration {})\n".format(
                    file_name, checkpoint['epoch'], checkpoint['iteration']))

        except OSError as e:
            self.logger.info("Checkpoint not found in '{}'.".format(file_name))
            self.logger.info("**First time to train**")

    def save_checkpoint(self, file_name="checkpoint.pth", is_best=False):
        """
        Checkpoint saver
        :param file_name: name of the checkpoint file
        :param is_best: boolean flag to indicate whether current
                        checkpoint's accuracy is the best so far
        :return:
        """
        state = {
            'epoch': self.current_epoch,
            'iteration': self.current_iteration,
            'model': self.model.state_dict(),
            'vocab': self.id2token,
            'optimizer': self.optimizer.state_dict()
        }

        # save the state
        checkpoint_dir = os.path.join(self.exp_dir, 'checkpoints')
        if is_best:
            torch.save(state, os.path.join(checkpoint_dir, 'best.pt'))
            self.best_info = 'best: e{}_i{}'.format(self.current_epoch,
                                                    self.current_iteration)
        else:
            file_name = "e{}-i{}.pt".format(self.current_epoch,
                                            self.current_iteration)
            torch.save(state, os.path.join(checkpoint_dir, file_name))

    def run(self):
        """
        The main operator
        :return:
        """
        try:
            if self.cfg.mode == 'train':
                self.train()
            elif self.cfg.mode == 'predict':
                self.predict()

        except KeyboardInterrupt:
            self.logger.info("You have entered CTRL+C.. Wait to finalize")

    def train(self):
        """
        Main training loop
        :return:
        """
        prev_perplexity = 0
        for e in range(self.current_epoch, self.cfg.epochs + 1):
            this_perplexity = self.train_one_epoch()
            self.save_checkpoint()
            self.scheduler.step()
            self.current_epoch += 1

        if self.cfg.valid_img_path:
            self.validate()

    def train_one_epoch(self):
        """
        One epoch of training
        :return:
        """
        tqdm_bar = tqdm(enumerate(self.train_loader, 1),
                        total=len(self.train_loader))

        self.model.train()
        last_avg_perplexity, avg_perplexity = 0, 0
        for i, (imgs, tgt) in tqdm_bar:
            imgs = imgs.float().to(self.device)
            tgt = tgt.long().to(self.device)

            # [B, MAXLEN, VOCABSIZE]
            logits = self.model(imgs, tgt, is_train=True)

            loss = self.criterion(logits, tgt)
            avg_perplexity += loss.item()

            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                           self.cfg.grad_clip)
            self.optimizer.step()
            self.current_iteration += 1

            # logging
            if i % self.cfg.log_freq == 0:
                avg_perplexity = avg_perplexity / self.cfg.log_freq
                self.summary_writer.add_scalar(
                    'perplexity/train',
                    avg_perplexity,
                    global_step=self.current_iteration)
                self.summary_writer.add_scalar(
                    'lr',
                    self.scheduler.get_last_lr(),
                    global_step=self.current_iteration)
                tqdm_bar.set_description("e{} | avg_perplexity: {:.3f}".format(
                    self.current_epoch, avg_perplexity))

                # save if best
                if avg_perplexity < self.best_metric:
                    self.save_checkpoint(is_best=True)

                    self.best_metric = avg_perplexity
                last_avg_perplexity = avg_perplexity
                avg_perplexity = 0

        mask = (tgt[0] != 2)
        pred = str(logits[0].argmax(1)[mask].cpu().detach().tolist())
        gt = str(tgt[0][mask].cpu().tolist())
        self.summary_writer.add_text('example/train',
                                     pred + '  \n' + gt,
                                     global_step=self.current_iteration)

        return last_avg_perplexity

    def validate(self):
        """
        One cycle of model validation
        :return:
        """
        tqdm_bar = tqdm(enumerate(self.valid_loader, 1),
                        total=len(self.valid_loader))
        self.model.eval()
        acc = 0
        with torch.no_grad():
            for i, (imgs, tgt) in tqdm_bar:
                imgs = imgs.to(self.device).float()
                tgt = tgt.to(self.device).long()

                logits = self.model(
                    imgs, is_train=False).long()  # [B, MAXLEN, VOCABSIZE]
                # mask = (tgt == 2)
                # tgt[mask] = 1
                # logits[mask] = 1
                acc += torch.all(tgt == logits, dim=1).sum() / imgs.size(0)
                # print('t', tgt)
                # print('l', logits)
                tqdm_bar.set_description('acc {:.4f}'.format(acc / i))
                if i % self.cfg.log_freq == 0:
                    self.summary_writer.add_scalar(
                        'accuracy/valid',
                        acc.item() / i,
                        global_step=self.current_iteration)

    def predict(self):
        """
        get predict results
        :return:
        """
        from torchvision import transforms
        from pathlib import Path
        from PIL import Image
        from time import time

        self.model.eval()
        transform = transforms.ToTensor()
        image_path = Path(self.cfg.test_img_path)

        t = time()
        with torch.no_grad():
            images = []
            imgPath = list(image_path.glob('*.jpg')) + list(
                image_path.glob('*.png'))
            for i, img in enumerate(imgPath):
                print(i, ':', img)
                img = Image.open(img)
                img = transform(img)
                images.append(img)
            images = torch.stack(images, dim=0)
            out = self.model(images)  # [B, max_len, vocab_size]
            # out = out.argmax(2)

        for i, output in enumerate(out):
            print(
                i, ' '.join([
                    self.id2token[out.item()] for out in output
                    if out.item() != 1
                ]))

        print(time() - t)

    def finalize(self):
        """
        Finalizes all the operations of the 2 Main classes of the process,
        the operator and the data loader
        :return:
        """
        print(self.best_info)
        pass
예제 #15
0
class SRGANModel(BaseModel):
    def __init__(self, opt, dataset=None):
        super(SRGANModel, self).__init__(opt)

        if dataset:
            self.cri_text = True

        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        if self.is_train:
            self.netD = networks.define_D(opt).to(self.device)
            if opt['dist']:
                self.netD = DistributedDataParallel(
                    self.netD, device_ids=[torch.cuda.current_device()])
            else:
                self.netD = DataParallel(self.netD)

            self.netG.train()
            self.netD.train()

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)
                if opt['dist']:
                    pass  # do not need to use DistributedDataParallel for netF
                else:
                    self.netF = DataParallel(self.netF)
            if self.cri_text:
                from lib.models.model_builder import ModelBuilder
                self.netT = ModelBuilder(
                    arch="ResNet_ASTER",
                    rec_num_classes=dataset.rec_num_classes,
                    sDim=512,
                    attDim=512,
                    max_len_labels=100,
                    eos=dataset.char2id[dataset.EOS],
                    STN_ON=True).to(self.device)

                self.netT = DataParallel(self.netT)
                self.netT.eval()
                from lib.util.serialization import load_checkpoint
                checkpoint = load_checkpoint(train_opt['text_model'])
                self.netT.load_state_dict(checkpoint['state_dict'])

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = train_opt['gan_weight']
            # D_update_ratio and D_init_iters
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                'D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                'D_init_iters'] else 0

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1_G'],
                                                       train_opt['beta2_G']))
            self.optimizers.append(self.optimizer_G)
            # D
            wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=train_opt['lr_D'],
                                                weight_decay=wd_D,
                                                betas=(train_opt['beta1_D'],
                                                       train_opt['beta2_D']))
            self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

        self.print_network()  # print network
        self.load()  # load G and D if needed

    def feed_data(self, data, need_GT=True):
        self.var_L = data['LQ'].to(self.device)  # LQ
        if need_GT:
            self.var_H = data['GT'].to(self.device)  # GT
            input_ref = data['ref'] if 'ref' in data else data['GT']
            self.var_ref = input_ref.to(self.device)

    def optimize_parameters(self, step, text_input=None):
        # G
        for p in self.netD.parameters():
            p.requires_grad = False

        self.optimizer_G.zero_grad()
        self.fake_H = self.netG(self.var_L)

        l_g_total = 0
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if self.cri_pix:  # pixel loss
                l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
                l_g_total += l_g_pix
            if self.cri_fea:  # feature loss
                real_fea = self.netF(self.var_H).detach()
                fake_fea = self.netF(self.fake_H)
                l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
                l_g_total += l_g_fea

            if self.cri_text:
                _, label, length = text_input
                input_dict = {}
                input_dict['images'] = self.fake_H
                input_dict['rec_target'] = label
                input_dict['rec_length'] = length
                output_dict = self.netT(input_dict)
                l_g_total += output_dict['losses']['loss_rec'].mean(dim=0)

            if self.opt['train']['gan_type'] == 'gan':
                pred_g_fake = self.netD(self.fake_H)
                l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
            elif self.opt['train']['gan_type'] == 'ragan':
                pred_d_real = self.netD(self.var_ref).detach()
                pred_g_fake = self.netD(self.fake_H)
                l_g_gan = self.l_gan_w * (
                    self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False)
                    + self.cri_gan(pred_g_fake - torch.mean(pred_d_real),
                                   True)) / 2
            l_g_total += l_g_gan

            l_g_total.backward()
            self.optimizer_G.step()

        # D
        for p in self.netD.parameters():
            p.requires_grad = True

        self.optimizer_D.zero_grad()
        if self.opt['train']['gan_type'] == 'gan':
            # need to forward and backward separately, since batch norm statistics differ
            # real
            pred_d_real = self.netD(self.var_ref)
            l_d_real = self.cri_gan(pred_d_real, True)
            l_d_real.backward()
            # fake
            pred_d_fake = self.netD(
                self.fake_H.detach())  # detach to avoid BP to G
            l_d_fake = self.cri_gan(pred_d_fake, False)
            l_d_fake.backward()
        elif self.opt['train']['gan_type'] == 'ragan':
            # pred_d_real = self.netD(self.var_ref)
            # pred_d_fake = self.netD(self.fake_H.detach())  # detach to avoid BP to G
            # l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True)
            # l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
            # l_d_total = (l_d_real + l_d_fake) / 2
            # l_d_total.backward()
            pred_d_fake = self.netD(self.fake_H.detach()).detach()
            pred_d_real = self.netD(self.var_ref)
            l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake),
                                    True) * 0.5
            l_d_real.backward()
            pred_d_fake = self.netD(self.fake_H.detach())
            l_d_fake = self.cri_gan(
                pred_d_fake - torch.mean(pred_d_real.detach()), False) * 0.5
            l_d_fake.backward()
        self.optimizer_D.step()

        # set log
        if step % self.D_update_ratio == 0 and step > self.D_init_iters:
            if self.cri_pix:
                self.log_dict['l_g_pix'] = l_g_pix.item()
            if self.cri_fea:
                self.log_dict['l_g_fea'] = l_g_fea.item()
            self.log_dict['l_g_gan'] = l_g_gan.item()

        self.log_dict['l_d_real'] = l_d_real.item()
        self.log_dict['l_d_fake'] = l_d_fake.item()
        self.log_dict['D_real'] = torch.mean(pred_d_real.detach())
        self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())

    def test(self):
        self.netG.eval()
        with torch.no_grad():
            self.fake_H = self.netG(self.var_L)
        self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_GT=True):
        out_dict = OrderedDict()
        out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
        out_dict['rlt'] = self.fake_H.detach()[0].float().cpu()
        if need_GT:
            out_dict['GT'] = self.var_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        # Generator
        s, n = self.get_network_description(self.netG)
        if isinstance(self.netG, nn.DataParallel) or isinstance(
                self.netG, DistributedDataParallel):
            net_struc_str = '{} - {}'.format(
                self.netG.__class__.__name__,
                self.netG.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.netG.__class__.__name__)
        if self.rank <= 0:
            logger.info(
                'Network G structure: {}, with parameters: {:,d}'.format(
                    net_struc_str, n))
            logger.info(s)
        if self.is_train:
            # Discriminator
            s, n = self.get_network_description(self.netD)
            if isinstance(self.netD, nn.DataParallel) or isinstance(
                    self.netD, DistributedDataParallel):
                net_struc_str = '{} - {}'.format(
                    self.netD.__class__.__name__,
                    self.netD.module.__class__.__name__)
            else:
                net_struc_str = '{}'.format(self.netD.__class__.__name__)
            if self.rank <= 0:
                logger.info(
                    'Network D structure: {}, with parameters: {:,d}'.format(
                        net_struc_str, n))
                logger.info(s)

            if self.cri_fea:  # F, Perceptual Network
                s, n = self.get_network_description(self.netF)
                if isinstance(self.netF, nn.DataParallel) or isinstance(
                        self.netF, DistributedDataParallel):
                    net_struc_str = '{} - {}'.format(
                        self.netF.__class__.__name__,
                        self.netF.module.__class__.__name__)
                else:
                    net_struc_str = '{}'.format(self.netF.__class__.__name__)
                if self.rank <= 0:
                    logger.info(
                        'Network F structure: {}, with parameters: {:,d}'.
                        format(net_struc_str, n))
                    logger.info(s)

    def load(self):
        load_path_G = self.opt['path']['pretrain_model_G']
        if load_path_G is not None:
            logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG,
                              self.opt['path']['strict_load'])
        load_path_D = self.opt['path']['pretrain_model_D']
        if self.opt['is_train'] and load_path_D is not None:
            logger.info('Loading model for D [{:s}] ...'.format(load_path_D))
            self.load_network(load_path_D, self.netD,
                              self.opt['path']['strict_load'])

    def save(self, iter_step):
        self.save_network(self.netG, 'G', iter_step)
        self.save_network(self.netD, 'D', iter_step)
예제 #16
0
class VisualizeProcess:
    def __init__(self):
        self.net = ET_Net()

        if (ARGS['gpu']):
            self.net = DataParallel(module=self.net.cuda())

        self.net.load_state_dict(torch.load(ARGS['weight']))

        self.train_dataset = get_dataset(dataset_name=ARGS['dataset'],
                                         part='train')
        self.val_dataset = get_dataset(dataset_name=ARGS['dataset'],
                                       part='val')

    def visualize(self):

        start = time.time()
        self.net.eval()
        val_batch_size = min(ARGS['batch_size'], len(self.val_dataset))
        val_dataloader = DataLoader(self.val_dataset,
                                    batch_size=val_batch_size)
        for batch_index, items in enumerate(val_dataloader):
            images, labels, edges = items['image'], items['label'], items[
                'edge']
            images = images.float()
            labels = labels.long()
            edges = edges.long()

            if ARGS['gpu']:
                labels = labels.cuda()
                images = images.cuda()
                edges = edges.cuda()

            print('image shape:', images.size())

            with torch.no_grad():
                outputs_edge, outputs = self.net(images)

            pred = torch.max(outputs, dim=1)[1]
            iou = torch.sum(pred[0] & labels[0]) / (
                torch.sum(pred[0] | labels[0]) + 1e-6)

            mean = torch.FloatTensor([123.68, 116.779, 103.939]).reshape(
                (3, 1, 1)) / 255.
            images = images + mean.cuda()

            # images *= 255.
            print('pred min: ', pred[0].min(), ' max: ', pred[0].max())
            print('label min:', labels[0].min(), ' max: ', labels[0].max())
            print('edge min:', edges[0].min(), ' max: ', edges[0].max())
            print('output edge min:', outputs_edge[0].min(), ' max: ',
                  outputs_edge[0].max())
            print('IoU:', iou)
            print('Intersect num:', torch.sum(pred[0] & labels[0]))
            print('Union num:', torch.sum(pred[0] | labels[0]))

            plt.subplot(221)
            plt.imshow(images[0].cpu().numpy().transpose(
                (1, 2, 0))), plt.axis('off')
            plt.subplot(222)
            plt.imshow(labels[0].cpu().numpy(), cmap='gray'), plt.axis('off')
            plt.subplot(223)
            # plt.imshow(pred[0].cpu().numpy(), cmap='gray'), plt.axis('off')
            plt.imshow(outputs[0, 1].cpu().numpy(),
                       cmap='gray'), plt.axis('off')
            plt.subplot(224)
            plt.imshow(outputs_edge[0, 1].cpu().numpy(),
                       cmap='gray'), plt.axis('off')
            plt.show()

            # update training loss for each iteration
            # self.writer.add_scalar('Train/loss', loss.item(), n_iter)

        finish = time.time()

        print('validating time consumed: {:.2f}s'.format(finish - start))
예제 #17
0
class BaseEngine(object):
    def __init__(self, args):
        self._make_dataset(args)
        self._make_model(args)
        tc.manual_seed(args.seed)
        if args.cuda and tc.cuda.is_available():
            tc.cuda.manual_seed_all(args.seed)
            if tc.cuda.device_count() > 1:
                self.batch_size = args.batch_size * tc.cuda.device_count()
                self.model = DataParallel(self.model)
            else:
                self.batch_size = args.batch_size
                self.model = self.model.cuda()
        else:
            self.batch_size = args.batch_size
        self._make_optimizer(args)
        self._make_loss(args)
        self._make_metric(args)
        self.num_training_samples = args.num_training_samples
        self.tag = args.tag or 'default'
        self.dump_dir = get_dir(args.dump_dir)
        self.train_logger = get_logger('train.{}.{}'.format(
            self.__class__.__name__, self.tag))

    def _make_dataset(self, args):
        raise NotImplementedError

    def _make_model(self, args):
        raise NotImplementedError

    def _make_optimizer(self, args):
        raise NotImplementedError

    def _make_loss(self, args):
        raise NotImplementedError

    def _make_metric(self, args):
        raise NotImplementedError

    def dump(self, epoch, model=True, optimizer=True, decayer=True):
        state = {'epoch': epoch}
        if model:
            state['model'] = self.model.state_dict()
        if optimizer:
            state['optimizer'] = self.optimizer.state_dict()
        if decayer and (getattr(self, 'decayer', None) is not None):
            state['decayer'] = self.decayer.state_dict()
        tc.save(state,
                os.path.join(self.dump_dir, 'state_{}.pkl'.format(self.tag)))
        self.train_logger.info('Checkpoint {} dumped'.format(self.tag))

    def load(self, model=True, optimizer=True, decayer=True):
        try:
            state = tc.load(
                os.path.join(self.dump_dir, 'state_{}.pkl'.format(self.tag)))
        except FileNotFoundError:
            return 0
        if model and (state.get('model') is not None):
            self.model.load_state_dict(state['model'])
        if optimizer and (state.get('optimizer') is not None):
            self.optimizer.load_state_dict(state['optimizer'])
        if decayer and (state.get('decayer') is not None) and (getattr(
                self, 'decayer', None) is not None):
            self.decayer.load_state_dict(state['decayer'])
        return state['epoch']

    def eval(self):
        raise NotImplementedError

    def test(self):
        raise NotImplementedError

    def train(self, num_epochs, resume=False):
        raise NotImplementedError
예제 #18
0
파일: trainer.py 프로젝트: kyuyeonpooh/STTN
class Trainer():
    def __init__(self, config, debug=False):
        self.config = config
        self.epoch = 0
        self.iteration = 0
        if debug:
            self.config['trainer']['save_freq'] = 5
            self.config['trainer']['valid_freq'] = 5
            self.config['trainer']['iterations'] = 5

        # setup data set and data loader
        self.train_dataset = AVEDataset(config['data_loader'], split='train')
        self.train_sampler = None
        self.train_args = config['trainer']
        self.train_loader = DataLoader(
            self.train_dataset,
            batch_size=self.train_args['batch_size'],
            shuffle=True,
            num_workers=self.train_args['num_workers'],
            pin_memory=True)

        # set loss functions
        self.adversarial_loss = AdversarialLoss(
            type=self.config['losses']['GAN_LOSS'])
        self.adversarial_loss = self.adversarial_loss.to(self.config['device'])
        self.l1_loss = nn.L1Loss()

        # setup models including generator and discriminator
        net = importlib.import_module('model.' + config['model'])
        self.netG = net.InpaintGenerator()
        self.netG = self.netG.to(self.config['device'])
        self.netD = net.Discriminator(
            in_channels=3, use_sigmoid=config['losses']['GAN_LOSS'] != 'hinge')
        self.netD = self.netD.to(self.config['device'])
        self.optimG = torch.optim.Adam(self.netG.parameters(),
                                       lr=config['trainer']['lr'],
                                       betas=(self.config['trainer']['beta1'],
                                              self.config['trainer']['beta2']))
        self.optimD = torch.optim.Adam(self.netD.parameters(),
                                       lr=config['trainer']['lr'],
                                       betas=(self.config['trainer']['beta1'],
                                              self.config['trainer']['beta2']))
        self.load()

        if config['distributed']:
            self.netG = DataParallel(self.netG)
            self.netD = DataParallel(self.netD)

        # set summary writer
        self.dis_writer = None
        self.gen_writer = None
        self.summary = {}
        if self.config['global_rank'] == 0 or (not config['distributed']):
            self.dis_writer = SummaryWriter(
                os.path.join(config['save_dir'], 'dis'))
            self.gen_writer = SummaryWriter(
                os.path.join(config['save_dir'], 'gen'))

    # get current learning rate
    def get_lr(self):
        return self.optimG.param_groups[0]['lr']

    # learning rate scheduler, step
    def adjust_learning_rate(self):
        decay = 0.1**(
            min(self.iteration, self.config['trainer']['niter_steady']) //
            self.config['trainer']['niter'])
        new_lr = self.config['trainer']['lr'] * decay
        if new_lr != self.get_lr():
            for param_group in self.optimG.param_groups:
                param_group['lr'] = new_lr
            for param_group in self.optimD.param_groups:
                param_group['lr'] = new_lr

    # add summary
    def add_summary(self, writer, name, val):
        if name not in self.summary:
            self.summary[name] = 0
        self.summary[name] += val
        if writer is not None and self.iteration % 100 == 0:
            writer.add_scalar(name, self.summary[name] / 100, self.iteration)
            self.summary[name] = 0

    # add image
    def add_images(self, writer, input_image, output_image, gt_image):
        if writer is not None and self.iteration % 100 == 0:
            b, t, c, h, w = input_image.size()
            input_image = input_image.view(b * t, c, h, w)
            output_image = output_image.view(b * t, c, h, w)
            gt_image = gt_image.view(b * t, c, h, w)
            writer.add_image("input/input_image",
                             make_grid((input_image + 1) / 2, t),
                             self.iteration)
            writer.add_image("output/output_image",
                             make_grid((output_image + 1) / 2, t),
                             self.iteration)
            writer.add_image("output/gt_image", make_grid((gt_image + 1) / 2,
                                                          t), self.iteration)

    # load netG and netD
    def load(self):
        model_path = self.config['save_dir']
        if os.path.isfile(os.path.join(model_path, 'latest.ckpt')):
            latest_epoch = open(os.path.join(model_path, 'latest.ckpt'),
                                'r').read().splitlines()[-1]
        else:
            ckpts = [
                os.path.basename(i).split('.pth')[0]
                for i in glob.glob(os.path.join(model_path, '*.pth'))
            ]
            ckpts.sort()
            latest_epoch = ckpts[-1] if len(ckpts) > 0 else None
        if latest_epoch is not None:
            gen_path = os.path.join(
                model_path, 'gen_{}.pth'.format(str(latest_epoch).zfill(5)))
            dis_path = os.path.join(
                model_path, 'dis_{}.pth'.format(str(latest_epoch).zfill(5)))
            opt_path = os.path.join(
                model_path, 'opt_{}.pth'.format(str(latest_epoch).zfill(5)))
            if self.config['global_rank'] == 0:
                print('Loading model from {}...'.format(gen_path))
            data = torch.load(gen_path, map_location=self.config['device'])
            self.netG.load_state_dict(data['netG'])
            data = torch.load(dis_path, map_location=self.config['device'])
            self.netD.load_state_dict(data['netD'])
            data = torch.load(opt_path, map_location=self.config['device'])
            self.optimG.load_state_dict(data['optimG'])
            self.optimD.load_state_dict(data['optimD'])
            self.epoch = data['epoch']
            self.iteration = data['iteration']
        else:
            if self.config['global_rank'] == 0:
                print(
                    'Warnning: There is no trained model found. An initialized model will be used.'
                )

    # save parameters every eval_epoch
    def save(self, it):
        if self.config['global_rank'] == 0:
            gen_path = os.path.join(self.config['save_dir'],
                                    'gen_{}.pth'.format(str(it).zfill(5)))
            dis_path = os.path.join(self.config['save_dir'],
                                    'dis_{}.pth'.format(str(it).zfill(5)))
            opt_path = os.path.join(self.config['save_dir'],
                                    'opt_{}.pth'.format(str(it).zfill(5)))
            print('\nsaving model to {} ...'.format(gen_path))
            if isinstance(self.netG, torch.nn.DataParallel) or isinstance(
                    self.netG, DDP):
                netG = self.netG.module
                netD = self.netD.module
            else:
                netG = self.netG
                netD = self.netD
            torch.save({'netG': netG.state_dict()}, gen_path)
            torch.save({'netD': netD.state_dict()}, dis_path)
            torch.save(
                {
                    'epoch': self.epoch,
                    'iteration': self.iteration,
                    'optimG': self.optimG.state_dict(),
                    'optimD': self.optimD.state_dict()
                }, opt_path)
            os.system('echo {} > {}'.format(
                str(it).zfill(5),
                os.path.join(self.config['save_dir'], 'latest.ckpt')))

    # train entry
    def train(self):
        pbar = range(int(self.train_args['iterations']))
        if self.config['global_rank'] == 0:
            pbar = tqdm(pbar,
                        initial=self.iteration,
                        dynamic_ncols=True,
                        smoothing=0.01)

        while True:
            self.epoch += 1
            # if self.config['distributed']:
            #     self.train_sampler.set_epoch(self.epoch)

            self._train_epoch(pbar)
            if self.iteration > self.train_args['iterations']:
                break
        print('\nEnd training....')

    # process input and calculate loss every training epoch
    def _train_epoch(self, pbar):
        device = self.config['device']

        for frames, masks in self.train_loader:
            self.adjust_learning_rate()
            self.iteration += 1

            frames, masks = frames.to(device), masks.to(device)
            b, t, c, h, w = frames.size()
            masked_frame = (frames * (1 - masks).float())
            pred_img = self.netG(masked_frame, masks)
            frames = frames.view(b * t, c, h, w)
            masks = masks.view(b * t, 1, h, w)
            comp_img = frames * (1. - masks) + masks * pred_img

            gen_loss = 0
            dis_loss = 0

            # discriminator adversarial loss
            real_vid_feat = self.netD(frames)
            fake_vid_feat = self.netD(comp_img.detach())
            dis_real_loss = self.adversarial_loss(real_vid_feat, True, True)
            dis_fake_loss = self.adversarial_loss(fake_vid_feat, False, True)
            dis_loss += (dis_real_loss + dis_fake_loss) / 2
            self.add_summary(self.dis_writer, 'loss/dis_vid_fake',
                             dis_fake_loss.item())
            self.add_summary(self.dis_writer, 'loss/dis_vid_real',
                             dis_real_loss.item())
            self.optimD.zero_grad()
            dis_loss.backward()
            self.optimD.step()

            # generator adversarial loss
            gen_vid_feat = self.netD(comp_img)
            gan_loss = self.adversarial_loss(gen_vid_feat, True, False)
            gan_loss = gan_loss * self.config['losses']['adversarial_weight']
            gen_loss += gan_loss
            self.add_summary(self.gen_writer, 'loss/gan_loss', gan_loss.item())

            # generator l1 loss
            hole_loss = self.l1_loss(pred_img * masks, frames * masks)
            hole_loss = hole_loss / torch.mean(
                masks) * self.config['losses']['hole_weight']
            gen_loss += hole_loss
            self.add_summary(self.gen_writer, 'loss/hole_loss',
                             hole_loss.item())

            valid_loss = self.l1_loss(pred_img * (1 - masks),
                                      frames * (1 - masks))
            valid_loss = valid_loss / torch.mean(
                1 - masks) * self.config['losses']['valid_weight']
            gen_loss += valid_loss
            self.add_summary(self.gen_writer, 'loss/valid_loss',
                             valid_loss.item())

            self.optimG.zero_grad()
            gen_loss.backward()
            self.optimG.step()

            self.add_images(self.gen_writer,
                            masked_frame.cpu().detach(),
                            comp_img.cpu().detach(),
                            frames.cpu().detach())

            # console logs
            if self.config['global_rank'] == 0:
                pbar.update(1)
                pbar.set_description((
                    f"d: {dis_loss.item():.3f}; g: {gan_loss.item():+.3f}; "
                    f"hole: {hole_loss.item():.3f}; valid: {valid_loss.item():.3f}"
                ))

            # saving models
            if self.iteration % self.train_args['save_freq'] == 0:
                self.save(int(self.iteration // self.train_args['save_freq']))
            if self.iteration > self.train_args['iterations']:
                break
예제 #19
0
def main():
    parser = argparse.ArgumentParser(description="PyTorch Object Detection Training")
    parser.add_argument(
        "--config-file",
        default="",
        metavar="FILE",
        help="path to config file",
        type=str,
    )
    args = parser.parse_args()
    cfg.merge_from_file(args.config_file)
    cfg.freeze()

    viewer = Visualizer(cfg.OUTPUT_DIR)
    #Model
    model = build_model(cfg)
    model = DataParallel(model).cuda()
    if cfg.MODEL.WEIGHT !="":
        model.module.backbone.load_state_dict(torch.load(cfg.MODEL.WEIGHT))
        #freeze backbone
        for key,val in model.module.backbone.named_parameters():
            val.requires_grad = False


    batch_time = AverageMeter()
    data_time = AverageMeter()

    #optimizer
    optimizer = getattr(torch.optim,cfg.SOLVER.OPTIM)(model.parameters(),lr = cfg.SOLVER.BASE_LR,weight_decay=cfg.SOLVER.WEIGHT_DECAY)
    lr_sche = torch.optim.lr_scheduler.MultiStepLR(optimizer,cfg.SOLVER.STEPS,gamma= cfg.SOLVER.GAMMA)

    #dataset
    datasets  = make_dataset(cfg,is_train=False)
    dataloaders = make_dataloaders(cfg,datasets,False)
    iter_epoch = (cfg.SOLVER.MAX_ITER)//len(dataloaders[0])+1
    if not os.path.exists(cfg.OUTPUT_DIR):
        os.mkdir(cfg.OUTPUT_DIR)
    ite = 0
    batch_it = [i *cfg.SOLVER.IMS_PER_BATCH for i in range(1,4)]


    # start time
    start = time.time()
    inference_list = ['resnet18_14.pth','resnet18_13.pth','resnet18_12.pth','resnet18_11.pth','resnet18_10.pth']
    for inference_weight in inference_list:
        model.load_state_dict(torch.load(os.path.join(resume_dir,inference_weight)))
        model.eval()

        total_count = 0
        one_count = 0
        two_count = 0
        three_count = 0
        one_number = 0
        two_number = 0
        three_number = 0
        for dataloader in dataloaders:
            for imgs,labels,types in tqdm.tqdm(dataloader,desc="dataloader:"):
                types = np.asarray(types)
                lr_sche.step()
                data_time.update(time.time() - start)

                inputs = torch.cat([imgs[0].cuda(),imgs[1].cuda(),imgs[2].cuda()],dim=0)
                with torch.no_grad():
                    features = model(inputs)
                acc,batch_loss = loss_opts.batch_triple_loss_acc(features,labels,types,size_average=True)
                print(batch_loss)
                xxx
                total_count+= batch_loss.shape[0]-acc

                ONE_CLASS = (batch_loss[np.nonzero(types=='ONE_CLASS_TRIPLET')[0]])
                TWO_CLASS = (batch_loss[np.nonzero(types=='TWO_CLASS_TRIPLET')[0]])
                THREE_CLASS = (batch_loss[np.nonzero(types=='THREE_CLASS_TRIPLET')[0]])
                one_count += ONE_CLASS.shape[0] - torch.nonzero(ONE_CLASS).shape[0]
                two_count += TWO_CLASS.shape[0] - torch.nonzero(TWO_CLASS).shape[0]
                three_count += THREE_CLASS.shape[0] - torch.nonzero(THREE_CLASS).shape[0]
                one_number+=ONE_CLASS.shape[0]
                two_number+=TWO_CLASS.shape[0]
                three_number+=THREE_CLASS.shape[0]
                # viewer.line("train/loss",loss.item()*100,ite)
        print(inference_weight,total_count/(one_number+two_number+three_number),one_count/one_number,two_count/two_number,three_count/three_number)
예제 #20
0
class CalculateMetricProcess:
    def __init__(self):
        self.net = ET_Net()

        if (ARGS['gpu']):
            self.net = DataParallel(module=self.net.cuda())

        self.net.load_state_dict(torch.load(ARGS['weight']))

        self.metric_dataset = get_dataset(dataset_name=ARGS['dataset'],
                                          part='metric')

    def predict(self):

        start = time.time()
        self.net.eval()
        metric_dataloader = DataLoader(
            self.metric_dataset, batch_size=1)  # only support batch size = 1
        os.makedirs(ARGS['prediction_save_folder'], exist_ok=True)
        y_true = []
        y_pred = []
        for items in metric_dataloader:
            images, labels, mask = items['image'], items['label'], items[
                'mask']
            images = images.float()
            print('image shape:', images.size())

            image_patches, big_h, big_w = get_test_patches(
                images, ARGS['crop_size'], ARGS['stride_size'])
            test_patch_dataloader = DataLoader(image_patches,
                                               batch_size=ARGS['batch_size'],
                                               shuffle=False,
                                               drop_last=False)
            test_results = []
            print('Number of batches for testing:', len(test_patch_dataloader))

            for patches in test_patch_dataloader:

                if ARGS['gpu']:
                    patches = patches.cuda()

                with torch.no_grad():
                    result_patches_edge, result_patches = self.net(patches)

                test_results.append(result_patches.cpu())

            test_results = torch.cat(test_results, dim=0)
            # merge
            test_results = recompone_overlap(test_results, ARGS['crop_size'],
                                             ARGS['stride_size'], big_h, big_w)
            test_results = test_results[:, 1, :images.size(2), :images.size(3)]
            y_pred.append(test_results[mask == 1].reshape(-1))
            y_true.append(labels[mask == 1].reshape(-1))

        y_pred = torch.cat(y_pred).numpy()
        y_true = torch.cat(y_true).numpy()
        calc_metrics(y_pred, y_true)
        finish = time.time()

        print('Calculating metric time consumed: {:.2f}s'.format(finish -
                                                                 start))
예제 #21
0
파일: trainEDVR.py 프로젝트: xuexiy1ge/AI4K
def main(args):
    print("===> Loading datasets")
    data_set = DatasetLoader(args.data_lr,
                             args.data_hr,
                             size_w=args.size_w,
                             size_h=args.size_h,
                             scale=args.scale,
                             n_frames=args.n_frames,
                             interval_list=args.interval_list,
                             border_mode=args.border_mode,
                             random_reverse=args.random_reverse)
    train_loader = DataLoader(data_set,
                              batch_size=args.batch_size,
                              num_workers=args.workers,
                              shuffle=True,
                              pin_memory=False,
                              drop_last=True)

    #### random seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    cudnn.benchmark = True
    #cudnn.deterministic = True

    print("===> Building model")
    #### create model
    model = EDVR_arch.EDVR(nf=args.nf,
                           nframes=args.n_frames,
                           groups=args.groups,
                           front_RBs=args.front_RBs,
                           back_RBs=args.back_RBs,
                           center=args.center,
                           predeblur=args.predeblur,
                           HR_in=args.HR_in,
                           w_TSA=args.w_TSA)
    criterion = CharbonnierLoss()
    print("===> Setting GPU")
    gups = args.gpus if args.gpus != 0 else torch.cuda.device_count()
    device_ids = list(range(gups))
    model = DataParallel(model, device_ids=device_ids)
    model = model.cuda()
    criterion = criterion.cuda()

    # print(model)

    start_epoch = args.start_epoch
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isdir(args.resume):
            # 获取目录中最后一个
            pth_list = sorted(glob(os.path.join(args.resume, '*.pth')))
            if len(pth_list) > 0:
                args.resume = pth_list[-1]
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)

            start_epoch = checkpoint['epoch'] + 1
            state_dict = checkpoint['state_dict']

            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                namekey = 'module.' + k  # remove `module.`
                new_state_dict[namekey] = v
            model.load_state_dict(new_state_dict)

            # 如果文件中有lr,则不用启动参数
            args.lr = checkpoint.get('lr', args.lr)

        # 如果设置了 start_epoch 则不用checkpoint中的epoch参数
        start_epoch = args.start_epoch if args.start_epoch != 0 else start_epoch

    #如果use_current_lr大于0 测代替作为lr
    args.lr = args.use_current_lr if args.use_current_lr > 0 else args.lr
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                        model.parameters()),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay,
                                 betas=(args.beta1, args.beta2),
                                 eps=1e-8)

    #### training
    print("===> Training")
    for epoch in range(start_epoch, args.epochs):
        adjust_lr(optimizer, epoch)
        if args.use_tqdm == 1:
            losses, psnrs = one_epoch_train_tqdm(
                model, optimizer, criterion, len(data_set), train_loader,
                epoch, args.epochs, args.batch_size,
                optimizer.param_groups[0]["lr"])
        else:
            losses, psnrs = one_epoch_train_logger(
                model, optimizer, criterion, len(data_set), train_loader,
                epoch, args.epochs, args.batch_size,
                optimizer.param_groups[0]["lr"])

        # save model
        # if epoch %9 != 0:
        #     continue

        model_out_path = os.path.join(
            args.checkpoint, "model_epoch_%04d_edvr_loss_%.3f_psnr_%.3f.pth" %
            (epoch, losses.avg, psnrs.avg))
        if not os.path.exists(args.checkpoint):
            os.makedirs(args.checkpoint)
        torch.save(
            {
                'state_dict': model.module.state_dict(),
                "epoch": epoch,
                'lr': optimizer.param_groups[0]["lr"]
            }, model_out_path)
예제 #22
0
def create_and_test_triplet_network(batch_triplet_indices_loader,
                                    experiment_name,
                                    path_to_emb_net,
                                    unseen_triplets,
                                    dataset_name,
                                    model_name,
                                    logger,
                                    test_n,
                                    n,
                                    dim,
                                    layers,
                                    learning_rate=5e-2,
                                    epochs=20,
                                    hl_size=100):
    """
    Description: Constructs the OENN network, defines an optimizer and trains the network on the data w.r.t triplet loss.
    :param model_name:
    :param dataset_name:
    :param test_n:
    :param path_to_emb_net: Data loader object. Gives triplet indices in batches.
    :param n: # points
    :param dim: # features/ dimensions
    :param layers: # layers
    :param learning_rate: learning rate of optimizer.
    :param epochs: # epochs
    :param hl_size: # width of the hidden layer
    :param unseen_triplets: #TODO
    :param logger: # for logging
    :return:
    """

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)
    digits = int(math.ceil(math.log2(n)))

    #  Define train model
    emb_net_train = define_model(model_name=model_name,
                                 digits=digits,
                                 hl_size=hl_size,
                                 dim=dim,
                                 layers=layers)
    emb_net_train = emb_net_train.to(device)

    for param in emb_net_train.parameters():
        param.requires_grad = False

    if torch.cuda.device_count() > 1:
        emb_net_train = DataParallel(emb_net_train)
        print('multi-gpu')

    checkpoint = torch.load(path_to_emb_net)['model_state_dict']
    key_word = list(checkpoint.keys())[0].split('.')[0]
    if key_word == 'module':
        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in checkpoint.items():
            name = k[7:]  # remove `module.`
            new_state_dict[name] = v
        emb_net_train.load_state_dict(new_state_dict)
    else:
        emb_net_train.load_state_dict(checkpoint)

    emb_net_train.eval()

    #  Define test model
    emb_net_test = define_model(model_name=model_name,
                                digits=digits,
                                hl_size=hl_size,
                                dim=dim,
                                layers=layers)
    emb_net_test = emb_net_test.to(device)

    if torch.cuda.device_count() > 1:
        emb_net_test = DataParallel(emb_net_test)
        print('multi-gpu')

    # Optimizer
    optimizer = torch.optim.Adam(emb_net_test.parameters(), lr=learning_rate)
    criterion = nn.TripletMarginLoss(margin=1, p=2)
    criterion = criterion.to(device)

    logger.info('#### Dataset Selection #### \n')
    logger.info('dataset:', dataset_name)
    logger.info('#### Network and learning parameters #### \n')
    logger.info('------------------------------------------ \n')
    logger.info('Model Name: ' + model_name + '\n')
    logger.info('Number of hidden layers: ' + str(layers) + '\n')
    logger.info('Hidden layer width: ' + str(hl_size) + '\n')
    logger.info('Embedding dimension: ' + str(dim) + '\n')
    logger.info('Learning rate: ' + str(learning_rate) + '\n')
    logger.info('Number of epochs: ' + str(epochs) + '\n')

    logger.info(' #### Training begins #### \n')
    logger.info('---------------------------\n')

    digits = int(math.ceil(math.log2(n)))
    bin_array = data_utils.get_binary_array(n, digits)

    trip_data = torch.tensor(bin_array[unseen_triplets])
    trip = trip_data.squeeze().to(device).float()

    # Training begins
    train_time = 0
    for ep in range(epochs):
        # Epoch is one pass over the dataset
        epoch_loss = 0

        for batch_ind, trips in enumerate(batch_triplet_indices_loader):
            sys.stdout.flush()
            trip = trips.squeeze().to(device).float()

            # Training time
            begin_train_time = time.time()
            # Forward pass
            embedded_a = emb_net_test(trip[:, :digits])
            embedded_p = emb_net_train(trip[:, digits:2 * digits])
            embedded_n = emb_net_train(trip[:, 2 * digits:])
            # Compute loss
            loss = criterion(embedded_a, embedded_p, embedded_n).to(device)
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # End of training
            end_train_time = time.time()
            if batch_ind % 50 == 0:
                logger.info('Epoch: ' + str(ep) + ' Mini batch: ' +
                            str(batch_ind) + '/' +
                            str(len(batch_triplet_indices_loader)) +
                            ' Loss: ' + str(loss.item()))
                sys.stdout.flush()  # Prints faster to the out file
            epoch_loss += loss.item()
            train_time = train_time + end_train_time - begin_train_time

        # Log
        logger.info('Epoch ' + str(ep) + ' - Average Epoch Loss:  ' +
                    str(epoch_loss / len(batch_triplet_indices_loader)) +
                    ' Training time ' + str(train_time))
        sys.stdout.flush()  # Prints faster to the out file

        # Saving the results
        logger.info('Saving the models and the results')
        sys.stdout.flush()  # Prints faster to the out file

        os.makedirs('test_checkpoints', mode=0o777, exist_ok=True)
        model_path = 'test_checkpoints/' + \
                     experiment_name + \
                     '.pt'
        torch.save(
            {
                'epochs': ep,
                'model_state_dict': emb_net_test.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss:': epoch_loss,
            }, model_path)

    # Compute the embedding of the data points.
    bin_array_test = data_utils.get_binary_array(test_n, digits)
    test_embeddings = emb_net_test(
        torch.Tensor(bin_array_test).cuda().float()).cpu().detach().numpy()
    train_embeddings = emb_net_train(
        torch.Tensor(bin_array).cuda().float()).cpu().detach().numpy()
    unseen_triplet_error, _ = data_utils.triplet_error_unseen(
        test_embeddings, train_embeddings, unseen_triplets)

    logger.info('Unseen triplet error is ' + str(unseen_triplet_error))
    return unseen_triplet_error
예제 #23
0
class TrainValProcess():
    def __init__(self):
        self.net = ET_Net()
        if (ARGS['weight']):
            self.net.load_state_dict(torch.load(ARGS['weight']))
        else:
            self.net.load_encoder_weight()
        if (ARGS['gpu']):
            self.net = DataParallel(module=self.net.cuda())

        self.train_dataset = get_dataset(dataset_name=ARGS['dataset'],
                                         part='train')
        self.val_dataset = get_dataset(dataset_name=ARGS['dataset'],
                                       part='val')

        self.optimizer = Adam(self.net.parameters(), lr=ARGS['lr'])
        # Use / to get an approximate result, // to get an accurate result
        total_iters = len(
            self.train_dataset) // ARGS['batch_size'] * ARGS['num_epochs']
        self.lr_scheduler = LambdaLR(
            self.optimizer,
            lr_lambda=lambda iter:
            (1 - iter / total_iters)**ARGS['scheduler_power'])
        self.writer = SummaryWriter()

    def train(self, epoch):

        start = time.time()
        self.net.train()
        train_dataloader = DataLoader(self.train_dataset,
                                      batch_size=ARGS['batch_size'],
                                      shuffle=False)
        epoch_loss = 0.
        for batch_index, items in enumerate(train_dataloader):
            images, labels, edges = items['image'], items['label'], items[
                'edge']
            images = images.float()
            labels = labels.long()
            edges = edges.long()

            if ARGS['gpu']:
                labels = labels.cuda()
                images = images.cuda()
                edges = edges.cuda()

            self.optimizer.zero_grad()
            outputs_edge, outputs = self.net(images)
            # print('output edge min:', outputs_edge[0, 1].min(), ' max: ', outputs_edge[0, 1].max())
            # plt.imshow(outputs_edge[0, 1].detach().cpu().numpy() * 255, cmap='gray')
            # plt.show()
            loss_edge = lovasz_softmax(outputs_edge,
                                       edges)  # Lovasz-Softmax loss
            loss_seg = lovasz_softmax(outputs, labels)  #
            loss = ARGS['combine_alpha'] * loss_seg + (
                1 - ARGS['combine_alpha']) * loss_edge
            loss.backward()
            self.optimizer.step()
            self.lr_scheduler.step()

            n_iter = (epoch - 1) * len(train_dataloader) + batch_index + 1

            pred = torch.max(outputs, dim=1)[1]
            iou = torch.sum(pred & labels) / (torch.sum(pred | labels) + 1e-6)

            # print('edge min:', edges.min(), ' max: ', edges.max())
            # print('output edge min:', outputs_edge.min(), ' max: ', outputs_edge.max())

            print(
                'Training Epoch: {epoch} [{trained_samples}/{total_samples}]\tL_edge: {:0.4f}\tL_seg: {:0.4f}\tL_all: {:0.4f}\tIoU: {:0.4f}\tLR: {:0.4f}'
                .format(loss_edge.item(),
                        loss_seg.item(),
                        loss.item(),
                        iou.item(),
                        self.optimizer.param_groups[0]['lr'],
                        epoch=epoch,
                        trained_samples=batch_index * ARGS['batch_size'],
                        total_samples=len(train_dataloader.dataset)))

            epoch_loss += loss.item()

            # update training loss for each iteration
            # self.writer.add_scalar('Train/loss', loss.item(), n_iter)

        for name, param in self.net.named_parameters():
            layer, attr = os.path.splitext(name)
            attr = attr[1:]
            self.writer.add_histogram("{}/{}".format(layer, attr), param,
                                      epoch)

        epoch_loss /= len(train_dataloader)
        self.writer.add_scalar('Train/loss', epoch_loss, epoch)
        finish = time.time()

        print('epoch {} training time consumed: {:.2f}s'.format(
            epoch, finish - start))

    def validate(self, epoch):

        start = time.time()
        self.net.eval()
        val_batch_size = min(ARGS['batch_size'], len(self.val_dataset))
        val_dataloader = DataLoader(self.val_dataset,
                                    batch_size=val_batch_size)
        epoch_loss = 0.
        for batch_index, items in enumerate(val_dataloader):
            images, labels, edges = items['image'], items['label'], items[
                'edge']
            # print('label min:', labels[0].min(), ' max: ', labels[0].max())
            # print('edge min:', labels[0].min(), ' max: ', labels[0].max())

            if ARGS['gpu']:
                labels = labels.cuda()
                images = images.cuda()
                edges = edges.cuda()

            print('image shape:', images.size())

            with torch.no_grad():
                outputs_edge, outputs = self.net(images)
                loss_edge = lovasz_softmax(outputs_edge,
                                           edges)  # Lovasz-Softmax loss
                loss_seg = lovasz_softmax(outputs, labels)  #
                loss = ARGS['combine_alpha'] * loss_seg + (
                    1 - ARGS['combine_alpha']) * loss_edge

            pred = torch.max(outputs, dim=1)[1]
            iou = torch.sum(pred & labels) / (torch.sum(pred | labels) + 1e-6)

            print(
                'Validating Epoch: {epoch} [{val_samples}/{total_samples}]\tLoss: {:0.4f}\tIoU: {:0.4f}'
                .format(loss.item(),
                        iou.item(),
                        epoch=epoch,
                        val_samples=batch_index * val_batch_size,
                        total_samples=len(val_dataloader.dataset)))

            epoch_loss += loss

            # update training loss for each iteration
            # self.writer.add_scalar('Train/loss', loss.item(), n_iter)

        epoch_loss /= len(val_dataloader)
        self.writer.add_scalar('Val/loss', epoch_loss, epoch)

        finish = time.time()

        print('epoch {} training time consumed: {:.2f}s'.format(
            epoch, finish - start))

    def train_val(self):
        print('Begin training and validating:')
        for epoch in range(ARGS['num_epochs']):
            self.train(epoch)
            self.validate(epoch)
            self.net.state_dict()
            print(f'Finish training and validating epoch #{epoch+1}')
            if (epoch + 1) % ARGS['epoch_save'] == 0:
                os.makedirs(ARGS['weight_save_folder'], exist_ok=True)
                torch.save(
                    self.net.state_dict(),
                    os.path.join(ARGS['weight_save_folder'],
                                 f'epoch_{epoch+1}.pth'))
                print(f'Model saved for epoch #{epoch+1}.')
        print('Finish training and validating.')