コード例 #1
0
ファイル: train.py プロジェクト: zhuzhu18/pytorch_cifar10
    def fit(self,
            train_loader,
            test_loader,
            start_epoch=0,
            max_epochs=200,
            lr_scheduler=None):
        args = self.args
        for epoch in range(start_epoch, max_epochs):
            # train for one epoch
            self.train(train_loader, self.model, self.criterion,
                       self.optimizer, epoch)

            # evaluate on validation set
            prec1 = self.evaluate(test_loader, self.model, self.criterion,
                                  epoch)  # top1 avg error

            if lr_scheduler is not None:
                lr_scheduler.step(epoch)
            # self.adjust_learning_rate(self.optimizer, epoch)
            if epoch % 5 == 0:
                self.figure.generate()

            # remember best prec@1 and save checkpoint
            is_best = prec1 > self.best_prec1
            self.best_prec1 = max(prec1, self.best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': self.model.state_dict(),
                    'best_prec1': self.best_prec1,
                    'optimizer': self.optimizer.state_dict(),
                },
                is_best,
                folder=args.folder)
コード例 #2
0
 def _save_checkpoint(self, epoch, rank1, save_dir, is_best=False):
     save_checkpoint(
         {
             'state_dict': self.model.state_dict(),
             'epoch': epoch + 1,
             'rank1': rank1,
             'optimizer': self.optimizer.state_dict(),
             'scheduler': self.scheduler.state_dict(),
         },
         save_dir,
         is_best=is_best)
コード例 #3
0
                                       is_validate=True,
                                       offset=offset)
            offset += 1

            is_best = False
            if validation_loss < best_err:
                best_err = validation_loss
                is_best = True

            checkpoint_progress = tqdm(ncols=100,
                                       desc='Saving Checkpoint',
                                       position=offset)
            tools.save_checkpoint(
                {
                    'arch': args.model,
                    'epoch': epoch,
                    'state_dict': model.module.state_dict(),
                    'best_EPE': best_err
                }, is_best, args.save, args.model)
            checkpoint_progress.update(1)
            checkpoint_progress.close()
            offset += 1

        if not args.skip_training:
            train_loss, iterations = train(args=args,
                                           epoch=epoch,
                                           start_iteration=global_iteration,
                                           data_loader=train_loader,
                                           model=model,
                                           optimizer=optimizer,
                                           loss=loss,
コード例 #4
0
            if not args.skip_validation and ((epoch - 1) % args.validation_frequency) == 0:
                validation_loss, _ = train(args=args, epoch=epoch - 1, start_iteration=global_iteration,
                                           data_loader=validation_loader, model=model_and_loss, optimizer=optimizer,
                                           logger=validation_logger, is_validate=True, offset=offset)
                offset += 1

                is_best = False
                if validation_loss < best_err:
                    best_err = validation_loss
                    is_best = True

                checkpoint_progress = tqdm(
                    ncols=100, desc='Saving Checkpoint', position=offset)
                tools.save_checkpoint({'arch': args.model,
                                       'epoch': epoch,
                                       'state_dict': model_and_loss.module.model.state_dict(),
                                       'best_EPE': best_err},
                                      is_best, args.save, args.model)
                checkpoint_progress.update(1)
                checkpoint_progress.close()
                offset += 1

            if not args.skip_training:
                train_loss, iterations = train(args=args, epoch=epoch, start_iteration=global_iteration,
                                               data_loader=train_loader, model=model_and_loss, optimizer=optimizer,
                                               logger=train_logger, offset=offset)
                global_iteration += iterations
                offset += 1

                # save checkpoint after every validation_frequency number of epochs
                if ((epoch - 1) % args.validation_frequency) == 0:
コード例 #5
0
                offset=offset,
            )
            offset += 1

            is_best = False
            if validation_loss < best_err:
                best_err = validation_loss
                is_best = True

            checkpoint_progress = tqdm(ncols=100, desc="Saving Checkpoint", position=offset)
            tools.save_checkpoint(
                {
                    "arch": args.model,
                    "epoch": epoch,
                    "state_dict": model_and_loss.module.model.state_dict(),
                    "best_EPE": best_err,
                },
                is_best,
                args.save,
                args.model,
            )
            checkpoint_progress.update(1)
            checkpoint_progress.close()
            offset += 1

        if not args.skip_training:
            train_loss, iterations = train(
                args=args,
                epoch=epoch,
                start_iteration=global_iteration,
                data_loader=train_loader,
コード例 #6
0
def train(model, device, train_dataloader, test_dataloader, multi_gpu, args):
    # 初始参数
    start_epoch, epochs_since_improvement, best_loss = load_checkpoint()
    model.train()
    # 计算所有epoch进行参数优化的总步数total_steps
    total_steps = int(len(train_dataloader) * args.epochs / args.gradient_accumulation)
    print('total training steps = {}'.format(total_steps))
    # 设置优化器,并且在初始训练时,使用warmup策略
    optimizer = AdamW(model.parameters(), lr=args.lr, correct_bias=True)
    scheduler = get_linear_schedule_with_warmup(optimizer,
                                                num_warmup_steps=args.warmup_steps,
                                                num_training_steps=total_steps)
    print('开始训练')
    # 用于统计每次梯度累计的loss
    running_loss = 0
    # 统计一共训练了多少个step
    overall_step = 0
    # 记录tensorboardX
    tb_writer = SummaryWriter(log_dir=args.writer_dir)
    # 记录 out of memory的次数
    # 开始训练
    for epoch in range(start_epoch, args.epochs):
        epoch_start_time = datetime.now()
        running_train_loss = 0
        running_train_correct = 0  # 记录预测正确的值
        running_train_num = 0
        data_iter = tqdm(enumerate(train_dataloader), total=len(train_dataloader))
        for batch_idx, input_ids in data_iter:
            # 注意:GPT2模型的forward()函数,是对于给定的context,生成一个token,而不是生成一串token
            # GPT2Model的输入为n个token_id时,输出也是n个hidden_state,使用第n个hidden_state预测第n+1个token
            input_ids = input_ids.to(device)
            outputs = model.forward(input_ids=input_ids)
            loss, correct, num_targets = calculate_loss_and_accuracy(outputs, input_ids, device, pad_id)
            # -- 加入epoch_loss, epoch_acc -- #
            running_train_loss += loss.item()
            running_train_correct += correct
            running_train_num += num_targets
            # train_temp_loss, train_temp_acc
            loss = loss / num_targets
            accuracy = correct / num_targets
            if multi_gpu:
                loss = loss.mean()
                accuracy = accuracy.mean()
            if args.gradient_accumulation > 1:
                loss = loss / args.gradient_accumulation
                accuracy = accuracy / args.gradient_accumulation
            loss.backward()
            # 梯度裁剪解决的是梯度消失或爆炸的问题,即设定阈值
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
            # 进行一定step的梯度累计之后,更新参数
            if (batch_idx + 1) % args.gradient_accumulation == 0:
                running_loss += loss.item()
                # 更新参数
                optimizer.step()
                # 清空梯度信息
                optimizer.zero_grad()
                # 进行warm up
                scheduler.step()
                overall_step += 1
                # 更新tnesorboardX信息
                if (overall_step + 1) % args.log_step == 0:
                    tb_writer.add_scalar('loss', loss.item(), overall_step)
            # 输出最新loss与acc
            data_iter.set_description("epoch:{}/{}, train_loss:{:.4f}, train_acc:{:.2f}%"\
                                      .format(epoch + 1, config.epochs, loss, accuracy * 100))
        epoch_train_loss = running_train_loss / running_train_num
        epoch_train_acc = running_train_correct / running_train_num
        epoch_valid_loss, epoch_valid_acc = evaluate(epoch, model, device, test_dataloader, multi_gpu, args)
        print('epoch: {} / {} train_loss:{:.4f}, train_acc:{:.2f}% \
                 valid_loss:{:.4f}, valid_acc:{:.2f}%'. \
              format(epoch + 1, config.epochs, epoch_train_loss,epoch_train_acc * 100,
                     epoch_valid_loss, epoch_valid_acc * 100))
        # 保存epoch训练结果
        save_epoch_csv(epoch, epoch_train_loss, epoch_train_acc, epoch_valid_loss, epoch_valid_acc)
        # 开始保存模型
        is_best = epoch_valid_loss < best_loss
        best_loss = min(epoch_valid_loss, best_loss)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))
        else:
            best_loss = epoch_valid_loss
            epochs_since_improvement = 0
        save_checkpoint(epoch, epochs_since_improvement, model, best_loss, is_best)
        epoch_finish_time = datetime.now()
        print('time for one epoch: {}'.format(epoch_finish_time - epoch_start_time))
    print('training finished')
コード例 #7
0
best_err = 1000
is_best = False
global_iteration = 0

for epoch in range(config['total_epochs']):
    print('Epoch {}/{}'.format(epoch, config['total_epochs'] - 1))
    print('-' * 10)
    g_scheduler.step()

    if not config['skip_training']:
        global_iteration = train_convnet(g_model, epoch, g_criterion,
                                         g_optimizer, train_loader,
                                         global_iteration)

    if not config['skip_validate'] and ((epoch + 1) %
                                        config['validate_frequency']) == 0:
        # _error = validation(g_model,epoch, val_loss, val_loader )
        validation(g_model, epoch, val_loss, val_loader)
        # if _error < best_err:
        #     is_best = True
        # if is_best:
        _error = 0.0
        tools.save_checkpoint(
            {
                'arch': config['model'],
                'epoch': epoch,
                'state_dict': g_model.state_dict(),
                'best_err': _error
            }, is_best, config['save'], config['model'])
コード例 #8
0
                                       is_validate=True,
                                       offset=offset)
            offset += 1

            is_best = False
            if validation_loss < best_err:
                best_err = validation_loss
                is_best = True

            checkpoint_progress = tqdm(ncols=100,
                                       desc='Saving Checkpoint',
                                       position=offset)
            tools.save_checkpoint(
                {
                    'arch': args.model,
                    'epoch': epoch,
                    'state_dict': SRmodel.model.state_dict(),
                    'best_EPE': best_err,
                    'optimizer': optimizer
                }, is_best, args.save, args.model)
            checkpoint_progress.update(1)
            checkpoint_progress.close()
            offset += 1

        if not args.skip_training:
            train_loss, iterations = train(args=args,
                                           epoch=epoch,
                                           data_loader=train_loader,
                                           model=SRmodel,
                                           optimizer=optimizer,
                                           offset=offset)
            global_iteration += iterations
コード例 #9
0
    def train(self):
        best_recall = 0
        not_improved = 0

        if self.config['train']['wandb']:
            wandb.init(project='netvlad-tokyotm', config=self.config)
            wandb.watch(self.model)

        recall_n1_cache = []
        recall_n5_cache = []
        recall_n10_cache = []
        for e in range(self.epochs):
            # self.train_epoch(e)
            # self.scheduler.step(e)

            if (e % self.valFrequency) == 0:
                if self.config['train']['val_dataset'] == 'tokyo247':
                    recalls = self.test(e, self.testQ_dataset,
                                        self.testDb_dataset)
                else:
                    recalls = self.test(e, self.valQ_dataset,
                                        self.valDb_dataset)

                if recalls[2][1] > best_recall:
                    not_improved = 0
                    best_recall = recalls[2][1]
                    save_checkpoint(
                        e=e,
                        model=self.model,
                        recalls=recalls,
                        filepath=self.config['statedict_root']['best'])
                    if self.wandb:
                        table = wandb.Table(data=recalls,
                                            columns=["@N", "Recall"])
                        wandb.log({
                            'Best Recall Curve':
                            wandb.plot.line(
                                table,
                                "@N",
                                "Recall",
                                title="Best Model's Recall @N" + " (" +
                                self.config['train']['val_dataset'] + ")")
                        })
                        wandb.log({'Recall @1': recalls[0][1]})
                        wandb.log({'Recall @5': recalls[4][1]})
                        wandb.log({'Recall @10': recalls[5][1]})

                else:
                    not_improved += 1

                save_checkpoint(
                    e=e,
                    model=self.model,
                    recalls=recalls,
                    filepath=self.config['statedict_root']['checkpoint'])

                # Recall @1 Plot
                if self.wandb:
                    recall_n1_cache.append([e, recalls[0][1]])
                    table = wandb.Table(data=recall_n1_cache,
                                        columns=["Epoch", "Recall @1"])
                    wandb.log({
                        "Recall @1 Changes over Training":
                        wandb.plot.line(
                            table,
                            "Epoch",
                            "Recall @1",
                            title="Recall @1 Changes over Training" + " (" +
                            self.config['train']['val_dataset'] + ")")
                    })

                    recall_n5_cache.append([e, recalls[4][1]])
                    table = wandb.Table(data=recall_n5_cache,
                                        columns=["Epoch", "Recall @5"])
                    wandb.log({
                        "Recall @5 Changes over Training":
                        wandb.plot.line(
                            table,
                            "Epoch",
                            "Recall @5",
                            title="Recall @5 Changes over Training" + " (" +
                            self.config['train']['val_dataset'] + ")")
                    })

                    recall_n10_cache.append([e, recalls[5][1]])
                    table = wandb.Table(data=recall_n10_cache,
                                        columns=["Epoch", "Recall @10"])
                    wandb.log({
                        "Recall @10 Changes over Training":
                        wandb.plot.line(
                            table,
                            "Epoch",
                            "Recall @10",
                            title="Recall @10 Changes over Training" + " (" +
                            self.config['train']['val_dataset'] + ")")
                    })
コード例 #10
0
                                           optimizer=optimizer,
                                           scheduler=scheduler,
                                           logger=train_logger,
                                           offset=offset)
            global_iteration += iterations
            offset += 1

            # save checkpoint after every validation_frequency number of epochs
            if ((epoch - 1) % args.validation_frequency) == 0:
                checkpoint_progress = tqdm(ncols=100,
                                           desc='Saving Checkpoint',
                                           position=offset)
                tools.save_checkpoint(generate_checkpoint_state(
                    model_and_loss.module.model, optimizer, scheduler,
                    args.model, epoch + 1, global_iteration, train_loss, True),
                                      False,
                                      args.save,
                                      args.model,
                                      filename='train-checkpoint.pth.tar')
                checkpoint_progress.update(1)
                checkpoint_progress.close()

        if not args.skip_validation and (epoch %
                                         args.validation_frequency) == 0:
            validation_loss, _ = train(args=args,
                                       epoch=epoch,
                                       start_iteration=global_iteration,
                                       data_loader=validation_loader,
                                       model=model_and_loss,
                                       optimizer=optimizer,
                                       scheduler=scheduler,
コード例 #11
0
def train_valid(model, criterion, optimizer):
    """
    总训练与验证程序
    """
    best_score = 0  # 最佳得分,目前用的f1-score
    epochs_since_improvement = 0  # 如果没有改善,累计次数,累计一定次数就退出
    start_epoch = 1  # 开始的epoch
    # 如果存在上次的训练记录
    since = time.time()
    if os.path.exists(config.checkpoint_file):
        print('发现模型', config.checkpoint_file)
        checkpoint = torch.load(config.checkpoint_file)
        start_epoch = checkpoint['epoch'] + 1
        best_score = checkpoint['best_score']
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        model.load_state_dict(checkpoint['model_state'])
        optimizer.load_state_dict(checkpoint['optimizer_state'])
    # 初始化loss权重
    epoch_loss_weights = EpochLossWeight()
    # 加入block输出
    block_dir = os.path.join(config.out_data_dir, 'block_dir')
    if not os.path.exists(block_dir):
        os.mkdir(block_dir)
    else:
        for file in os.listdir(block_dir):
            os.remove(os.path.join(block_dir, file))
    # --- 正式训练
    for epoch in range(start_epoch, config.epochs + 1):
        epoch_loss_weights_dict = epoch_loss_weights.run()
        train_loss_list, train_acc_list, train_f1_score_list = train_epoch(
            model, criterion, optimizer, epoch, epoch_loss_weights_dict)
        valid_loss_list, valid_acc_list, valid_f1_score_list = valid_epoch(
            model, criterion, epoch)
        # 计算平均表现情况
        train_loss_average = np.average(train_loss_list)
        train_acc_average = np.average(train_acc_list)
        train_f1_score_average = np.average(train_f1_score_list)
        valid_loss_average = np.average(valid_loss_list)
        valid_acc_average = np.average(valid_acc_list)
        valid_f1_score_average = np.average(valid_f1_score_list)
        # 保存记录
        out_path = os.path.join(config.out_data_dir, 'out_average.csv')  # 输出路径
        save_epoch_csv(epoch, train_loss_average, train_acc_average,
                       train_f1_score_average, valid_loss_average,
                       valid_acc_average, valid_f1_score_average, out_path)
        # 保存三个数据集所有数据到csv
        for file_name, train_loss, train_acc, train_f1, valid_loss, valid_acc, valid_f1 \
                in zip(config.file_name_list, train_loss_list, train_acc_list, train_f1_score_list,
                       valid_loss_list, valid_acc_list, valid_f1_score_list):
            out_path = os.path.join(config.out_data_dir,
                                    f'out_{file_name}.csv')  # 输出路径
            save_epoch_csv(epoch, train_loss, train_acc, train_f1, valid_loss,
                           valid_acc, valid_f1, out_path)
        print('epoch: {} / {} train_loss:{:.4f}, train_acc:{:.2f}%  train_f1_score:{:.4f}\n\
        valid_loss:{:.4f}, valid_acc:{:.2f}% valid_f1_score:{:.4f}'                                                                   .\
              format(epoch, config.epochs, train_loss_average, train_acc_average * 100, train_f1_score_average,
                     valid_loss_average, valid_acc_average * 100, valid_f1_score_average))
        # -- 保存最佳模型,以valid f1为准 -- #
        is_best = valid_f1_score_average > best_score
        best_score = max(valid_f1_score_average, best_score)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement, ))
            if epochs_since_improvement >= config.break_epoch:
                break
        else:
            best_score = valid_f1_score_average
            epochs_since_improvement = 0
        time_elapsed = time.time() - since  # 计算时间间隔
        print('当前训练共用时{:.0f}时{:.0f}分{:.0f}秒'.format(time_elapsed // 3600,
                                                    time_elapsed // 60,
                                                    time_elapsed % 60))
        loss_weights_dict = epoch_loss_weights.run()
        print('当前loss权重为:', loss_weights_dict)
        print('目前最高的F1-score为:{:.6f}'.format(best_score))
        save_checkpoint(epoch, epochs_since_improvement, model, optimizer,
                        best_score, is_best)
    return model, optimizer
コード例 #12
0
def main():
    # --------------------------------config-------------------------------
    global use_cuda
    global gpu_ids
    threshold = args.threshold
    best_loss = None
    best_f2 = None

    start_epoch = args.start_epoch  # start from epoch 0 or last checkpoint epoch
    # ------------------------------ load dataset---------------------------
    print('==> Loader dataset {}'.format(args.train_data))

    train_transform = get_transform(size=args.image_size, mode='train')
    train_dataset = PlanetDataset(image_root=args.train_data,
                                  target_path=args.labels,
                                  phase='train',
                                  fold=args.fold,
                                  img_type=args.image_type,
                                  img_size=args.image_size,
                                  transform=train_transform)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.num_works)

    eval_transform = get_transform(size=args.image_size, mode='eval')
    eval_dataset = PlanetDataset(image_root=args.train_data,
                                 target_path=args.labels,
                                 phase='eval',
                                 fold=args.fold,
                                 img_type=args.image_type,
                                 img_size=args.image_size,
                                 transform=eval_transform)
    eval_loader = torch.utils.data.DataLoader(eval_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.num_works)

    # ---------------------------------model---------------------------------
    model = build_model(model_name=args.model_name,
                        num_classes=args.num_classes,
                        pretrained=args.pretrained,
                        global_pool=args.global_pool)
    if use_cuda:
        if len(gpu_ids) > 1:
            model = torch.nn.DataParallel(
                model, device_ids=gpu_ids).cuda()  # load model to cuda
        else:
            model.cuda()
    # show model size
    print('\t Total params volumes: {:.2f} M'.format(
        sum(param.numel() for param in model.parameters()) / 1000000.0))

    # --------------------------------criterion-----------------------
    criterion = None
    if args.reweight:
        class_weights = torch.from_numpy(
            train_dataset.get_class_weights()).float()
        class_weights_norm = class_weights / class_weights.sum()
        if use_cuda:
            class_weights = class_weights.cuda()
            class_weights_norm = class_weights_norm.cuda()
    else:
        class_weights = None
        class_weights_norm = None

    if args.loss.lower() == 'nll':
        # assert not args.multi_label and 'Cannot use crossentropy with multi-label target.'
        criterion = torch.nn.CrossEntropyLoss(weight=class_weights)
    elif args.loss.lower() == 'mlsm':
        assert args.multi_label
        criterion = torch.nn.MultiLabelSoftMarginLoss(weight=class_weights)
    else:
        assert False and "Invalid loss function"

    #---------------------------------optimizer----------------------------
    optimizer = get_optimizer(model, args)

    # apex optimizer
    # Initialization
    # opt_level = 'O1'
    # model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)

    # lr scheduler
    if not args.decay_epoch:
        lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.1, patience=8, verbose=False)
    else:
        lr_scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, step_size=args.decay_epoch, gamma=0.1)

    # # Resume model
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']

            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            start_epoch = checkpoint['epoch']
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
            exit(-1)

    # eval model
    # if args.evaluate:
    #     print('\nEvaluation only')
    #     test_loss, test_acc_1, test_acc_5 = test(val_loader, model, criterion, use_cuda)
    #     print(' Test => loss {:.4f} | acc_top1 {:.4f} acc_top5'.format(test_loss, test_acc_1, test_acc_5))
    #
    #     return None
    #
    # best_model_weights = copy.deepcopy(model.state_dict())
    since = time.time()
    try:
        for epoch in range(start_epoch, args.epochs):
            # print('Epoch {}/{} | LR {:.8f}'.format(epoch, args.epochs, optimizer.param_groups[0]['lr']))

            train_metrics = train(loader=train_loader,
                                  model=model,
                                  epoch=epoch,
                                  criterion=criterion,
                                  optimizer=optimizer,
                                  threshold=threshold,
                                  class_weights=class_weights_norm,
                                  use_cuda=use_cuda)
            eval_metrics, latest_threshold = eval(loader=eval_loader,
                                                  model=model,
                                                  epoch=epoch,
                                                  criterion=criterion,
                                                  threshold=threshold,
                                                  use_cuda=use_cuda)

            if args.decay_epoch is None:
                lr_scheduler.step(eval_metrics['loss'])
            else:
                lr_scheduler.step()

            # save train and eval metric
            writer.add_scalars(main_tag='epoch/loss',
                               tag_scalar_dict={
                                   'train': train_metrics['loss'],
                                   'val': eval_metrics['loss']
                               },
                               global_step=epoch)

            if args.multi_label:
                writer.add_scalars(main_tag='epoch/acc',
                                   tag_scalar_dict={
                                       'train': train_metrics['acc'],
                                       'val': eval_metrics['acc']
                                   },
                                   global_step=epoch)
            else:
                writer.add_scalars(main_tag='epoch/acc_top1',
                                   tag_scalar_dict={
                                       'train': train_metrics['acc_top1'],
                                       'val': eval_metrics['acc_top1']
                                   },
                                   global_step=epoch)
                writer.add_scalars(main_tag='epoch/acc_top5',
                                   tag_scalar_dict={
                                       'train': train_metrics['acc_top5'],
                                       'val': eval_metrics['acc_top5']
                                   },
                                   global_step=epoch)

            writer.add_scalar(tag='epoch/f2_score',
                              scalar_value=eval_metrics['f2'],
                              global_step=epoch)

            # add learning_rate to logs
            writer.add_scalar(tag='lr',
                              scalar_value=optimizer.param_groups[0]['lr'],
                              global_step=epoch)

            # -----------------------------save model every epoch -----------------------------
            # get param state dict
            if len(args.gpu_id) > 1:
                model_weights = model.module.state_dict()
            else:
                model_weights = model.state_dict()

            # model_weights = amp.state_dict()

            # -------------------------- save model state--------------------------
            is_best = False

            if best_loss is not None or best_f2 is not None:
                if eval_metrics['loss'] < best_loss[0]:
                    best_loss = (eval_metrics['loss'], epoch)
                    if args.score_metric == 'loss':
                        is_best = True
                elif eval_metrics['f2'] > best_f2[0]:
                    best_f2 = (eval_metrics['f2'], epoch)
                    if args.score_metric == 'f2':
                        is_best = True
                else:
                    is_best = False
                    pass
            else:
                best_loss = (eval_metrics['loss'], epoch)
                best_f2 = (eval_metrics['f2'], epoch)
                is_best = True

            state = {
                'epoch': epoch + 1,
                'arch': args.model_name,
                'state_dict': model_weights,
                'optimizer': optimizer.state_dict(),
                'threshold': latest_threshold,
                'loss': eval_metrics['loss'],
                'f2': eval_metrics['f2'],
                'fold': args.fold,
                'num_gpu': len(gpu_ids)
            }
            save_checkpoint(state,
                            os.path.join(
                                args.checkpoint,
                                'ckpt-{}-f{}-{:.6f}.pth.tar'.format(
                                    epoch, args.fold, eval_metrics['f2'])),
                            is_best=is_best)

    except KeyboardInterrupt:
        pass

    writer.close()

    time_elapsed = time.time() - since
    print('*** Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('*** Eval best loss: {0} (epoch {1})'.format(best_loss[1],
                                                       best_loss[0]))
    print('*** Eval best f2_score: {0} (epoch {1})'.format(
        best_f2[1], best_f2[0]))