Beispiel #1
0
def set_model(args, cfg, checkpoint):
    # model
    if checkpoint:
        model = Classifier(pretrained=False)
        model.load_state_dict(checkpoint['model'])
    else:
        model = Classifier(pretrained=True)
    if args.data_parallel:
        model = DataParallel(model)
    model = model.to(device=args.device)

    # optimizer
    if cfg['optimizer'] == 'sgd':
        optimizer = optim.ASGD(model.parameters(),
                               lr=cfg['learning_rate'],
                               weight_decay=cfg['weight_decay'])
    elif cfg['optimizer'] == 'adam':
        optimizer = optim.Adam(model.parameters(),
                               lr=cfg['learning_rate'],
                               weight_decay=cfg['weight_decay'])
    elif cfg['optimizer'] == 'adabound':
        optimizer = AdaBound(model.parameters(),
                             lr=cfg['learning_rate'],
                             final_lr=0.1,
                             weight_decay=cfg['weight_decay'])
    elif cfg['optimizer'] == 'amsbound':
        optimizer = AdaBound(model.parameters(),
                             lr=cfg['learning_rate'],
                             final_lr=0.1,
                             weight_decay=cfg['weight_decay'],
                             amsbound=True)

    # checkpoint
    if checkpoint and args.load_optimizer:
        optimizer.load_state_dict(checkpoint['optimizer'])

    return model, optimizer
Beispiel #2
0
        assert opt.model == checkpoint['model']

        if opt.pretrained:
            model_dict = model.state_dict()
            passed_dict = ['conv9.weight','conv10.weight','conv11.weight']
            new_state_dict = OrderedDict()
            new_state_dict = {k: v for k,v in checkpoint['state_dict'].items() if k not in passed_dict}
            model_dict.update(new_state_dict)
            model.load_state_dict(model_dict)
        else:
            model.load_state_dict(checkpoint['state_dict'])

        opt.begin_epoch = checkpoint['epoch']
        model = model.to(opt.device)
        if not opt.no_train and not opt.pretrained:
            optimizer.load_state_dict(checkpoint['optimizer'])
        best_mAP = checkpoint["best_mAP"]


    ########################################
    #           Train, Val, Test           #
    ########################################
    if opt.test:
        test(model,test_dataloader,opt.begin_epoch,opt)
    else:
        for epoch in range(opt.begin_epoch, opt.num_epochs + 1):
            if not opt.no_train:
                print("\n---- Training Model ----")
                train(model,optimizer,train_dataloader,epoch,opt,train_logger, best_mAP=best_mAP)

            if not opt.no_val and (epoch+1) % opt.val_interval == 0:
Beispiel #3
0
class TrainNetwork(object):
    """The main train network"""

    def __init__(self, args):
        super(TrainNetwork, self).__init__()
        self.args = args
        self.dur_time = 0
        self.logger = self._init_log()

        if not torch.cuda.is_available():
            self.logger.info('no gpu device available')
            sys.exit(1)

        self._init_hyperparam()
        self._init_random_and_device()
        self._init_model()

    def _init_hyperparam(self):
        if 'cifar100' == self.args.train_dataset:
            # cifar10:  6000 images per class, 10 classes, 50000 training images and 10000 test images
            # cifar100: 600 images per class, 100 classes, 500 training images and 100 testing images per class
            self.args.num_classes = 100
            self.args.layers = 20
            self.args.data = '/train_tiny_data/train_data/cifar100'
        elif 'imagenet' == self.args.train_dataset:
            self.args.data = '/train_data/imagenet'
            self.args.num_classes = 1000
            self.args.weight_decay = 3e-5
            self.args.report_freq = 100
            self.args.init_channels = 50
            self.args.drop_path_prob = 0
        elif 'tiny-imagenet' == self.args.train_dataset:
            self.args.data = '/train_tiny_data/train_data/tiny-imagenet'
            self.args.num_classes = 200
        elif 'food101' == self.args.train_dataset:
            self.args.data = '/train_tiny_data/train_data/food-101'
            self.args.num_classes = 101
            self.args.init_channels = 48

    def _init_log(self):
        self.args.save = '../logs/eval/' + self.args.arch + '/' + self.args.train_dataset + '/eval-{}-{}'.format(self.args.save, time.strftime('%Y%m%d-%H%M'))
        dutils.create_exp_dir(self.args.save, scripts_to_save=None)

        log_format = '%(asctime)s %(message)s'
        logging.basicConfig(stream=sys.stdout, level=logging.INFO,
                            format=log_format, datefmt='%m/%d %I:%M:%S %p')
        fh = logging.FileHandler(os.path.join(self.args.save, 'log.txt'))
        fh.setFormatter(logging.Formatter(log_format))
        logger = logging.getLogger('Architecture Training')
        logger.addHandler(fh)
        return logger

    def _init_random_and_device(self):
        # Set random seed and cuda device
        np.random.seed(self.args.seed)
        cudnn.benchmark = True
        torch.manual_seed(self.args.seed)
        cudnn.enabled = True
        torch.cuda.manual_seed(self.args.seed)
        max_free_gpu_id, gpus_info = dutils.get_gpus_memory_info()
        self.device_id = max_free_gpu_id
        self.gpus_info = gpus_info
        self.device = torch.device('cuda:{}'.format(0 if self.args.multi_gpus else self.device_id))

    def _init_model(self):

        self.train_queue, self.valid_queue = self._load_dataset_queue()

        def _init_scheduler():
            if 'cifar' in self.args.train_dataset:
                scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, float(self.args.epochs))
            else:
                scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, self.args.decay_period,
                                                            gamma=self.args.gamma)
            return scheduler

        genotype = eval('geno_types.%s' % self.args.arch)
        reduce_level = (0 if 'cifar10' in self.args.train_dataset else 0)
        model = EvalNetwork(self.args.init_channels, self.args.num_classes, 0,
                            self.args.layers, self.args.auxiliary, genotype, reduce_level)

        # Try move model to multi gpus
        if torch.cuda.device_count() > 1 and self.args.multi_gpus:
            self.logger.info('use: %d gpus', torch.cuda.device_count())
            model = nn.DataParallel(model)
        else:
            self.logger.info('gpu device = %d' % self.device_id)
            torch.cuda.set_device(self.device_id)
        self.model = model.to(self.device)

        self.logger.info('param size = %fM', dutils.calc_parameters_count(model))

        criterion = nn.CrossEntropyLoss()
        if self.args.num_classes >= 50:
            criterion = CrossEntropyLabelSmooth(self.args.num_classes, self.args.label_smooth)
        self.criterion = criterion.to(self.device)

        if self.args.opt == 'adam':
            self.optimizer = torch.optim.Adamax(
                model.parameters(),
                self.args.learning_rate,
                weight_decay=self.args.weight_decay
            )
        elif self.args.opt == 'adabound':
            self.optimizer = AdaBound(model.parameters(),
            self.args.learning_rate,
            weight_decay=self.args.weight_decay)
        else:
            self.optimizer = torch.optim.SGD(
                model.parameters(),
                self.args.learning_rate,
                momentum=self.args.momentum,
                weight_decay=self.args.weight_decay
            )

        self.best_acc_top1 = 0
        # optionally resume from a checkpoint
        if self.args.resume:
            if os.path.isfile(self.args.resume):
                print("=> loading checkpoint {}".format(self.args.resume))
                checkpoint = torch.load(self.args.resume)
                self.dur_time = checkpoint['dur_time']
                self.args.start_epoch = checkpoint['epoch']
                self.best_acc_top1 = checkpoint['best_acc_top1']
                self.args.drop_path_prob = checkpoint['drop_path_prob']
                self.model.load_state_dict(checkpoint['state_dict'])
                self.optimizer.load_state_dict(checkpoint['optimizer'])
                print("=> loaded checkpoint '{}' (epoch {})".format(self.args.resume, checkpoint['epoch']))
            else:
                print("=> no checkpoint found at '{}'".format(self.args.resume))

        self.scheduler = _init_scheduler()
        # reload the scheduler if possible
        if self.args.resume and os.path.isfile(self.args.resume):
            checkpoint = torch.load(self.args.resume)
            self.scheduler.load_state_dict(checkpoint['scheduler'])

    def _load_dataset_queue(self):
        if 'cifar' in self.args.train_dataset:
            train_transform, valid_transform = dutils.data_transforms_cifar(self.args)
            if 'cifar10' == self.args.train_dataset:
                train_data = dset.CIFAR10(root=self.args.data, train=True, download=True, transform=train_transform)
                valid_data = dset.CIFAR10(root=self.args.data, train=False, download=True, transform=valid_transform)
            else:
                train_data = dset.CIFAR100(root=self.args.data, train=True, download=True, transform=train_transform)
                valid_data = dset.CIFAR100(root=self.args.data, train=False, download=True, transform=valid_transform)

            train_queue = torch.utils.data.DataLoader(
                train_data, batch_size = self.args.batch_size, shuffle=True, pin_memory=True, num_workers=4
            )
            valid_queue = torch.utils.data.DataLoader(
                valid_data, batch_size = self.args.batch_size, shuffle=True, pin_memory=True, num_workers=4
            )
        elif 'tiny-imagenet' == self.args.train_dataset:
            train_transform, valid_transform = dutils.data_transforms_tiny_imagenet()
            train_data = dartsdset.TinyImageNet200(self.args.data, train=True, download=True, transform=train_transform)
            valid_data = dartsdset.TinyImageNet200(self.args.data, train=False, download=True, transform=valid_transform)
            train_queue = torch.utils.data.DataLoader(
                train_data, batch_size=self.args.batch_size, shuffle=True, pin_memory=True, num_workers=4
            )
            valid_queue = torch.utils.data.DataLoader(
                valid_data, batch_size=self.args.batch_size, shuffle=True, pin_memory=True, num_workers=4
            )
        elif 'imagenet' == self.args.train_dataset:
            traindir = os.path.join(self.args.data, 'train')
            validdir = os.path.join(self.args.data, 'val')
            train_transform, valid_transform = dutils.data_transforms_imagenet()
            train_data = dset.ImageFolder(
                traindir,train_transform)
            valid_data = dset.ImageFolder(
                validdir,valid_transform)

            train_queue = torch.utils.data.DataLoader(
                train_data, batch_size=self.args.batch_size, shuffle=True, pin_memory=True, num_workers=4)

            valid_queue = torch.utils.data.DataLoader(
                valid_data, batch_size=self.args.batch_size, shuffle=False, pin_memory=True, num_workers=4)
        elif 'food101' == self.args.train_dataset:
            traindir = os.path.join(self.args.data, 'train')
            validdir = os.path.join(self.args.data, 'val')
            train_transform, valid_transform = dutils.data_transforms_food101()
            train_data = dset.ImageFolder(
                traindir,train_transform)
            valid_data = dset.ImageFolder(
                validdir,valid_transform)

            train_queue = torch.utils.data.DataLoader(
                train_data, batch_size=self.args.batch_size, shuffle=True, pin_memory=True, num_workers=4)

            valid_queue = torch.utils.data.DataLoader(
                valid_data, batch_size=self.args.batch_size, shuffle=False, pin_memory=True, num_workers=4)

        return train_queue, valid_queue

    def run(self):
        self.logger.info('args = %s', self.args)
        run_start = time.time()
        for epoch in range(self.args.start_epoch, self.args.epochs):
            self.scheduler.step()
            self.logger.info('epoch % d / %d  lr %e', epoch, self.args.epochs, self.scheduler.get_lr()[0])

            if self.args.no_dropout:
                self.model._drop_path_prob = 0
            else:
                self.model._drop_path_prob = self.args.drop_path_prob * epoch / self.args.epochs
                self.logger.info('drop_path_prob %e', self.model._drop_path_prob)

            train_acc, train_obj = self.train()
            self.logger.info('train loss %e, train acc %f', train_obj, train_acc)

            valid_acc_top1, valid_acc_top5, valid_obj = self.infer()
            self.logger.info('valid loss %e, top1 valid acc %f top5 valid acc %f',
                        valid_obj, valid_acc_top1, valid_acc_top5)
            self.logger.info('best valid acc %f', self.best_acc_top1)

            is_best = False
            if valid_acc_top1 > self.best_acc_top1:
                self.best_acc_top1 = valid_acc_top1
                is_best = True

            dutils.save_checkpoint({
                'epoch': epoch+1,
                'dur_time': self.dur_time + time.time() - run_start,
                'state_dict': self.model.state_dict(),
                'drop_path_prob': self.args.drop_path_prob,
                'best_acc_top1': self.best_acc_top1,
                'optimizer': self.optimizer.state_dict(),
                'scheduler': self.scheduler.state_dict()
            }, is_best, self.args.save)
        self.logger.info('train epoches %d, best_acc_top1 %f, dur_time %s',
                         self.args.epochs, self.best_acc_top1, dutils.calc_time(self.dur_time + time.time() - run_start))

    def train(self):
        objs = dutils.AverageMeter()
        top1 = dutils.AverageMeter()
        top5 = dutils.AverageMeter()

        self.model.train()

        for step, (input, target) in enumerate(self.train_queue):

            input = input.cuda(self.device, non_blocking=True)
            target = target.cuda(self.device, non_blocking=True)

            self.optimizer.zero_grad()
            logits, logits_aux = self.model(input)
            loss = self.criterion(logits, target)
            if self.args.auxiliary:
                loss_aux = self.criterion(logits_aux, target)
                loss += self.args.auxiliary_weight*loss_aux
            loss.backward()
            nn.utils.clip_grad_norm_(self.model.parameters(), self.args.grad_clip)
            self.optimizer.step()

            prec1, prec5 = dutils.accuracy(logits, target, topk=(1,5))
            n = input.size(0)
            objs.update(loss.item(), n)
            top1.update(prec1.item(), n)
            top5.update(prec5.item(), n)

            if step % args.report_freq == 0:
                self.logger.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)

        return top1.avg, objs.avg

    def infer(self):
        objs = dutils.AverageMeter()
        top1 = dutils.AverageMeter()
        top5 = dutils.AverageMeter()
        self.model.eval()
        with torch.no_grad():
            for step, (input, target) in enumerate(self.valid_queue):
                input = input.cuda(self.device, non_blocking=True)
                target = target.cuda(self.device, non_blocking=True)

                logits, _ = self.model(input)
                loss = self.criterion(logits, target)

                prec1, prec5 = dutils.accuracy(logits, target, topk=(1,5))
                n = input.size(0)
                objs.update(loss.item(), n)
                top1.update(prec1.item(), n)
                top5.update(prec5.item(), n)

                if step % args.report_freq == 0:
                    self.logger.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
            return top1.avg, top5.avg, objs.avg
Beispiel #4
0
def main():
    args = parse_args()
    update_config(cfg_hrnet, args)

    # create checkpoint dir
    if not isdir(args.checkpoint):
        mkdir_p(args.checkpoint)

    # create model
    #print('networks.'+ cfg_hrnet.MODEL.NAME+'.get_pose_net')
    model = eval('models.' + cfg_hrnet.MODEL.NAME + '.get_pose_net')(
        cfg_hrnet, is_train=True)
    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()

    # show net
    args.channels = 3
    args.height = cfg.data_shape[0]
    args.width = cfg.data_shape[1]
    #net_vision(model, args)

    # define loss function (criterion) and optimizer
    criterion = torch.nn.MSELoss(reduction='mean').cuda()

    #torch.optim.Adam
    optimizer = AdaBound(model.parameters(),
                         lr=cfg.lr,
                         weight_decay=cfg.weight_decay)

    if args.resume:
        if isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            pretrained_dict = checkpoint['state_dict']
            model.load_state_dict(pretrained_dict)
            args.start_epoch = checkpoint['epoch']
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            logger = Logger(join(args.checkpoint, 'log.txt'), resume=True)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        logger = Logger(join(args.checkpoint, 'log.txt'))
        logger.set_names(['Epoch', 'LR', 'Train Loss'])

    cudnn.benchmark = True
    torch.backends.cudnn.enabled = True
    print('    Total params: %.2fMB' %
          (sum(p.numel() for p in model.parameters()) / (1024 * 1024) * 4))

    train_loader = torch.utils.data.DataLoader(
        #MscocoMulti(cfg),
        KPloader(cfg),
        batch_size=cfg.batch_size * len(args.gpus))
    #, shuffle=True,
    #num_workers=args.workers, pin_memory=True)

    #for i, (img, targets, valid) in enumerate(train_loader):
    #    print(i, img, targets, valid)

    for epoch in range(args.start_epoch, args.epochs):
        lr = adjust_learning_rate(optimizer, epoch, cfg.lr_dec_epoch,
                                  cfg.lr_gamma)
        print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr))

        # train for one epoch
        train_loss = train(train_loader, model, criterion, optimizer)
        print('train_loss: ', train_loss)

        # append logger file
        logger.append([epoch + 1, lr, train_loss])

        save_model(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            },
            checkpoint=args.checkpoint)

    logger.close()
def train_model_v2_1(net,
                     trainloader,
                     validloader,
                     epochs,
                     lr,
                     grad_accum_steps=1,
                     warmup_epoch=1,
                     patience=5,
                     factor=0.5,
                     opt='AdaBound',
                     weight_decay=0.0,
                     loss_w=[0.5, 0.25, 0.25],
                     reference_labels=None,
                     cb_beta=0.99,
                     start_epoch=0,
                     opt_state_dict=None):
    """
    mixup, ReduceLROnPlateau, class balance
    """
    net = net.cuda()

    # loss
    loss_w = loss_w if loss_w is not None else [0.5, 0.25, 0.25]
    if reference_labels is None:
        if len(loss_w) == 3:
            criterion = multiloss_wrapper_v1_mixup(loss_funcs=[
                mixup.CrossEntropyLossForMixup(num_class=168),
                mixup.CrossEntropyLossForMixup(num_class=11),
                mixup.CrossEntropyLossForMixup(num_class=7)
            ],
                                                   weights=loss_w)
        elif len(loss_w) == 4:
            criterion = multiloss_wrapper_v1_mixup(loss_funcs=[
                mixup.CrossEntropyLossForMixup(num_class=168),
                mixup.CrossEntropyLossForMixup(num_class=11),
                mixup.CrossEntropyLossForMixup(num_class=7),
                mixup.CrossEntropyLossForMixup(num_class=1292)
            ],
                                                   weights=loss_w)

    else:
        if len(loss_w) == 3:
            criterion = multiloss_wrapper_v1_mixup(loss_funcs=[
                cbl.CB_CrossEntropyLoss(reference_labels[:, 0],
                                        num_class=168,
                                        beta=cb_beta,
                                        label_smooth=0.0),
                cbl.CB_CrossEntropyLoss(reference_labels[:, 1],
                                        num_class=11,
                                        beta=cb_beta,
                                        label_smooth=0.0),
                cbl.CB_CrossEntropyLoss(reference_labels[:, 2],
                                        num_class=7,
                                        beta=cb_beta,
                                        label_smooth=0.0)
            ],
                                                   weights=loss_w)
        elif len(loss_w) == 4:
            criterion = multiloss_wrapper_v1_mixup(loss_funcs=[
                cbl.CB_CrossEntropyLoss(reference_labels[:, 0],
                                        num_class=168,
                                        beta=cb_beta,
                                        label_smooth=0.0),
                cbl.CB_CrossEntropyLoss(reference_labels[:, 1],
                                        num_class=11,
                                        beta=cb_beta,
                                        label_smooth=0.0),
                cbl.CB_CrossEntropyLoss(reference_labels[:, 2],
                                        num_class=7,
                                        beta=cb_beta,
                                        label_smooth=0.0),
                cbl.CB_CrossEntropyLoss(reference_labels[:, 3],
                                        num_class=1292,
                                        beta=cb_beta,
                                        label_smooth=0.0)
            ],
                                                   weights=loss_w)

    test_criterion = multiloss_wrapper_v1(loss_funcs=[
        nn.CrossEntropyLoss(),
        nn.CrossEntropyLoss(),
        nn.CrossEntropyLoss()
    ],
                                          weights=loss_w)

    # opt
    if opt == 'SGD':
        optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9)
    elif opt == 'AdaBound':
        optimizer = AdaBound(net.parameters(),
                             lr=lr,
                             final_lr=0.1,
                             weight_decay=weight_decay)

    # scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     mode="min",
                                                     patience=patience,
                                                     factor=factor,
                                                     verbose=True)
    warmup_scheduler = WarmUpLR(optimizer, len(trainloader) * warmup_epoch)

    if opt_state_dict is not None:
        optimizer.load_state_dict(opt_state_dict)

    # train
    loglist = []
    val_loss = 100
    for epoch in range(start_epoch, epochs):
        if epoch > warmup_epoch - 1:
            scheduler.step(val_loss)

        print('epoch ', epoch)
        tr_log = _trainer_v1(net,
                             trainloader,
                             criterion,
                             optimizer,
                             epoch,
                             grad_accum_steps,
                             warmup_epoch,
                             warmup_scheduler,
                             use_mixup=True)
        vl_log = _tester_v1(net, validloader, test_criterion)
        loglist.append(list(tr_log) + list(vl_log))

        val_loss = vl_log[0]

        save_checkpoint(epoch, net, optimizer, 'checkpoint')
        save_log(loglist, 'training_log.csv')

    return net
Beispiel #6
0
def train_model(cfg: DictConfig) -> None:
    output_dir = Path.cwd()
    logging.basicConfig(format='%(asctime)s\t%(levelname)s\t%(message)s',
                        datefmt='%Y/%m/%d %H:%M:%S',
                        filename=str(output_dir / 'log.txt'),
                        level=logging.DEBUG)
    # hydraでlogがコンソールにも出力されてしまうのを抑制する
    logger = logging.getLogger()
    assert isinstance(logger.handlers[0], logging.StreamHandler)
    logger.handlers[0].setLevel(logging.CRITICAL)

    if cfg.gpu >= 0:
        device = torch.device(f"cuda:{cfg.gpu}")
        # noinspection PyUnresolvedReferences
        torch.backends.cudnn.benchmark = True
    else:
        device = torch.device("cpu")
    model = load_model(model_name=cfg.model_name)
    model.to(device)
    if cfg.swa.enable:
        swa_model = AveragedModel(model=model, device=device)
    else:
        swa_model = None

    # optimizer = optim.SGD(
    #     model.parameters(), lr=cfg.optimizer.lr,
    #     momentum=cfg.optimizer.momentum,
    #     weight_decay=cfg.optimizer.weight_decay,
    #     nesterov=cfg.optimizer.nesterov
    # )
    optimizer = AdaBound(model.parameters(),
                         lr=cfg.optimizer.lr,
                         final_lr=cfg.optimizer.final_lr,
                         weight_decay=cfg.optimizer.weight_decay,
                         amsbound=False)
    scaler = torch.cuda.amp.GradScaler(enabled=cfg.use_amp)
    if cfg.scheduler.enable:
        scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer=optimizer,
            T_0=1,
            T_mult=1,
            eta_min=cfg.scheduler.eta_min)
        # scheduler = optim.lr_scheduler.CyclicLR(
        #     optimizer, base_lr=cfg.scheduler.base_lr,
        #     max_lr=cfg.scheduler.max_lr,
        #     step_size_up=cfg.scheduler.step_size,
        #     mode=cfg.scheduler.mode
        # )
    else:
        scheduler = None
    if cfg.input_dir is not None:
        input_dir = Path(cfg.input_dir)
        model_path = input_dir / 'model.pt'
        print('load model from {}'.format(model_path))
        model.load_state_dict(torch.load(model_path))

        state_path = input_dir / 'state.pt'
        print('load optimizer state from {}'.format(state_path))
        checkpoint = torch.load(state_path, map_location=device)
        epoch = checkpoint['epoch']
        t = checkpoint['t']
        optimizer.load_state_dict(checkpoint['optimizer'])
        if cfg.swa.enable and 'swa_model' in checkpoint:
            swa_model.load_state_dict(checkpoint['swa_model'])
        if cfg.scheduler.enable and 'scheduler' in checkpoint:
            scheduler.load_state_dict(checkpoint['scheduler'])
        if cfg.use_amp and 'scaler' in checkpoint:
            scaler.load_state_dict(checkpoint['scaler'])
    else:
        epoch = 0
        t = 0

    # カレントディレクトリが変更されるので、データのパスを修正
    if isinstance(cfg.train_data, str):
        train_path_list = (hydra.utils.to_absolute_path(cfg.train_data), )
    else:
        train_path_list = [
            hydra.utils.to_absolute_path(path) for path in cfg.train_data
        ]
    logging.info('train data path: {}'.format(train_path_list))

    train_data = load_train_data(path_list=train_path_list)
    train_dataset = train_data
    train_data = train_dataset[0]
    test_data = load_test_data(
        path=hydra.utils.to_absolute_path(cfg.test_data))

    logging.info('train position num = {}'.format(len(train_data)))
    logging.info('test position num = {}'.format(len(test_data)))

    train_loader = DataLoader(train_data,
                              device=device,
                              batch_size=cfg.batch_size,
                              shuffle=True)
    validation_loader = DataLoader(test_data[:cfg.test_batch_size * 10],
                                   device=device,
                                   batch_size=cfg.test_batch_size)
    test_loader = DataLoader(test_data,
                             device=device,
                             batch_size=cfg.test_batch_size)

    train_writer = SummaryWriter(log_dir=str(output_dir / 'train'))
    test_writer = SummaryWriter(log_dir=str(output_dir / 'test'))

    train_metrics = Metrics()
    eval_interval = cfg.eval_interval
    total_epoch = cfg.epoch + epoch
    for e in range(cfg.epoch):
        train_metrics_epoch = Metrics()

        model.train()
        desc = 'train [{:03d}/{:03d}]'.format(epoch + 1, total_epoch)
        train_size = len(train_loader) * 4
        for x1, x2, t1, t2, z, value, mask in tqdm(train_loader, desc=desc):
            with torch.cuda.amp.autocast(enabled=cfg.use_amp):
                model.zero_grad()

                metric_value = compute_metric(model=model,
                                              x1=x1,
                                              x2=x2,
                                              t1=t1,
                                              t2=t2,
                                              z=z,
                                              value=value,
                                              mask=mask,
                                              val_lambda=cfg.val_lambda,
                                              beta=cfg.beta)

            scaler.scale(metric_value.loss).backward()
            if cfg.clip_grad_max_norm:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               cfg.clip_grad_max_norm)
            scaler.step(optimizer)
            scaler.update()
            if cfg.swa.enable and t % cfg.swa.freq == 0:
                swa_model.update_parameters(model=model)

            t += 1
            if cfg.scheduler.enable:
                scheduler.step(t / train_size)

            train_metrics.update(metric_value=metric_value)
            train_metrics_epoch.update(metric_value=metric_value)

            # print train loss
            if t % eval_interval == 0:
                model.eval()

                validation_metrics = Metrics()
                with torch.no_grad():
                    # noinspection PyAssignmentToLoopOrWithParameter
                    for x1, x2, t1, t2, z, value, mask in validation_loader:
                        m = compute_metric(model=model,
                                           x1=x1,
                                           x2=x2,
                                           t1=t1,
                                           t2=t2,
                                           z=z,
                                           value=value,
                                           mask=mask,
                                           val_lambda=cfg.val_lambda)
                        validation_metrics.update(metric_value=m)

                last_lr = (scheduler.get_last_lr()[-1]
                           if cfg.scheduler.enable else cfg.optimizer.lr)
                logging.info(
                    'epoch = {}, iteration = {}, lr = {}, {}, {}'.format(
                        epoch + 1, t, last_lr,
                        make_metric_log('train', train_metrics),
                        make_metric_log('validation', validation_metrics)))
                write_summary(writer=train_writer,
                              metrics=train_metrics,
                              t=t,
                              prefix='iteration')
                write_summary(writer=test_writer,
                              metrics=validation_metrics,
                              t=t,
                              prefix='iteration')
                train_metrics = Metrics()

                train_writer.add_scalar('learning_rate',
                                        last_lr,
                                        global_step=t)

                model.train()
            elif t % cfg.train_log_interval == 0:
                last_lr = (scheduler.get_last_lr()[-1]
                           if cfg.scheduler.enable else cfg.optimizer.lr)
                logging.info('epoch = {}, iteration = {}, lr = {}, {}'.format(
                    epoch + 1, t, last_lr,
                    make_metric_log('train', train_metrics)))
                write_summary(writer=train_writer,
                              metrics=train_metrics,
                              t=t,
                              prefix='iteration')
                train_metrics = Metrics()

                train_writer.add_scalar('learning_rate',
                                        last_lr,
                                        global_step=t)

        if cfg.swa.enable:
            with torch.cuda.amp.autocast(enabled=cfg.use_amp):
                desc = 'update BN [{:03d}/{:03d}]'.format(
                    epoch + 1, total_epoch)
                np.random.shuffle(train_data)
                # モーメントの計算にはそれなりのデータ数が必要
                # 1/16に減らすより全部使ったほうが精度が高かった
                # データ量を10分程度で処理できる分量に制限
                # メモリが連続でないとDataLoaderで正しく処理できないかもしれない
                train_data = np.ascontiguousarray(train_data[::4])
                torch.optim.swa_utils.update_bn(loader=tqdm(
                    hcpe_loader(data=train_data,
                                device=device,
                                batch_size=cfg.batch_size),
                    desc=desc,
                    total=len(train_data) // cfg.batch_size),
                                                model=swa_model)

        # print train loss for each epoch
        test_metrics = Metrics()

        if cfg.swa.enable:
            test_model = swa_model
        else:
            test_model = model
        test_model.eval()
        with torch.no_grad():
            desc = 'test [{:03d}/{:03d}]'.format(epoch + 1, total_epoch)
            for x1, x2, t1, t2, z, value, mask in tqdm(test_loader, desc=desc):
                metric_value = compute_metric(model=test_model,
                                              x1=x1,
                                              x2=x2,
                                              t1=t1,
                                              t2=t2,
                                              z=z,
                                              value=value,
                                              mask=mask,
                                              val_lambda=cfg.val_lambda)

                test_metrics.update(metric_value=metric_value)

        logging.info('epoch = {}, iteration = {}, {}, {}'.format(
            epoch + 1, t, make_metric_log('train', train_metrics_epoch),
            make_metric_log('test', test_metrics)))
        write_summary(writer=train_writer,
                      metrics=train_metrics_epoch,
                      t=epoch + 1,
                      prefix='epoch')
        write_summary(writer=test_writer,
                      metrics=test_metrics,
                      t=epoch + 1,
                      prefix='epoch')

        epoch += 1

        if e != cfg.epoch - 1:
            # 訓練データを入れ替える
            train_data = train_dataset[e + 1]
            train_loader.data = train_data

    train_writer.close()
    test_writer.close()

    print('save the model')
    torch.save(model.state_dict(), output_dir / 'model.pt')

    print('save the optimizer')
    state = {'epoch': epoch, 't': t, 'optimizer': optimizer.state_dict()}
    if cfg.scheduler.enable:
        state['scheduler'] = scheduler.state_dict()
    if cfg.swa.enable:
        state['swa_model'] = swa_model.state_dict()
    if cfg.use_amp:
        state['scaler'] = scaler.state_dict()
    torch.save(state, output_dir / 'state.pt')
def train(model, test_loader, lang, args, pairs, extra_loader):
	start = time.time()
	
	if args.optimizer == "adam":
		optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
	elif args.optimizer == "adabound":
		optimizer = AdaBound(model.parameters(), lr=0.0001, final_lr=0.1)
	else:
		print("unknow optimizer.")
		exit(0)

	print_model(args)

	save_name = get_name(args)
	if not os.path.exists(args.save_path):
		os.mkdir(args.save_path)
	print("the model is saved in: "+save_name)

	n_epochs = args.n_epochs
	step = 0.0
	begin_epoch = 0
	best_val_bleu = 0
	if not args.from_scratch:
		if os.path.exists(save_name):
			checkpoint = torch.load(save_name)
			model.load_state_dict(checkpoint['model_state_dict'])
			lr = checkpoint['lr']
			step = checkpoint['step']
			begin_epoch = checkpoint['epoch'] + 1
			best_val_bleu = checkpoint['bleu']
			optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
			for param_group in optimizer.param_groups:
				param_group['lr'] = lr
			print("load successful!")
			checkpoint = []
		else:
			print("load unsuccessful!")

	if args.use_dataset_B:
		extra_iter = iter(extra_loader)
		num_iter = 1
		if args.ratio >= 2:
			num_iter = int(args.ratio)
	else:
		extra_iter = None
		num_iter = 0

	for epoch in range(begin_epoch ,n_epochs):

		model.train()

		train_loader = Data.DataLoader(pairs, batch_size=args.batch_size, shuffle=True)

		print_loss_total, step, extra_iter, lr = train_epoch(model, lang, args, train_loader, 
															extra_loader, extra_iter, num_iter, optimizer, step)
			
		print('total loss: %f'%(print_loss_total))
		model.eval()
		curr_bleu = evaluate(model, test_loader, lang, args.max_length)
		print('%s (epoch: %d %d%%)' % (timeSince(start, (epoch+1-begin_epoch)/(n_epochs-begin_epoch)), epoch, (epoch+1-begin_epoch)/(n_epochs-begin_epoch)*100))

		if curr_bleu > best_val_bleu:
			best_val_bleu = curr_bleu
			torch.save({
									'model_state_dict': model.state_dict(),
									'optimizer_state_dict': optimizer.state_dict(),
									'lr': lr,
									'step': step,
									'epoch': epoch,
									'bleu': curr_bleu,
									}, save_name)
			print("checkpoint saved!")
		print()