예제 #1
0
def _save(model_prior: Prior,
          ckpt_loc: str,
          optim: AdaBound):
    """
    Save checkpoint

    Args:
        model_prior (Prior): The prior network
        ckpt_loc (str): Checkpoint location
        optim (AdaBound): The optimizer
    """
    torch.save(model_prior.state_dict(),
               os.path.join(ckpt_loc, 'mdl.ckpt'))
    torch.save(optim.state_dict(),
               os.path.join(ckpt_loc, 'optimizer.ckpt'))
예제 #2
0
class TrainNetwork(object):
    """The main train network"""

    def __init__(self, args):
        super(TrainNetwork, self).__init__()
        self.args = args
        self.dur_time = 0
        self.logger = self._init_log()

        if not torch.cuda.is_available():
            self.logger.info('no gpu device available')
            sys.exit(1)

        self._init_hyperparam()
        self._init_random_and_device()
        self._init_model()

    def _init_hyperparam(self):
        if 'cifar100' == self.args.train_dataset:
            # cifar10:  6000 images per class, 10 classes, 50000 training images and 10000 test images
            # cifar100: 600 images per class, 100 classes, 500 training images and 100 testing images per class
            self.args.num_classes = 100
            self.args.layers = 20
            self.args.data = '/train_tiny_data/train_data/cifar100'
        elif 'imagenet' == self.args.train_dataset:
            self.args.data = '/train_data/imagenet'
            self.args.num_classes = 1000
            self.args.weight_decay = 3e-5
            self.args.report_freq = 100
            self.args.init_channels = 50
            self.args.drop_path_prob = 0
        elif 'tiny-imagenet' == self.args.train_dataset:
            self.args.data = '/train_tiny_data/train_data/tiny-imagenet'
            self.args.num_classes = 200
        elif 'food101' == self.args.train_dataset:
            self.args.data = '/train_tiny_data/train_data/food-101'
            self.args.num_classes = 101
            self.args.init_channels = 48

    def _init_log(self):
        self.args.save = '../logs/eval/' + self.args.arch + '/' + self.args.train_dataset + '/eval-{}-{}'.format(self.args.save, time.strftime('%Y%m%d-%H%M'))
        dutils.create_exp_dir(self.args.save, scripts_to_save=None)

        log_format = '%(asctime)s %(message)s'
        logging.basicConfig(stream=sys.stdout, level=logging.INFO,
                            format=log_format, datefmt='%m/%d %I:%M:%S %p')
        fh = logging.FileHandler(os.path.join(self.args.save, 'log.txt'))
        fh.setFormatter(logging.Formatter(log_format))
        logger = logging.getLogger('Architecture Training')
        logger.addHandler(fh)
        return logger

    def _init_random_and_device(self):
        # Set random seed and cuda device
        np.random.seed(self.args.seed)
        cudnn.benchmark = True
        torch.manual_seed(self.args.seed)
        cudnn.enabled = True
        torch.cuda.manual_seed(self.args.seed)
        max_free_gpu_id, gpus_info = dutils.get_gpus_memory_info()
        self.device_id = max_free_gpu_id
        self.gpus_info = gpus_info
        self.device = torch.device('cuda:{}'.format(0 if self.args.multi_gpus else self.device_id))

    def _init_model(self):

        self.train_queue, self.valid_queue = self._load_dataset_queue()

        def _init_scheduler():
            if 'cifar' in self.args.train_dataset:
                scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, float(self.args.epochs))
            else:
                scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, self.args.decay_period,
                                                            gamma=self.args.gamma)
            return scheduler

        genotype = eval('geno_types.%s' % self.args.arch)
        reduce_level = (0 if 'cifar10' in self.args.train_dataset else 0)
        model = EvalNetwork(self.args.init_channels, self.args.num_classes, 0,
                            self.args.layers, self.args.auxiliary, genotype, reduce_level)

        # Try move model to multi gpus
        if torch.cuda.device_count() > 1 and self.args.multi_gpus:
            self.logger.info('use: %d gpus', torch.cuda.device_count())
            model = nn.DataParallel(model)
        else:
            self.logger.info('gpu device = %d' % self.device_id)
            torch.cuda.set_device(self.device_id)
        self.model = model.to(self.device)

        self.logger.info('param size = %fM', dutils.calc_parameters_count(model))

        criterion = nn.CrossEntropyLoss()
        if self.args.num_classes >= 50:
            criterion = CrossEntropyLabelSmooth(self.args.num_classes, self.args.label_smooth)
        self.criterion = criterion.to(self.device)

        if self.args.opt == 'adam':
            self.optimizer = torch.optim.Adamax(
                model.parameters(),
                self.args.learning_rate,
                weight_decay=self.args.weight_decay
            )
        elif self.args.opt == 'adabound':
            self.optimizer = AdaBound(model.parameters(),
            self.args.learning_rate,
            weight_decay=self.args.weight_decay)
        else:
            self.optimizer = torch.optim.SGD(
                model.parameters(),
                self.args.learning_rate,
                momentum=self.args.momentum,
                weight_decay=self.args.weight_decay
            )

        self.best_acc_top1 = 0
        # optionally resume from a checkpoint
        if self.args.resume:
            if os.path.isfile(self.args.resume):
                print("=> loading checkpoint {}".format(self.args.resume))
                checkpoint = torch.load(self.args.resume)
                self.dur_time = checkpoint['dur_time']
                self.args.start_epoch = checkpoint['epoch']
                self.best_acc_top1 = checkpoint['best_acc_top1']
                self.args.drop_path_prob = checkpoint['drop_path_prob']
                self.model.load_state_dict(checkpoint['state_dict'])
                self.optimizer.load_state_dict(checkpoint['optimizer'])
                print("=> loaded checkpoint '{}' (epoch {})".format(self.args.resume, checkpoint['epoch']))
            else:
                print("=> no checkpoint found at '{}'".format(self.args.resume))

        self.scheduler = _init_scheduler()
        # reload the scheduler if possible
        if self.args.resume and os.path.isfile(self.args.resume):
            checkpoint = torch.load(self.args.resume)
            self.scheduler.load_state_dict(checkpoint['scheduler'])

    def _load_dataset_queue(self):
        if 'cifar' in self.args.train_dataset:
            train_transform, valid_transform = dutils.data_transforms_cifar(self.args)
            if 'cifar10' == self.args.train_dataset:
                train_data = dset.CIFAR10(root=self.args.data, train=True, download=True, transform=train_transform)
                valid_data = dset.CIFAR10(root=self.args.data, train=False, download=True, transform=valid_transform)
            else:
                train_data = dset.CIFAR100(root=self.args.data, train=True, download=True, transform=train_transform)
                valid_data = dset.CIFAR100(root=self.args.data, train=False, download=True, transform=valid_transform)

            train_queue = torch.utils.data.DataLoader(
                train_data, batch_size = self.args.batch_size, shuffle=True, pin_memory=True, num_workers=4
            )
            valid_queue = torch.utils.data.DataLoader(
                valid_data, batch_size = self.args.batch_size, shuffle=True, pin_memory=True, num_workers=4
            )
        elif 'tiny-imagenet' == self.args.train_dataset:
            train_transform, valid_transform = dutils.data_transforms_tiny_imagenet()
            train_data = dartsdset.TinyImageNet200(self.args.data, train=True, download=True, transform=train_transform)
            valid_data = dartsdset.TinyImageNet200(self.args.data, train=False, download=True, transform=valid_transform)
            train_queue = torch.utils.data.DataLoader(
                train_data, batch_size=self.args.batch_size, shuffle=True, pin_memory=True, num_workers=4
            )
            valid_queue = torch.utils.data.DataLoader(
                valid_data, batch_size=self.args.batch_size, shuffle=True, pin_memory=True, num_workers=4
            )
        elif 'imagenet' == self.args.train_dataset:
            traindir = os.path.join(self.args.data, 'train')
            validdir = os.path.join(self.args.data, 'val')
            train_transform, valid_transform = dutils.data_transforms_imagenet()
            train_data = dset.ImageFolder(
                traindir,train_transform)
            valid_data = dset.ImageFolder(
                validdir,valid_transform)

            train_queue = torch.utils.data.DataLoader(
                train_data, batch_size=self.args.batch_size, shuffle=True, pin_memory=True, num_workers=4)

            valid_queue = torch.utils.data.DataLoader(
                valid_data, batch_size=self.args.batch_size, shuffle=False, pin_memory=True, num_workers=4)
        elif 'food101' == self.args.train_dataset:
            traindir = os.path.join(self.args.data, 'train')
            validdir = os.path.join(self.args.data, 'val')
            train_transform, valid_transform = dutils.data_transforms_food101()
            train_data = dset.ImageFolder(
                traindir,train_transform)
            valid_data = dset.ImageFolder(
                validdir,valid_transform)

            train_queue = torch.utils.data.DataLoader(
                train_data, batch_size=self.args.batch_size, shuffle=True, pin_memory=True, num_workers=4)

            valid_queue = torch.utils.data.DataLoader(
                valid_data, batch_size=self.args.batch_size, shuffle=False, pin_memory=True, num_workers=4)

        return train_queue, valid_queue

    def run(self):
        self.logger.info('args = %s', self.args)
        run_start = time.time()
        for epoch in range(self.args.start_epoch, self.args.epochs):
            self.scheduler.step()
            self.logger.info('epoch % d / %d  lr %e', epoch, self.args.epochs, self.scheduler.get_lr()[0])

            if self.args.no_dropout:
                self.model._drop_path_prob = 0
            else:
                self.model._drop_path_prob = self.args.drop_path_prob * epoch / self.args.epochs
                self.logger.info('drop_path_prob %e', self.model._drop_path_prob)

            train_acc, train_obj = self.train()
            self.logger.info('train loss %e, train acc %f', train_obj, train_acc)

            valid_acc_top1, valid_acc_top5, valid_obj = self.infer()
            self.logger.info('valid loss %e, top1 valid acc %f top5 valid acc %f',
                        valid_obj, valid_acc_top1, valid_acc_top5)
            self.logger.info('best valid acc %f', self.best_acc_top1)

            is_best = False
            if valid_acc_top1 > self.best_acc_top1:
                self.best_acc_top1 = valid_acc_top1
                is_best = True

            dutils.save_checkpoint({
                'epoch': epoch+1,
                'dur_time': self.dur_time + time.time() - run_start,
                'state_dict': self.model.state_dict(),
                'drop_path_prob': self.args.drop_path_prob,
                'best_acc_top1': self.best_acc_top1,
                'optimizer': self.optimizer.state_dict(),
                'scheduler': self.scheduler.state_dict()
            }, is_best, self.args.save)
        self.logger.info('train epoches %d, best_acc_top1 %f, dur_time %s',
                         self.args.epochs, self.best_acc_top1, dutils.calc_time(self.dur_time + time.time() - run_start))

    def train(self):
        objs = dutils.AverageMeter()
        top1 = dutils.AverageMeter()
        top5 = dutils.AverageMeter()

        self.model.train()

        for step, (input, target) in enumerate(self.train_queue):

            input = input.cuda(self.device, non_blocking=True)
            target = target.cuda(self.device, non_blocking=True)

            self.optimizer.zero_grad()
            logits, logits_aux = self.model(input)
            loss = self.criterion(logits, target)
            if self.args.auxiliary:
                loss_aux = self.criterion(logits_aux, target)
                loss += self.args.auxiliary_weight*loss_aux
            loss.backward()
            nn.utils.clip_grad_norm_(self.model.parameters(), self.args.grad_clip)
            self.optimizer.step()

            prec1, prec5 = dutils.accuracy(logits, target, topk=(1,5))
            n = input.size(0)
            objs.update(loss.item(), n)
            top1.update(prec1.item(), n)
            top5.update(prec5.item(), n)

            if step % args.report_freq == 0:
                self.logger.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)

        return top1.avg, objs.avg

    def infer(self):
        objs = dutils.AverageMeter()
        top1 = dutils.AverageMeter()
        top5 = dutils.AverageMeter()
        self.model.eval()
        with torch.no_grad():
            for step, (input, target) in enumerate(self.valid_queue):
                input = input.cuda(self.device, non_blocking=True)
                target = target.cuda(self.device, non_blocking=True)

                logits, _ = self.model(input)
                loss = self.criterion(logits, target)

                prec1, prec5 = dutils.accuracy(logits, target, topk=(1,5))
                n = input.size(0)
                objs.update(loss.item(), n)
                top1.update(prec1.item(), n)
                top5.update(prec5.item(), n)

                if step % args.report_freq == 0:
                    self.logger.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
            return top1.avg, top5.avg, objs.avg
예제 #3
0
def main():
    args = parse_args()
    update_config(cfg_hrnet, args)

    # create checkpoint dir
    if not isdir(args.checkpoint):
        mkdir_p(args.checkpoint)

    # create model
    #print('networks.'+ cfg_hrnet.MODEL.NAME+'.get_pose_net')
    model = eval('models.' + cfg_hrnet.MODEL.NAME + '.get_pose_net')(
        cfg_hrnet, is_train=True)
    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()

    # show net
    args.channels = 3
    args.height = cfg.data_shape[0]
    args.width = cfg.data_shape[1]
    #net_vision(model, args)

    # define loss function (criterion) and optimizer
    criterion = torch.nn.MSELoss(reduction='mean').cuda()

    #torch.optim.Adam
    optimizer = AdaBound(model.parameters(),
                         lr=cfg.lr,
                         weight_decay=cfg.weight_decay)

    if args.resume:
        if isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            pretrained_dict = checkpoint['state_dict']
            model.load_state_dict(pretrained_dict)
            args.start_epoch = checkpoint['epoch']
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            logger = Logger(join(args.checkpoint, 'log.txt'), resume=True)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        logger = Logger(join(args.checkpoint, 'log.txt'))
        logger.set_names(['Epoch', 'LR', 'Train Loss'])

    cudnn.benchmark = True
    torch.backends.cudnn.enabled = True
    print('    Total params: %.2fMB' %
          (sum(p.numel() for p in model.parameters()) / (1024 * 1024) * 4))

    train_loader = torch.utils.data.DataLoader(
        #MscocoMulti(cfg),
        KPloader(cfg),
        batch_size=cfg.batch_size * len(args.gpus))
    #, shuffle=True,
    #num_workers=args.workers, pin_memory=True)

    #for i, (img, targets, valid) in enumerate(train_loader):
    #    print(i, img, targets, valid)

    for epoch in range(args.start_epoch, args.epochs):
        lr = adjust_learning_rate(optimizer, epoch, cfg.lr_dec_epoch,
                                  cfg.lr_gamma)
        print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr))

        # train for one epoch
        train_loss = train(train_loader, model, criterion, optimizer)
        print('train_loss: ', train_loss)

        # append logger file
        logger.append([epoch + 1, lr, train_loss])

        save_model(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            },
            checkpoint=args.checkpoint)

    logger.close()
예제 #4
0
def train_model(cfg: DictConfig) -> None:
    output_dir = Path.cwd()
    logging.basicConfig(format='%(asctime)s\t%(levelname)s\t%(message)s',
                        datefmt='%Y/%m/%d %H:%M:%S',
                        filename=str(output_dir / 'log.txt'),
                        level=logging.DEBUG)
    # hydraでlogがコンソールにも出力されてしまうのを抑制する
    logger = logging.getLogger()
    assert isinstance(logger.handlers[0], logging.StreamHandler)
    logger.handlers[0].setLevel(logging.CRITICAL)

    if cfg.gpu >= 0:
        device = torch.device(f"cuda:{cfg.gpu}")
        # noinspection PyUnresolvedReferences
        torch.backends.cudnn.benchmark = True
    else:
        device = torch.device("cpu")
    model = load_model(model_name=cfg.model_name)
    model.to(device)
    if cfg.swa.enable:
        swa_model = AveragedModel(model=model, device=device)
    else:
        swa_model = None

    # optimizer = optim.SGD(
    #     model.parameters(), lr=cfg.optimizer.lr,
    #     momentum=cfg.optimizer.momentum,
    #     weight_decay=cfg.optimizer.weight_decay,
    #     nesterov=cfg.optimizer.nesterov
    # )
    optimizer = AdaBound(model.parameters(),
                         lr=cfg.optimizer.lr,
                         final_lr=cfg.optimizer.final_lr,
                         weight_decay=cfg.optimizer.weight_decay,
                         amsbound=False)
    scaler = torch.cuda.amp.GradScaler(enabled=cfg.use_amp)
    if cfg.scheduler.enable:
        scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer=optimizer,
            T_0=1,
            T_mult=1,
            eta_min=cfg.scheduler.eta_min)
        # scheduler = optim.lr_scheduler.CyclicLR(
        #     optimizer, base_lr=cfg.scheduler.base_lr,
        #     max_lr=cfg.scheduler.max_lr,
        #     step_size_up=cfg.scheduler.step_size,
        #     mode=cfg.scheduler.mode
        # )
    else:
        scheduler = None
    if cfg.input_dir is not None:
        input_dir = Path(cfg.input_dir)
        model_path = input_dir / 'model.pt'
        print('load model from {}'.format(model_path))
        model.load_state_dict(torch.load(model_path))

        state_path = input_dir / 'state.pt'
        print('load optimizer state from {}'.format(state_path))
        checkpoint = torch.load(state_path, map_location=device)
        epoch = checkpoint['epoch']
        t = checkpoint['t']
        optimizer.load_state_dict(checkpoint['optimizer'])
        if cfg.swa.enable and 'swa_model' in checkpoint:
            swa_model.load_state_dict(checkpoint['swa_model'])
        if cfg.scheduler.enable and 'scheduler' in checkpoint:
            scheduler.load_state_dict(checkpoint['scheduler'])
        if cfg.use_amp and 'scaler' in checkpoint:
            scaler.load_state_dict(checkpoint['scaler'])
    else:
        epoch = 0
        t = 0

    # カレントディレクトリが変更されるので、データのパスを修正
    if isinstance(cfg.train_data, str):
        train_path_list = (hydra.utils.to_absolute_path(cfg.train_data), )
    else:
        train_path_list = [
            hydra.utils.to_absolute_path(path) for path in cfg.train_data
        ]
    logging.info('train data path: {}'.format(train_path_list))

    train_data = load_train_data(path_list=train_path_list)
    train_dataset = train_data
    train_data = train_dataset[0]
    test_data = load_test_data(
        path=hydra.utils.to_absolute_path(cfg.test_data))

    logging.info('train position num = {}'.format(len(train_data)))
    logging.info('test position num = {}'.format(len(test_data)))

    train_loader = DataLoader(train_data,
                              device=device,
                              batch_size=cfg.batch_size,
                              shuffle=True)
    validation_loader = DataLoader(test_data[:cfg.test_batch_size * 10],
                                   device=device,
                                   batch_size=cfg.test_batch_size)
    test_loader = DataLoader(test_data,
                             device=device,
                             batch_size=cfg.test_batch_size)

    train_writer = SummaryWriter(log_dir=str(output_dir / 'train'))
    test_writer = SummaryWriter(log_dir=str(output_dir / 'test'))

    train_metrics = Metrics()
    eval_interval = cfg.eval_interval
    total_epoch = cfg.epoch + epoch
    for e in range(cfg.epoch):
        train_metrics_epoch = Metrics()

        model.train()
        desc = 'train [{:03d}/{:03d}]'.format(epoch + 1, total_epoch)
        train_size = len(train_loader) * 4
        for x1, x2, t1, t2, z, value, mask in tqdm(train_loader, desc=desc):
            with torch.cuda.amp.autocast(enabled=cfg.use_amp):
                model.zero_grad()

                metric_value = compute_metric(model=model,
                                              x1=x1,
                                              x2=x2,
                                              t1=t1,
                                              t2=t2,
                                              z=z,
                                              value=value,
                                              mask=mask,
                                              val_lambda=cfg.val_lambda,
                                              beta=cfg.beta)

            scaler.scale(metric_value.loss).backward()
            if cfg.clip_grad_max_norm:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               cfg.clip_grad_max_norm)
            scaler.step(optimizer)
            scaler.update()
            if cfg.swa.enable and t % cfg.swa.freq == 0:
                swa_model.update_parameters(model=model)

            t += 1
            if cfg.scheduler.enable:
                scheduler.step(t / train_size)

            train_metrics.update(metric_value=metric_value)
            train_metrics_epoch.update(metric_value=metric_value)

            # print train loss
            if t % eval_interval == 0:
                model.eval()

                validation_metrics = Metrics()
                with torch.no_grad():
                    # noinspection PyAssignmentToLoopOrWithParameter
                    for x1, x2, t1, t2, z, value, mask in validation_loader:
                        m = compute_metric(model=model,
                                           x1=x1,
                                           x2=x2,
                                           t1=t1,
                                           t2=t2,
                                           z=z,
                                           value=value,
                                           mask=mask,
                                           val_lambda=cfg.val_lambda)
                        validation_metrics.update(metric_value=m)

                last_lr = (scheduler.get_last_lr()[-1]
                           if cfg.scheduler.enable else cfg.optimizer.lr)
                logging.info(
                    'epoch = {}, iteration = {}, lr = {}, {}, {}'.format(
                        epoch + 1, t, last_lr,
                        make_metric_log('train', train_metrics),
                        make_metric_log('validation', validation_metrics)))
                write_summary(writer=train_writer,
                              metrics=train_metrics,
                              t=t,
                              prefix='iteration')
                write_summary(writer=test_writer,
                              metrics=validation_metrics,
                              t=t,
                              prefix='iteration')
                train_metrics = Metrics()

                train_writer.add_scalar('learning_rate',
                                        last_lr,
                                        global_step=t)

                model.train()
            elif t % cfg.train_log_interval == 0:
                last_lr = (scheduler.get_last_lr()[-1]
                           if cfg.scheduler.enable else cfg.optimizer.lr)
                logging.info('epoch = {}, iteration = {}, lr = {}, {}'.format(
                    epoch + 1, t, last_lr,
                    make_metric_log('train', train_metrics)))
                write_summary(writer=train_writer,
                              metrics=train_metrics,
                              t=t,
                              prefix='iteration')
                train_metrics = Metrics()

                train_writer.add_scalar('learning_rate',
                                        last_lr,
                                        global_step=t)

        if cfg.swa.enable:
            with torch.cuda.amp.autocast(enabled=cfg.use_amp):
                desc = 'update BN [{:03d}/{:03d}]'.format(
                    epoch + 1, total_epoch)
                np.random.shuffle(train_data)
                # モーメントの計算にはそれなりのデータ数が必要
                # 1/16に減らすより全部使ったほうが精度が高かった
                # データ量を10分程度で処理できる分量に制限
                # メモリが連続でないとDataLoaderで正しく処理できないかもしれない
                train_data = np.ascontiguousarray(train_data[::4])
                torch.optim.swa_utils.update_bn(loader=tqdm(
                    hcpe_loader(data=train_data,
                                device=device,
                                batch_size=cfg.batch_size),
                    desc=desc,
                    total=len(train_data) // cfg.batch_size),
                                                model=swa_model)

        # print train loss for each epoch
        test_metrics = Metrics()

        if cfg.swa.enable:
            test_model = swa_model
        else:
            test_model = model
        test_model.eval()
        with torch.no_grad():
            desc = 'test [{:03d}/{:03d}]'.format(epoch + 1, total_epoch)
            for x1, x2, t1, t2, z, value, mask in tqdm(test_loader, desc=desc):
                metric_value = compute_metric(model=test_model,
                                              x1=x1,
                                              x2=x2,
                                              t1=t1,
                                              t2=t2,
                                              z=z,
                                              value=value,
                                              mask=mask,
                                              val_lambda=cfg.val_lambda)

                test_metrics.update(metric_value=metric_value)

        logging.info('epoch = {}, iteration = {}, {}, {}'.format(
            epoch + 1, t, make_metric_log('train', train_metrics_epoch),
            make_metric_log('test', test_metrics)))
        write_summary(writer=train_writer,
                      metrics=train_metrics_epoch,
                      t=epoch + 1,
                      prefix='epoch')
        write_summary(writer=test_writer,
                      metrics=test_metrics,
                      t=epoch + 1,
                      prefix='epoch')

        epoch += 1

        if e != cfg.epoch - 1:
            # 訓練データを入れ替える
            train_data = train_dataset[e + 1]
            train_loader.data = train_data

    train_writer.close()
    test_writer.close()

    print('save the model')
    torch.save(model.state_dict(), output_dir / 'model.pt')

    print('save the optimizer')
    state = {'epoch': epoch, 't': t, 'optimizer': optimizer.state_dict()}
    if cfg.scheduler.enable:
        state['scheduler'] = scheduler.state_dict()
    if cfg.swa.enable:
        state['swa_model'] = swa_model.state_dict()
    if cfg.use_amp:
        state['scaler'] = scaler.state_dict()
    torch.save(state, output_dir / 'state.pt')
예제 #5
0
class SRSolver(BaseSolver):
    def __init__(self, opt):
        super(SRSolver, self).__init__(opt)
        self.train_opt = opt['solver']
        self.LR = self.Tensor()
        self.HR = self.Tensor()
        self.SR = None

        self.records = {'train_loss': [],
                        'val_loss': [],
                        'psnr': [],
                        'ssim': [],
                        'lr': []}

        self.model = create_model(opt)
        self.print_network()

        if self.is_train:
            self.model.train()

            # set cl_loss
            if self.use_cl:
                self.cl_weights = self.opt['solver']['cl_weights']
                assert self.cl_weights, "[Error] 'cl_weights' is not be declared when 'use_cl' is true"

            # set loss
            loss_type = self.train_opt['loss_type']
            if loss_type == 'l1':
                self.criterion_pix = nn.L1Loss()
            elif loss_type == 'l2':
                self.criterion_pix = nn.MSELoss()
            else:
                raise NotImplementedError('Loss type [%s] is not implemented!'%loss_type)

            if self.use_gpu:
                self.criterion_pix = self.criterion_pix.cuda()

            # set optimizer
            weight_decay = self.train_opt['weight_decay'] if self.train_opt['weight_decay'] else 0
            optim_type = self.train_opt['type'].upper()
            if optim_type == "ADAM":
                self.optimizer = optim.Adam(self.model.parameters(),
                                            lr=self.train_opt['learning_rate'], weight_decay=weight_decay)
            elif optim_type == 'ADABOUND':
                self.optimizer = AdaBound(self.model.parameters(),
                                          lr = self.train_opt['learning_rate'], weight_decay=weight_decay)
            elif optim_type == 'SGD':
                self.optimizer = optim.SGD(self.model.parameters(),
                                           lr = self.train_opt['learning_rate'], momentum=0.90, weight_decay=weight_decay)
            else:
                raise NotImplementedError('Loss type [%s] is not implemented!' % optim_type)

            # set lr_scheduler
            if self.train_opt['lr_scheme'].lower() == 'multisteplr':
                self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer,
                                                                self.train_opt['lr_steps'],
                                                                self.train_opt['lr_gamma'])
            elif self.train_opt['lr_scheme'].lower() == 'cos':
                self.scheduler = optim.lr_scheduler.CosineAnnealingLR(self.optimizer,
                                                                      T_max =  self.opt['solver']['num_epochs'],
                                                                      eta_min = self.train_opt['lr_min']
                                                                      )
            else:
                raise NotImplementedError('Only MultiStepLR scheme is supported!')

        self.load()

        print('===> Solver Initialized : [%s] || Use CL : [%s] || Use GPU : [%s]'%(self.__class__.__name__,
                                                                                       self.use_cl, self.use_gpu))
        if self.is_train:
            print("optimizer: ", self.optimizer)
            if self.train_opt['lr_scheme'].lower() == 'multisteplr':
                print("lr_scheduler milestones: %s   gamma: %f"%(self.scheduler.milestones, self.scheduler.gamma))

    def _net_init(self, init_type='kaiming'):
        print('==> Initializing the network using [%s]'%init_type)
        init_weights(self.model, init_type)


    def feed_data(self, batch, need_HR=True):
        input = batch['LR']
        self.LR.resize_(input.size()).copy_(input)

        if need_HR:
            target = batch['HR']
            self.HR.resize_(target.size()).copy_(target)


    def train_step(self):
        self.model.train()
        self.optimizer.zero_grad()

        loss_batch = 0.0
        sub_batch_size = int(self.LR.size(0) / self.split_batch)
        for i in range(self.split_batch):
            loss_sbatch = 0.0
            split_LR = self.LR.narrow(0, i*sub_batch_size, sub_batch_size)
            split_HR = self.HR.narrow(0, i*sub_batch_size, sub_batch_size)
            if self.use_cl:
                outputs = self.model(split_LR)
                loss_steps = [self.criterion_pix(sr, split_HR) for sr in outputs]
                for step in range(len(loss_steps)):
                    loss_sbatch += self.cl_weights[step] * loss_steps[step]
            else:
                output = self.model(split_LR)
                loss_sbatch = self.criterion_pix(output, split_HR)

            loss_sbatch /= self.split_batch
            loss_sbatch.backward()

            loss_batch += (loss_sbatch.item())

        # for stable training
        if loss_batch < self.skip_threshold * self.last_epoch_loss:
            self.optimizer.step()
            self.last_epoch_loss = loss_batch
        else:
            print('[Warning] Skip this batch! (Loss: {})'.format(loss_batch))

        self.model.eval()
        return loss_batch

    def test(self):
        self.model.eval()
        with torch.no_grad(): # 执行完forward
            forward_func = self._overlap_crop_forward if self.use_chop else self.model.forward
            if self.self_ensemble and not self.is_train:
                SR = self._forward_x8(self.LR, forward_func)
            else:
                SR = forward_func(self.LR)

            if isinstance(SR, list):
                self.SR = SR[-1]
            else:
                self.SR = SR

        self.model.train()
        if self.is_train:
            loss_pix = self.criterion_pix(self.SR, self.HR)
            return loss_pix.item()


    def _forward_x8(self, x, forward_function):
        """
        self ensemble
        """
        def _transform(v, op):
            v = v.float()

            v2np = v.data.cpu().numpy()
            if op == 'v':
                tfnp = v2np[:, :, :, ::-1].copy()
            elif op == 'h':
                tfnp = v2np[:, :, ::-1, :].copy()
            elif op == 't':
                tfnp = v2np.transpose((0, 1, 3, 2)).copy()

            ret = self.Tensor(tfnp)

            return ret

        lr_list = [x]
        for tf in 'v', 'h', 't':
            lr_list.extend([_transform(t, tf) for t in lr_list])

        sr_list = []
        for aug in lr_list:
            sr = forward_function(aug)
            if isinstance(sr, list):
                sr_list.append(sr[-1])
            else:
                sr_list.append(sr)

        for i in range(len(sr_list)):
            if i > 3:
                sr_list[i] = _transform(sr_list[i], 't')
            if i % 4 > 1:
                sr_list[i] = _transform(sr_list[i], 'h')
            if (i % 4) % 2 == 1:
                sr_list[i] = _transform(sr_list[i], 'v')

        output_cat = torch.cat(sr_list, dim=0)
        output = output_cat.mean(dim=0, keepdim=True)

        return output


    def _overlap_crop_forward(self, x, shave=10, min_size=100000, bic=None):
        """
        chop for less memory consumption during test
        """
        n_GPUs = 2
        scale = self.scale
        b, c, h, w = x.size()
        h_half, w_half = h // 2, w // 2
        h_size, w_size = h_half + shave, w_half + shave
        lr_list = [
            x[:, :, 0:h_size, 0:w_size],
            x[:, :, 0:h_size, (w - w_size):w],
            x[:, :, (h - h_size):h, 0:w_size],
            x[:, :, (h - h_size):h, (w - w_size):w]]

        if bic is not None:
            bic_h_size = h_size*scale
            bic_w_size = w_size*scale
            bic_h = h*scale
            bic_w = w*scale
            
            bic_list = [
                bic[:, :, 0:bic_h_size, 0:bic_w_size],
                bic[:, :, 0:bic_h_size, (bic_w - bic_w_size):bic_w],
                bic[:, :, (bic_h - bic_h_size):bic_h, 0:bic_w_size],
                bic[:, :, (bic_h - bic_h_size):bic_h, (bic_w - bic_w_size):bic_w]]

        if w_size * h_size < min_size:
            sr_list = []
            for i in range(0, 4, n_GPUs):
                lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0)
                if bic is not None:
                    bic_batch = torch.cat(bic_list[i:(i + n_GPUs)], dim=0)

                sr_batch_temp = self.model(lr_batch)

                if isinstance(sr_batch_temp, list):
                    sr_batch = sr_batch_temp[-1]
                else:
                    sr_batch = sr_batch_temp

                sr_list.extend(sr_batch.chunk(n_GPUs, dim=0))
        else:
            sr_list = [
                self._overlap_crop_forward(patch, shave=shave, min_size=min_size) \
                for patch in lr_list
                ]

        h, w = scale * h, scale * w
        h_half, w_half = scale * h_half, scale * w_half
        h_size, w_size = scale * h_size, scale * w_size
        shave *= scale

        output = x.new(b, c, h, w)
        output[:, :, 0:h_half, 0:w_half] \
            = sr_list[0][:, :, 0:h_half, 0:w_half]
        output[:, :, 0:h_half, w_half:w] \
            = sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size]
        output[:, :, h_half:h, 0:w_half] \
            = sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half]
        output[:, :, h_half:h, w_half:w] \
            = sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size]

        return output


    def save_checkpoint(self, epoch, is_best):
        """
        save checkpoint to experimental dir
        """
        filename = os.path.join(self.checkpoint_dir, 'last_ckp.pth')
        print('===> Saving last checkpoint to [%s] ...]'%filename)
        ckp = {
            'epoch': epoch,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'best_pred': self.best_pred,
            'best_epoch': self.best_epoch,
            'records': self.records
        }
        torch.save(ckp, filename)
        if is_best:
            print('===> Saving best checkpoint to [%s] ...]' % filename.replace('last_ckp','best_ckp'))
            torch.save(ckp, filename.replace('last_ckp','best_ckp'))

        if epoch % self.train_opt['save_ckp_step'] == 0:
            print('===> Saving checkpoint [%d] to [%s] ...]' % (epoch,
                                                                filename.replace('last_ckp','epoch_%d_ckp.pth'%epoch)))

            torch.save(ckp, filename.replace('last_ckp','epoch_%d_ckp.pth'%epoch))


    def load(self):
        """
        load or initialize network
        """
        if (self.is_train and self.opt['solver']['pretrain']) or not self.is_train:
            model_path = self.opt['solver']['pretrained_path']
            if model_path is None: raise ValueError("[Error] The 'pretrained_path' does not declarate in *.json")

            print('===> Loading model from [%s]...' % model_path)
            if self.is_train:
                checkpoint = torch.load(model_path)
                self.model.load_state_dict(checkpoint['state_dict'])

                # if self.opt['solver']['pretrain'] == 'resume':
                #     self.cur_epoch = checkpoint['epoch'] + 1
                #     self.optimizer.load_state_dict(checkpoint['optimizer'])
                #     self.best_pred = checkpoint['best_pred']
                #     self.best_epoch = checkpoint['best_epoch']
                #     self.records = checkpoint['records']

            else:
                checkpoint = torch.load(model_path)
                if 'state_dict' in checkpoint.keys(): checkpoint = checkpoint['state_dict']
                load_func = self.model.load_state_dict if isinstance(self.model, nn.DataParallel) \
                    else self.model.module.load_state_dict
                load_func(checkpoint)
        else:
            print('===> Initialize model')
            self._net_init()


    def get_current_visual(self, need_np=True, need_HR=True):
        """
        return LR SR (HR) images
        """
        out_dict = OrderedDict()
        out_dict['LR'] = self.LR.data[0].float().cpu()
        out_dict['SR'] = self.SR.data[0].float().cpu()
        if need_np:  out_dict['LR'], out_dict['SR'] = util.Tensor2np([out_dict['LR'], out_dict['SR']],
                                                                        self.opt['rgb_range'])
        if need_HR:
            out_dict['HR'] = self.HR.data[0].float().cpu()
            if need_np: out_dict['HR'] = util.Tensor2np([out_dict['HR']],
                                                           self.opt['rgb_range'])[0]
        return out_dict


    def save_current_visual(self, epoch, iter):
        """
        save visual results for comparison
        """
        if epoch % self.save_vis_step == 0:
            visuals_list = []
            visuals = self.get_current_visual(need_np=False)
            visuals_list.extend([util.quantize(visuals['HR'].squeeze(0), self.opt['rgb_range']),
                                 util.quantize(visuals['SR'].squeeze(0), self.opt['rgb_range'])])
            visual_images = torch.stack(visuals_list)
            visual_images = thutil.make_grid(visual_images, nrow=2, padding=5)
            visual_images = visual_images.byte().permute(1, 2, 0).numpy()
            misc.imsave(os.path.join(self.visual_dir, 'epoch_%d_img_%d.png' % (epoch, iter + 1)),
                        visual_images)


    def get_current_learning_rate(self):
        # return self.scheduler.get_lr()[-1]
        return self.optimizer.param_groups[0]['lr']


    def update_learning_rate(self, epoch):
        self.scheduler.step(epoch)


    def get_current_log(self):
        log = OrderedDict()
        log['epoch'] = self.cur_epoch
        log['best_pred'] = self.best_pred
        log['best_epoch'] = self.best_epoch
        log['records'] = self.records
        return log


    def set_current_log(self, log):
        self.cur_epoch = log['epoch']
        self.best_pred = log['best_pred']
        self.best_epoch = log['best_epoch']
        self.records = log['records']


    def save_current_log(self):
        data_frame = pd.DataFrame(
            data={'train_loss': self.records['train_loss']
                , 'val_loss': self.records['val_loss']
                , 'psnr': self.records['psnr']
                , 'ssim': self.records['ssim']
                , 'lr': self.records['lr']
                  },
            index=range(1, self.cur_epoch + 1)
        )
        data_frame.to_csv(os.path.join(self.records_dir, 'train_records.csv'),
                          index_label='epoch')


    def print_network(self):
        """
        print network summary including module and number of parameters
        """
        s, n = self.get_network_description(self.model)
        if isinstance(self.model, nn.DataParallel):
            net_struc_str = '{} - {}'.format(self.model.__class__.__name__,
                                                 self.model.module.__class__.__name__)
        else:
            net_struc_str = '{}'.format(self.model.__class__.__name__)

        print("==================================================")
        print("===> Network Summary\n")
        net_lines = []
        line = s + '\n'
        print(line)
        net_lines.append(line)
        line = 'Network structure: [{}], with parameters: [{:,d}]'.format(net_struc_str, n)
        print(line)
        net_lines.append(line)

        if self.is_train:
            with open(os.path.join(self.exp_root, 'network_summary.txt'), 'w') as f:
                f.writelines(net_lines)

        print("==================================================")
예제 #6
0
def train(model, test_loader, lang, args, pairs, extra_loader):
	start = time.time()
	
	if args.optimizer == "adam":
		optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
	elif args.optimizer == "adabound":
		optimizer = AdaBound(model.parameters(), lr=0.0001, final_lr=0.1)
	else:
		print("unknow optimizer.")
		exit(0)

	print_model(args)

	save_name = get_name(args)
	if not os.path.exists(args.save_path):
		os.mkdir(args.save_path)
	print("the model is saved in: "+save_name)

	n_epochs = args.n_epochs
	step = 0.0
	begin_epoch = 0
	best_val_bleu = 0
	if not args.from_scratch:
		if os.path.exists(save_name):
			checkpoint = torch.load(save_name)
			model.load_state_dict(checkpoint['model_state_dict'])
			lr = checkpoint['lr']
			step = checkpoint['step']
			begin_epoch = checkpoint['epoch'] + 1
			best_val_bleu = checkpoint['bleu']
			optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
			for param_group in optimizer.param_groups:
				param_group['lr'] = lr
			print("load successful!")
			checkpoint = []
		else:
			print("load unsuccessful!")

	if args.use_dataset_B:
		extra_iter = iter(extra_loader)
		num_iter = 1
		if args.ratio >= 2:
			num_iter = int(args.ratio)
	else:
		extra_iter = None
		num_iter = 0

	for epoch in range(begin_epoch ,n_epochs):

		model.train()

		train_loader = Data.DataLoader(pairs, batch_size=args.batch_size, shuffle=True)

		print_loss_total, step, extra_iter, lr = train_epoch(model, lang, args, train_loader, 
															extra_loader, extra_iter, num_iter, optimizer, step)
			
		print('total loss: %f'%(print_loss_total))
		model.eval()
		curr_bleu = evaluate(model, test_loader, lang, args.max_length)
		print('%s (epoch: %d %d%%)' % (timeSince(start, (epoch+1-begin_epoch)/(n_epochs-begin_epoch)), epoch, (epoch+1-begin_epoch)/(n_epochs-begin_epoch)*100))

		if curr_bleu > best_val_bleu:
			best_val_bleu = curr_bleu
			torch.save({
									'model_state_dict': model.state_dict(),
									'optimizer_state_dict': optimizer.state_dict(),
									'lr': lr,
									'step': step,
									'epoch': epoch,
									'bleu': curr_bleu,
									}, save_name)
			print("checkpoint saved!")
		print()