Esempio n. 1
0
def train(name, df, VAL_FOLD=0, resume=False):
    dt_string = datetime.now().strftime("%d|%m_%H|%M|%S")
    print("Starting -->", dt_string)

    os.makedirs(OUTPUT_DIR, exist_ok=True)
    os.makedirs('checkpoint', exist_ok=True)
    run = f"{name}_[{dt_string}]"

    wandb.init(project="imanip", config=config_defaults, name=run)
    config = wandb.config

    # model = SRM_Classifer(num_classes=1, encoder_checkpoint='weights/pretrain_[31|03_12|16|32].h5')
    model = SMP_SRM_UPP(classifier_only=True)

    # for name_, param in model.named_parameters():
    #     if 'classifier' in name_:
    #         continue
    #     else:
    #         param.requires_grad = False

    print("Parameters : ",
          sum(p.numel() for p in model.parameters() if p.requires_grad))

    wandb.save('segmentation/smp_srm.py')
    wandb.save('dataset.py')

    train_imgaug, train_geo_aug = get_train_transforms()
    transforms_normalize = get_transforms_normalize()

    #region ########################-- CREATE DATASET and DATALOADER --########################
    train_dataset = DATASET(dataframe=df,
                            mode="train",
                            val_fold=VAL_FOLD,
                            test_fold=TEST_FOLD,
                            transforms_normalize=transforms_normalize,
                            imgaug_augment=train_imgaug,
                            geo_augment=train_geo_aug)
    train_loader = DataLoader(train_dataset,
                              batch_size=config.train_batch_size,
                              shuffle=True,
                              num_workers=4,
                              pin_memory=True,
                              drop_last=False)

    valid_dataset = DATASET(
        dataframe=df,
        mode="val",
        val_fold=VAL_FOLD,
        test_fold=TEST_FOLD,
        transforms_normalize=transforms_normalize,
    )
    valid_loader = DataLoader(valid_dataset,
                              batch_size=config.valid_batch_size,
                              shuffle=True,
                              num_workers=4,
                              pin_memory=True,
                              drop_last=False)

    test_dataset = DATASET(
        dataframe=df,
        mode="test",
        val_fold=VAL_FOLD,
        test_fold=TEST_FOLD,
        transforms_normalize=transforms_normalize,
    )
    test_loader = DataLoader(test_dataset,
                             batch_size=config.valid_batch_size,
                             shuffle=True,
                             num_workers=4,
                             pin_memory=True,
                             drop_last=False)
    #endregion ######################################################################################

    optimizer = get_optimizer(model, config.optimizer, config.learning_rate,
                              config.weight_decay)
    # after_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    #     optimizer,
    #     patience=config.schedule_patience,
    #     mode="min",
    #     factor=config.schedule_factor,
    # )
    after_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                                 T_0=35,
                                                                 T_mult=2)
    scheduler = GradualWarmupScheduler(optimizer=optimizer,
                                       multiplier=1,
                                       total_epoch=config.warmup + 1,
                                       after_scheduler=after_scheduler)

    # this zero gradient update is needed to avoid a warning message, issue #8.
    # optimizer.zero_grad()
    # optimizer.step()

    criterion = nn.BCEWithLogitsLoss()
    es = EarlyStopping(patience=200, mode="min")

    model = nn.DataParallel(model).to(device)

    # wandb.watch(model, log_freq=50, log='all')

    start_epoch = 0
    if resume:
        checkpoint = torch.load(
            'checkpoint/(using pretrain)COMBO_ALL_FULL_[09|04_12|46|35].pt')
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        print("-----------> Resuming <------------")

    for epoch in range(start_epoch, config.epochs):
        print(f"Epoch = {epoch}/{config.epochs-1}")
        print("------------------")

        train_metrics = train_epoch(model, train_loader, optimizer, scheduler,
                                    criterion, epoch)
        valid_metrics = valid_epoch(model, valid_loader, criterion, epoch)

        scheduler.step(valid_metrics['valid_loss'])

        print(
            f"TRAIN_ACC = {train_metrics['train_acc_05']}, TRAIN_LOSS = {train_metrics['train_loss']}"
        )
        print(
            f"VALID_ACC = {valid_metrics['valid_acc_05']}, VALID_LOSS = {valid_metrics['valid_loss']}"
        )
        print("Optimizer LR", optimizer.param_groups[0]['lr'])
        print("Scheduler LR", scheduler.get_lr()[0])
        wandb.log({
            'optim_lr': optimizer.param_groups[0]['lr'],
            'schedule_lr': scheduler.get_lr()[0]
        })

        es(
            valid_metrics["valid_loss"],
            model,
            model_path=os.path.join(OUTPUT_DIR, f"{run}.h5"),
        )
        if es.early_stop:
            print("Early stopping")
            break

        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
        }
        torch.save(checkpoint, os.path.join('checkpoint', f"{run}.pt"))

    if os.path.exists(os.path.join(OUTPUT_DIR, f"{run}.h5")):
        print(
            model.load_state_dict(
                torch.load(os.path.join(OUTPUT_DIR, f"{run}.h5"))))
        print("LOADED FOR TEST")

    test_metrics = test(model, test_loader, criterion)
    wandb.save(os.path.join(OUTPUT_DIR, f"{run}.h5"))

    return test_metrics
def train(args, train_dataset, model):
    tb_writer = SummaryWriter(args.tb_writer_dir)
    result_writer = ResultWriter(args.eval_results_dir)

    if args.weighted_sampling == 1:
        # 세 가지 구질이 불균일하게 분포되었으므로 세 개를 동일한 비율로 샘플링
        # 결과적으로 이 방법을 썼을 때 좋지 않아서 wighted_sampling은 쓰지 않았음
        ball_type, counts = np.unique(train_dataset.pitch, return_counts=True)
        count_dict = dict(zip(ball_type, counts))
        weights = [1.0 / count_dict[p] for p in train_dataset.pitch]
        sampler = WeightedRandomSampler(weights,
                                        len(train_dataset),
                                        replacement=True)
        logger.info("Do Weighted Sampling")
    else:
        sampler = RandomSampler(train_dataset)

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=args.train_batch_size,
                                  sampler=sampler)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // len(train_dataloader) + 1
    else:
        t_total = len(train_dataloader) * args.num_train_epochs

    args.warmup_step = int(args.warmup_percent * t_total)

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = [
        "bias",
        "layernorm.weight",
    ]

    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0,
        },
    ]

    optimizer = optim.Adam(optimizer_grouped_parameters,
                           lr=args.learning_rate,
                           eps=args.adam_epsilon)
    if args.warmup_step != 0:
        scheduler_cosine = CosineAnnealingLR(optimizer, t_total)
        scheduler = GradualWarmupScheduler(optimizer,
                                           1,
                                           args.warmup_step,
                                           after_scheduler=scheduler_cosine)
    else:
        scheduler = CosineAnnealingLR(optimizer, t_total)

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    loss_fct = torch.nn.NLLLoss()

    # Train!
    logger.info("***** Running Baseball Transformer *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Warmup Steps = %d", args.warmup_step)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.train_batch_size)
    logger.info("  Total train batch size = %d", args.train_batch_size)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    best_step = 0
    steps_trained_in_current_epoch = 0
    tr_loss, logging_loss, logging_val_loss = 0.0, 0.0, 0.0

    best_pitch_micro_f1, best_pitch_macro_f1, = 0, 0
    best_loss = 1e10
    best_pitch_macro_f1 = 0

    model.zero_grad()
    train_iterator = trange(
        epochs_trained,
        int(args.num_train_epochs),
        desc="Epoch",
    )
    set_seed(args)  # Added here for reproducibility
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader, desc="Iteration")
        for step, batch in enumerate(epoch_iterator):

            (
                pitcher,
                batter,
                state,
                pitch,
                label,
                pitch_memory,
                label_memory,
                memory_mask,
            ) = list(map(lambda x: x.to(args.device), batch))
            model.train()
            pitching_score, memories = model(
                pitcher,
                batter,
                state,
                pitch_memory,
                label_memory,
                memory_mask,
            )

            pitching_score = pitching_score.log_softmax(dim=-1)
            loss = loss_fct(pitching_score, pitch)

            if args.n_gpu > 1:
                loss = loss.mean()

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()

            if args.fp16:
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                               args.max_grad_norm)
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)
            optimizer.step()
            scheduler.step()
            model.zero_grad()
            global_step += 1

            if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                if args.evaluate_during_training:
                    results, f1_results, f1_log, cm = evaluate(
                        args, args.eval_data_file, model)
                    output_eval_file = os.path.join(args.output_dir,
                                                    "eval_results.txt")
                    print_result(output_eval_file, results, f1_log, cm)

                    for key, value in results.items():
                        tb_writer.add_scalar("eval_{}".format(key), value,
                                             global_step)
                    logging_val_loss = results["loss"]

                tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
                tb_writer.add_scalar("loss", (tr_loss - logging_loss) /
                                     args.logging_steps, global_step)
                logging_loss = tr_loss

                # best 모델 선정 지표를 loss말고 macro-f1으로 설정(trade-off 존재)
                # if best_loss > results["loss"]:
                if best_pitch_macro_f1 < results["pitch_macro_f1"]:
                    best_pitch_micro_f1 = results["pitch_micro_f1"]
                    best_pitch_macro_f1 = results["pitch_macro_f1"]
                    best_loss = results["loss"]
                    results["best_step"] = best_step = global_step

                    output_dir = os.path.join(args.output_dir, "best_model/")
                    os.makedirs(output_dir, exist_ok=True)
                    torch.save(model.state_dict(),
                               os.path.join(output_dir, "pytorch_model.bin"))
                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving best model to %s", output_dir)

                    result_path = os.path.join(output_dir, "best_results.txt")
                    print_result(result_path,
                                 results,
                                 f1_log,
                                 cm,
                                 off_logger=True)

                    results.update(dict(f1_results))
                    result_writer.update(args, **results)

                logger.info("  best pitch micro f1 : %s", best_pitch_micro_f1)
                logger.info("  best pitch macro f1 : %s", best_pitch_macro_f1)
                logger.info("  best loss : %s", best_loss)
                logger.info("  best step : %s", best_step)

            if args.save_steps > 0 and global_step % args.save_steps == 0:
                checkpoint_prefix = "checkpoint"
                # Save model checkpoint
                output_dir = os.path.join(
                    args.output_dir, "{}-{}".format(checkpoint_prefix,
                                                    global_step))
                os.makedirs(output_dir, exist_ok=True)
                torch.save(model.state_dict(),
                           os.path.join(output_dir, "pytorch_model.bin"))
                torch.save(args, os.path.join(output_dir, "training_args.bin"))
                logger.info("Saving model checkpoint to %s", output_dir)

                rotate_checkpoints(args, checkpoint_prefix)

                torch.save(optimizer.state_dict(),
                           os.path.join(output_dir, "optimizer.pt"))
                torch.save(scheduler.state_dict(),
                           os.path.join(output_dir, "scheduler.pt"))
                logger.info("Saving optimizer and scheduler states to %s",
                            output_dir)

    tb_writer.close()

    return global_step, tr_loss / global_step
Esempio n. 3
0
if warmup:
    warmup_epochs = 3
    scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer, opt.OPTIM.NUM_EPOCHS-warmup_epochs, eta_min=1e-6)
    scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine)
    scheduler.step()

######### Resume ###########
if opt.TRAINING.RESUME:
    path_chk_rest    = utils.get_last_path(model_dir, '_latest.pth')
    utils.load_checkpoint(model_restoration,path_chk_rest)
    start_epoch = utils.load_start_epoch(path_chk_rest) + 1
    utils.load_optim(optimizer, path_chk_rest)

    for i in range(1, start_epoch):
        scheduler.step()
    new_lr = scheduler.get_lr()[0]
    print('------------------------------------------------------------------------------')
    print("==> Resuming Training with learning rate:", new_lr)
    print('------------------------------------------------------------------------------')

if len(device_ids)>1:
    model_restoration = nn.DataParallel(model_restoration, device_ids = device_ids)

######### Loss ###########
criterion = CharbonnierLoss().cuda()

######### DataLoaders ###########
img_options_train = {'patch_size':opt.TRAINING.TRAIN_PS}

train_dataset = get_training_data(train_dir, img_options_train)
train_loader = DataLoader(dataset=train_dataset, batch_size=opt.OPTIM.BATCH_SIZE, shuffle=True, num_workers=16, drop_last=False)
Esempio n. 4
0
scheduler = GradualWarmupScheduler(optimizer,
                                   multiplier=1,
                                   total_epoch=warmup_epochs,
                                   after_scheduler=scheduler_cosine)
scheduler.step()

######### Resume ###########
if opt.TRAINING.RESUME:
    path_chk_rest = utils.get_last_path(model_dir, '_latest.pth')
    utils.load_checkpoint(model_restoration, path_chk_rest)
    start_epoch = utils.load_start_epoch(path_chk_rest) + 1
    utils.load_optim(optimizer, path_chk_rest)

    for i in range(1, start_epoch):
        scheduler.step()
    new_lr = scheduler.get_lr()[0]
    print(
        '------------------------------------------------------------------------------'
    )
    print("==> Resuming Training with learning rate:", new_lr)
    print(
        '------------------------------------------------------------------------------'
    )

if len(device_ids) > 1:
    model_restoration = nn.DataParallel(model_restoration,
                                        device_ids=device_ids)

######### Loss ###########
criterion_char = losses.CharbonnierLoss()
criterion_edge = losses.EdgeLoss()
Esempio n. 5
0
                        }, os.path.join(model_dir, "model_best.pth"))

                print(
                    "[Ep %d it %d\t PSNR SIDD: %.4f\t] ----  [best_Ep_SIDD %d best_it_SIDD %d Best_PSNR_SIDD %.4f] "
                    %
                    (epoch, i, psnr_val_rgb, best_epoch, best_iter, best_psnr))

            model_restoration.train()

    scheduler.step()

    print("------------------------------------------------------------------")
    print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".format(
        epoch,
        time.time() - epoch_start_time, epoch_loss,
        scheduler.get_lr()[0]))
    print("------------------------------------------------------------------")

    torch.save(
        {
            'epoch': epoch,
            'state_dict': model_restoration.state_dict(),
            'optimizer': optimizer.state_dict()
        }, os.path.join(model_dir, "model_latest.pth"))

    torch.save(
        {
            'epoch': epoch,
            'state_dict': model_restoration.state_dict(),
            'optimizer': optimizer.state_dict()
        }, os.path.join(model_dir, f"model_epoch_{epoch}.pth"))
def train(args, train_dataset, model):
    tb_writer = SummaryWriter()
    result_writer = ResultWriter(args.eval_results_dir)

    # Sampling 빈도 설정
    if args.sample_criteria is None:
        # 긍/부정 상황 & 투구구질 비율대로 Random sampling
        sampler = RandomSampler(train_dataset)
    else:
        # <투구구질> 기준(7 가지)으로 sampling 빈도 동등하게 조절
        if args.sample_criteria == "pitcher":
            counts = train_dataset.pitch_counts
            logger.info("  Counts of each ball type : %s", counts)
            pitch_contiguous = [
                i for p in train_dataset.origin_pitch for i, j in enumerate(p)
                if j == 1
            ]
            weights = [
                0 if p == 5 or p == 6 else 1.0 / counts[p]
                for p in pitch_contiguous
            ]
            sampler = WeightedRandomSampler(weights,
                                            len(train_dataset),
                                            replacement=True)
        # <긍,부정> 기준(2 가지)으로 sampling 빈도 동등하게 조절
        elif args.sample_criteria == "batter":
            counts = train_dataset.label_counts
            logger.info("  Counts of each label type : %s", counts)
            weights = [1.0 / counts[l] for l in train_dataset.label]
            sampler = WeightedRandomSampler(weights,
                                            len(train_dataset),
                                            replacement=True)
        # <투구구질 & 긍,부정> 기준(14 가지)으로 sampling 빈도 동등하게 조절
        elif args.sample_criteria == "both":
            counts = train_dataset.pitch_and_label_count
            logger.info("  Counts of each both type : %s", counts)
            pitch_contiguous = [
                i for p in train_dataset.origin_pitch for i, j in enumerate(p)
                if j == 1
            ]
            weights = [
                0 if p == 5 or p == 6 else 1.0 / counts[(p, l)]
                for p, l in zip(pitch_contiguous, train_dataset.label)
            ]
            sampler = WeightedRandomSampler(weights,
                                            len(train_dataset),
                                            replacement=True)
        else:
            sampler = RandomSampler(train_dataset)

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=args.train_batch_size,
        sampler=sampler,
    )
    t_total = len(train_dataloader) * args.num_train_epochs
    args.warmup_step = int(args.warmup_percent * t_total)

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = [
        "bias",
        "layernorm.weight",
    ]  # LayerNorm.weight -> layernorm.weight (model_parameter name)
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0,
        },
    ]

    optimizer = optim.Adam(optimizer_grouped_parameters,
                           lr=args.learning_rate,
                           eps=args.adam_epsilon)
    if args.warmup_step != 0:
        scheduler_cosine = CosineAnnealingLR(optimizer, t_total)
        scheduler = GradualWarmupScheduler(optimizer,
                                           1,
                                           args.warmup_step,
                                           after_scheduler=scheduler_cosine)
        # scheduler_plateau = ReduceLROnPlateau(optimizer, "min")
        # scheduler = GradualWarmupScheduler(
        #     optimizer, 1, args.warmup_step, after_scheduler=scheduler_plateau
        # )
    else:
        scheduler = CosineAnnealingLR(optimizer, t_total)
        # scheduler = ReduceLROnPlateau(optimizer, "min")

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)
    m = torch.nn.Sigmoid()
    loss_fct = torch.nn.BCELoss()

    # Train!
    logger.info("***** Running Baseball Transformer *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Warmup Steps = %d", args.warmup_step)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.train_batch_size)
    logger.info("  Total train batch size = %d", args.train_batch_size)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    tr_loss, logging_loss, logging_val_loss = 0.0, 0.0, 0.0

    best_pitch_micro_f1, best_pitch_macro_f1, = 0, 0
    best_loss = 1e10

    model.zero_grad()
    train_iterator = trange(
        epochs_trained,
        int(args.num_train_epochs),
        desc="Epoch",
    )
    set_seed(args)  # Added here for reproducibility
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader, desc="Iteration")
        for step, batch in enumerate(epoch_iterator):

            (
                pitcher_discrete,
                pitcher_continuous,
                batter_discrete,
                batter_continuous,
                state_discrete,
                state_continuous,
                pitch,
                hit,
                label,
                masked_pitch,
                origin_pitch,
            ) = list(map(lambda x: x.to(args.device), batch))
            model.train()

            # sentiment input
            pitching_score = model(
                pitcher_discrete,
                pitcher_continuous,
                batter_discrete,
                batter_continuous,
                state_discrete,
                state_continuous,
                label,
                args.concat if args.concat else 0,
            )

            pitching_score = pitching_score.contiguous()
            pitch = pitch.contiguous()
            # with sigmoid(m)
            loss = loss_fct(m(pitching_score), pitch)

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()

            if args.fp16:
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                               args.max_grad_norm)
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)
            optimizer.step()
            scheduler.step()
            model.zero_grad()
            global_step += 1

            if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                # Log metrics
                if args.evaluate_during_training:
                    results, f1_results, f1_log, cm_pos, cm_neg = evaluate(
                        args, args.eval_data_file, model)
                    output_eval_file = os.path.join(args.output_dir,
                                                    "eval_results_pos.txt")
                    print_result(output_eval_file, results, f1_log, cm_pos)

                    for key, value in results.items():
                        tb_writer.add_scalar("eval_{}".format(key), value,
                                             global_step)
                    logging_val_loss = results["loss"]
                    # scheduler.step(logging_val_loss)

                tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
                tb_writer.add_scalar("loss", (tr_loss - logging_loss) /
                                     args.logging_steps, global_step)
                logging_loss = tr_loss
                if best_loss > results["loss"]:
                    best_pitch_micro_f1 = results["pitch_micro_f1"]
                    best_pitch_macro_f1 = results["pitch_macro_f1"]
                    best_loss = results["loss"]

                    output_dir = os.path.join(args.output_dir, "best_model/")
                    os.makedirs(output_dir, exist_ok=True)
                    torch.save(model.state_dict(),
                               os.path.join(output_dir, "pytorch_model.bin"))
                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving best model to %s", output_dir)

                    result_path = os.path.join(output_dir, "best_results.txt")
                    print_result(result_path,
                                 results,
                                 f1_log,
                                 cm_pos,
                                 off_logger=True)

                    results.update(dict(f1_results))
                    result_writer.update(args, **results)

                logger.info("  best pitch micro f1 : %s", best_pitch_micro_f1)
                logger.info("  best pitch macro f1 : %s", best_pitch_macro_f1)
                logger.info("  best loss : %s", best_loss)

            if args.save_steps > 0 and global_step % args.save_steps == 0:
                checkpoint_prefix = "checkpoint"
                # Save model checkpoint
                output_dir = os.path.join(
                    args.output_dir, "{}-{}".format(checkpoint_prefix,
                                                    global_step))
                os.makedirs(output_dir, exist_ok=True)
                torch.save(model.state_dict(),
                           os.path.join(output_dir, "pytorch_model.bin"))
                torch.save(args, os.path.join(output_dir, "training_args.bin"))
                logger.info("Saving model checkpoint to %s", output_dir)

                rotate_checkpoints(args, checkpoint_prefix)

                torch.save(optimizer.state_dict(),
                           os.path.join(output_dir, "optimizer.pt"))
                torch.save(scheduler.state_dict(),
                           os.path.join(output_dir, "scheduler.pt"))
                logger.info("Saving optimizer and scheduler states to %s",
                            output_dir)

    tb_writer.close()

    return global_step, tr_loss / global_step