示例#1
0
def train():
    # 有GPU的话优先使用GPU
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("device: ", device)
    model = Net()  # 创建模型
    # model.to(device)
    train_data_loader = get_train_data_loader()
    criterion = get_criterion()
    optimizer = get_optimizer(model)
    num_epoch = 5
    for epoch in range(num_epoch):
        running_loss = 0.0
        for i, data in enumerate(train_data_loader):
            # 得到输入数据
            inputs, labels = data
            # inputs, labels = inputs.to(device), labels.to(device)
            # 梯度置零
            optimizer.zero_grad()
            # 前传+后传+梯度更新
            outpus = model(inputs)
            loss = criterion(outpus, labels)
            loss.backward()
            optimizer.step()
            # 输出结果
            running_loss += loss.item()
            if i % 2000 == 1999:  # 每2000个mini-batch打印一次
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i + 1, running_loss / 2000))
                running_loss = 0.0
    print('Finished Training')
    save_path = "LeNet.pth"
    torch.save(model.state_dict(), save_path)
def recover_pack():
    train_loader, test_loader = get_loader()

    pack = dotdict({
        'net': get_model(),
        'train_loader': train_loader,
        'test_loader': test_loader,
        'trainer': get_trainer(),
        'criterion': get_criterion(),
        'optimizer': None,
        'lr_scheduler': None
    })

    adjust_learning_rate(cfg.base.epoch, pack)
    return pack
示例#3
0
    logging.info("Number of GPUs: {}, using DaraParallel.".format(args.n_gpu))
# Distributed training (should be after apex fp16 initialization)
if args.local_rank != -1 and cfg.distributed:
    process_group = torch.distributed.new_group(list(range(args.num_gpus)))
    net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net, process_group)

    net = torch.nn.parallel.DistributedDataParallel(
        net,
        device_ids=[args.local_rank],
        output_device=args.local_rank,
    )
    logging.info("Number of GPUs: {}, using DistributedDaraParallel.".format(
        args.num_gpus))

##################### Loss function and optimizer ############################
criterion_eval = get_criterion(cfg, train=False)
criterion_eval.cuda()
optimizer = None
scheduler = None
if not cfg.EVALUATE:
    criterion = get_criterion(cfg)
    criterion.cuda()
    optimizer = get_opt(cfg, net, resume=iteration > 0)
    scheduler = get_lr_scheduler(cfg, optimizer, last_iter=iteration)

##################### make a checkpoint ############################
best_acc = 0.0
checkpointer = Checkpointer(net,
                            cfg.MODEL.ARCH,
                            best_acc=best_acc,
                            optimizer=optimizer,
def train_cv(
    task: str = Task.AgeC,  # 수행할 태스크(분류-메인 태스크, 마스크 상태, 연령대, 성별, 회귀-나이)
    model_type: str = Config.VanillaEfficientNet,  # 불러올 모델명
    load_state_dict: str = None,  # 학습 이어서 할 경우 저장된 파라미터 경로
    train_root: str = Config.TrainS,  # 데이터 경로
    valid_root: str = Config.ValidS,
    transform_type: str = Aug.BaseTransform,  # 적용할 transform
    age_filter: int=58,
    epochs: int = Config.Epochs,
    cv: int = 5,
    batch_size: int = Config.Batch32,
    optim_type: str = Config.Adam,
    loss_type: str = Loss.CE,
    lr: float = Config.LRBase,
    lr_scheduler: str = Config.CosineScheduler,
    save_path: str = Config.ModelPath,
    seed: int = Config.Seed,
):
    if save_path:
        kfold_dir = f"kfold_{model_type}_" + get_timestamp()
        if kfold_dir not in os.listdir(save_path):
            os.mkdir(os.path.join(save_path, kfold_dir))
        print(f'Models will be saved in {os.path.join(save_path, kfold_dir)}.')

    set_seed(seed)
    transform = configure_transform(phase="train", transform_type=transform_type)
    trainset = TrainDataset(root=train_root, transform=transform, task=task, age_filter=age_filter, meta_path=Config.Metadata)
    validloader = get_dataloader(
        task, "valid", valid_root, transform_type, 1024, shuffle=False, drop_last=False
    )

    kfold = KFold(n_splits=cv, shuffle=True)

    for fold_idx, (train_indices, _) in enumerate(
        kfold.split(trainset)
    ):  # 앙상블이 목적이므로 test 인덱스는 따로 사용하지 않고, validloader를 통해 성능 검증
        if fold_idx == 0 or fold_idx == 1 or fold_idx == 2 or fold_idx == 3: continue
        print(f"Train Fold #{fold_idx}")
        train_sampler = SubsetRandomSampler(train_indices)
        trainloader = DataLoader(
            trainset, batch_size=batch_size, sampler=train_sampler, drop_last=True
        )

        model = load_model(model_type, task, load_state_dict)
        model.cuda()
        model.train()

        optimizer = get_optim(model, optim_type=optim_type, lr=lr)
        criterion = get_criterion(loss_type=loss_type, task=task)

        if lr_scheduler is not None:
            scheduler = get_scheduler(scheduler_type=lr_scheduler, optimizer=optimizer)

        best_f1 = 0

        if task != Task.Age:  # classification(main, ageg, mask, gender)
            for epoch in range(epochs):
                print(f"Epoch: {epoch}")

                # F1, ACC
                pred_list = []
                true_list = []

                # CE Loss
                total_loss = 0
                num_samples = 0

                for idx, (imgs, labels) in tqdm(enumerate(trainloader), desc="Train"):
                    imgs = imgs.cuda()
                    labels = labels.cuda()

                    output = model(imgs)
                    loss = criterion(output, labels)
                    _, preds = torch.max(output.data, dim=1)

                    pred_list.append(preds.data.cpu().numpy())
                    true_list.append(labels.data.cpu().numpy())

                    total_loss += loss
                    num_samples += imgs.size(0)

                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    if lr_scheduler is not None:
                        scheduler.step()

                    train_loss = total_loss / num_samples

                    pred_arr = np.hstack(pred_list)
                    true_arr = np.hstack(true_list)
                    train_acc = (true_arr == pred_arr).sum() / len(true_arr)
                    train_f1 = f1_score(
                        y_true=true_arr, y_pred=pred_arr, average="macro"
                    )

                    if epoch == 0:  # logs during just first epoch

                        wandb.log(
                            {
                                f"Fold #{fold_idx} Ep{epoch:0>2d} Train F1": train_f1,
                                f"Fold #{fold_idx} Ep{epoch:0>2d} Train ACC": train_acc,
                                f"Fold #{fold_idx} Ep{epoch:0>2d} Train Loss": train_loss,
                            }
                        )

                    if idx != 0 and idx % VALID_CYCLE == 0:
                        valid_f1, valid_acc, valid_loss = validate(
                            task, model, validloader, criterion
                        )

                        print(
                            f"[Valid] F1: {valid_f1:.4f} ACC: {valid_acc:.4f} Loss: {valid_loss:.4f}"
                        )
                        print(
                            f"[Train] F1: {train_f1:.4f} ACC: {train_acc:.4f} Loss: {train_loss:.4f}"
                        )
                        if epoch == 0:
                            # logs during one epoch
                            wandb.log(
                                {
                                    f"Fold #{fold_idx} Ep{epoch:0>2d} Valid F1": valid_f1,
                                    f"Fold #{fold_idx} Ep{epoch:0>2d} Valid ACC": valid_acc,
                                    f"Fold #{fold_idx} Ep{epoch:0>2d} Valid Loss": valid_loss,
                                }
                            )

                # logs for one epoch in total
                wandb.log(
                    {
                        f"Fold #{fold_idx} Train F1": train_f1,
                        f"Fold #{fold_idx} Valid F1": valid_f1,
                        f"Fold #{fold_idx} Train ACC": train_acc,
                        f"Fold #{fold_idx} Valid ACC": valid_acc,
                        f"Fold #{fold_idx} Train Loss": train_loss,
                        f"Fold #{fold_idx} Valid Loss": valid_loss,
                    }
                )

                if save_path and valid_f1 >= best_f1:
                    name = f"Fold{fold_idx:0>2d}_{model_type}_task({task})ep({epoch:0>2d})f1({valid_f1:.4f})bs({batch_size})loss({valid_loss:.4f})lr({lr})trans({transform_type})optim({optim_type})crit({loss_type})seed({seed}).pth"
                    best_f1 = valid_f1
                    torch.save(
                        model.state_dict(), os.path.join(save_path, kfold_dir, name)
                    )

        # regression(age)
        else:
            for epoch in range(epochs):
                print(f"Epoch: {epoch}")

                pred_list = []
                true_list = []

                mse_raw = 0
                rmse_raw = 0
                num_samples = 0

                for idx, (imgs, labels) in tqdm(enumerate(trainloader), desc="Train"):
                    imgs = imgs.cuda()

                    # regression(age)
                    labels_reg = labels.float().cuda()
                    output = model(imgs)
                    loss = criterion(output, labels_reg.unsqueeze(1))

                    mse_raw += loss.item() * len(labels_reg)
                    rmse_raw += loss.item() * len(labels_reg)
                    num_samples += len(labels_reg)

                    # backward
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    scheduler.step()

                    # classification(ageg)
                    labels_clf = age2ageg(labels.data.numpy())
                    preds_clf = age2ageg(output.data.cpu().numpy().flatten())
                    pred_list.append(preds_clf)
                    true_list.append(labels_clf)

                    train_rmse = math.sqrt(rmse_raw / num_samples)
                    train_mse = mse_raw / num_samples

                    # eval for clf(ageg)
                    pred_arr = np.hstack(pred_list)
                    true_arr = np.hstack(true_list)

                    train_acc = (true_arr == pred_arr).sum() / len(true_arr)
                    train_f1 = f1_score(
                        y_true=true_arr, y_pred=pred_arr, average="macro"
                    )

                    if idx != 0 and idx % VALID_CYCLE == 0:
                        valid_f1, valid_acc, valid_rmse, valid_mse = validate(
                            task, model, validloader, criterion
                        )
                        print(
                            f"[Valid] F1: {valid_f1:.4f} ACC: {valid_acc:.4f} RMSE: {valid_rmse:.4f} MSE: {valid_mse:.4f}"
                        )
                        print(
                            f"[Train] F1: {train_f1:.4f} ACC: {train_acc:.4f} RMSE: {train_rmse:.4f} MSE: {train_mse:.4f}"
                        )

                wandb.log(
                    {
                        f"Fold #{fold_idx} Train F1": train_f1,
                        f"Fold #{fold_idx} Valid F1": valid_f1,
                        f"Fold #{fold_idx} Train ACC": train_acc,
                        f"Fold #{fold_idx} Valid ACC": valid_acc,
                        f"Fold #{fold_idx} Train RMSE": train_rmse,
                        f"Fold #{fold_idx} Valid RMSE": valid_rmse,
                        f"Fold #{fold_idx} Train MSE": train_mse,
                        f"Fold #{fold_idx} Valid MSE": valid_mse,
                    }
                )

                if save_path:
                    name = f"Fold{fold_idx:0>2d}_{model_type}_task({task})ep({epoch:0>2d})f1({valid_f1:.4f})bs({batch_size})loss({valid_mse:.4f})lr({lr})trans({transform_type})optim({optim_type})crit({loss_type})seed({seed}).pth"
                    torch.save(
                        model.state_dict(), os.path.join(save_path, kfold_dir, name)
                    )
        model.cpu()
def train(
    task: str = Task.AgeC,  # 수행할 태스크(분류-메인 태스크, 마스크 상태, 연령대, 성별, 회귀-나이)
    model_type: str = Config.VanillaEfficientNet,  # 불러올 모델명
    load_state_dict: str = None,  # 학습 이어서 할 경우 저장된 파라미터 경로
    train_root: str = Config.TrainS,  # 데이터 경로
    valid_root: str = Config.ValidS,
    transform_type: str = Aug.BaseTransform,  # 적용할 transform
    epochs: int = Config.Epochs,
    batch_size: int = Config.Batch32,
    optim_type: str = Config.Adam,
    loss_type: str = Loss.CE,
    lr: float = Config.LRBase,
    lr_scheduler: str = Config.CosineScheduler,
    save_path: str = Config.ModelPath,
    seed: int = Config.Seed,
):
    set_seed(seed)
    trainloader = get_dataloader(task, "train", train_root, transform_type, batch_size)
    validloader = get_dataloader(
        task, "valid", valid_root, transform_type, 1024, shuffle=False, drop_last=False
    )

    model = load_model(model_type, task, load_state_dict)
    model.cuda()
    model.train()

    optimizer = get_optim(model, optim_type=optim_type, lr=lr)
    criterion = get_criterion(loss_type=loss_type, task=task)

    if lr_scheduler is not None:
        scheduler = get_scheduler(scheduler_type=lr_scheduler, optimizer=optimizer)

    best_f1 = 0

    if task != Task.Age:  # classification(main, ageg, mask, gender)
        for epoch in range(epochs):
            print(f"Epoch: {epoch}")

            # F1, ACC
            pred_list = []
            true_list = []

            # CE Loss
            total_loss = 0
            num_samples = 0

            for idx, (imgs, labels) in tqdm(enumerate(trainloader), desc="Train"):
                imgs = imgs.cuda()
                labels = labels.cuda()

                output = model(imgs)
                loss = criterion(output, labels)
                _, preds = torch.max(output.data, dim=1)

                pred_list.append(preds.data.cpu().numpy())
                true_list.append(labels.data.cpu().numpy())

                total_loss += loss
                num_samples += imgs.size(0)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                if lr_scheduler is not None:
                    scheduler.step()

                train_loss = total_loss / num_samples

                pred_arr = np.hstack(pred_list)
                true_arr = np.hstack(true_list)
                train_acc = (true_arr == pred_arr).sum() / len(true_arr)
                train_f1 = f1_score(y_true=true_arr, y_pred=pred_arr, average="macro")

                if epoch == 0:  # logs during just first epoch

                    wandb.log(
                        {
                            f"Ep{epoch:0>2d} Train F1": train_f1,
                            f"Ep{epoch:0>2d} Train ACC": train_acc,
                            f"Ep{epoch:0>2d} Train Loss": train_loss,
                        }
                    )

                if idx != 0 and idx % VALID_CYCLE == 0:
                    valid_f1, valid_acc, valid_loss = validate(
                        task, model, validloader, criterion
                    )

                    print(
                        f"[Valid] F1: {valid_f1:.4f} ACC: {valid_acc:.4f} Loss: {valid_loss:.4f}"
                    )
                    print(
                        f"[Train] F1: {train_f1:.4f} ACC: {train_acc:.4f} Loss: {train_loss:.4f}"
                    )
                    if epoch == 0:
                        # logs during one epoch
                        wandb.log(
                            {
                                f"Ep{epoch:0>2d} Valid F1": valid_f1,
                                f"Ep{epoch:0>2d} Valid ACC": valid_acc,
                                f"Ep{epoch:0>2d} Valid Loss": valid_loss,
                            }
                        )

            # logs for one epoch in total
            wandb.log(
                {
                    "Train F1": train_f1,
                    "Valid F1": valid_f1,
                    "Train ACC": train_acc,
                    "Valid ACC": valid_acc,
                    "Train Loss": train_loss,
                    "Valid Loss": valid_loss,
                }
            )

            if save_path and valid_f1 >= best_f1:
                name = f"{model_type}_task({task})ep({epoch:0>2d})f1({valid_f1:.4f})bs({batch_size})loss({valid_loss:.4f})lr({lr})trans({transform_type})optim({optim_type})crit({loss_type})seed({seed}).pth"
                best_f1 = valid_f1
                torch.save(model.state_dict(), os.path.join(save_path, name))

    # regression(age)
    else:
        for epoch in range(epochs):
            print(f"Epoch: {epoch}")

            pred_list = []
            true_list = []

            mse_raw = 0
            rmse_raw = 0
            num_samples = 0

            for idx, (imgs, labels) in tqdm(enumerate(trainloader), desc="Train"):
                imgs = imgs.cuda()

                # regression(age)
                labels_reg = labels.float().cuda()
                output = model(imgs)
                loss = criterion(output, labels_reg.unsqueeze(1))

                mse_raw += loss.item() * len(labels_reg)
                rmse_raw += loss.item() * len(labels_reg)
                num_samples += len(labels_reg)

                # backward
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                scheduler.step()

                # classification(ageg)
                labels_clf = age2ageg(labels.data.numpy())
                preds_clf = age2ageg(output.data.cpu().numpy().flatten())
                pred_list.append(preds_clf)
                true_list.append(labels_clf)

                train_rmse = math.sqrt(rmse_raw / num_samples)
                train_mse = mse_raw / num_samples

                # eval for clf(ageg)
                pred_arr = np.hstack(pred_list)
                true_arr = np.hstack(true_list)

                train_acc = (true_arr == pred_arr).sum() / len(true_arr)
                train_f1 = f1_score(y_true=true_arr, y_pred=pred_arr, average="macro")

                # logs during one epoch
                # wandb.log(
                #     {
                #         f"Ep{epoch:0>2d} Train F1": train_f1,
                #         f"Ep{epoch:0>2d} Train ACC": train_acc,
                #         f"Ep{epoch:0>2d} Train RMSE": train_rmse,
                #         f"Ep{epoch:0>2d} Train MSE": train_mse,
                #     }
                # )

                if idx != 0 and idx % VALID_CYCLE == 0:
                    valid_f1, valid_acc, valid_rmse, valid_mse = validate(
                        task, model, validloader, criterion
                    )
                    print(
                        f"[Valid] F1: {valid_f1:.4f} ACC: {valid_acc:.4f} RMSE: {valid_rmse:.4f} MSE: {valid_mse:.4f}"
                    )
                    print(
                        f"[Train] F1: {train_f1:.4f} ACC: {train_acc:.4f} RMSE: {train_rmse:.4f} MSE: {train_mse:.4f}"
                    )
                    # wandb.log(
                    #     {
                    #         "Valid F1": valid_f1,
                    #         "Valid ACC": valid_acc,
                    #         "Valid RMSE": valid_rmse,
                    #         "Valid MSE": valid_mse,
                    #     }
                    # )
            wandb.log(
                {
                    "Train F1": train_f1,
                    "Valid F1": valid_f1,
                    "Train ACC": train_acc,
                    "Valid ACC": valid_acc,
                    "Train RMSE": train_rmse,
                    "Valid RMSE": valid_rmse,
                    "Train MSE": train_mse,
                    "Valid MSE": valid_mse,
                }
            )

            if save_path:
                name = f"{model_type}_task({task})ep({epoch:0>2d})f1({valid_f1:.4f})bs({batch_size})loss({valid_mse:.4f})lr({lr})trans({transform_type})optim({optim_type})crit({loss_type})seed({seed}).pth"
                torch.save(model.state_dict(), os.path.join(save_path, name))
示例#6
0
def main():
    args = parse_args()
    conf = Config(args.conf)

    data_dir = conf.data_dir
    fold_id = conf.fold_id

    workspace = Workspace(conf.run_id).setup()
    workspace.save_conf(args.conf)
    workspace.log(f'{conf.to_dict()}')

    torch.cuda.set_device(0)

    if conf.use_augmentor:
        if conf.augmentor_type == 'v1':
            augmentor = create_augmentor_v1(
                enable_random_morph=conf.enable_random_morph)
        elif conf.augmentor_type == 'v2':
            augmentor = create_augmentor_v2(
                enable_random_morph=conf.enable_random_morph,
                invert_color=conf.invert_color)
        elif conf.augmentor_type == 'v3':
            if conf.input_size_tuple:
                input_size = tuple(conf.input_size_tuple)
            else:
                input_size = (conf.input_size, conf.input_size) if conf.input_size else \
                             (SOURCE_IMAGE_HEIGHT, SOURCE_IMAGE_WIDTH)
            augmentor = create_augmentor_v3(
                input_size,
                enable_random_morph=conf.enable_random_morph,
                invert_color=conf.invert_color)
        else:
            raise ValueError(conf.augmentor_type)
        workspace.log(f'Use augmentor: {conf.augmentor_type}')
    else:
        augmentor = None

    if not conf.input_size_tuple and conf.input_size == 0:
        train_transformer = create_transformer_v1(augmentor=augmentor)
        val_transformer = create_testing_transformer_v1()
        workspace.log('Input size: default')
    else:
        if conf.input_size_tuple:
            input_size = tuple(conf.input_size_tuple)
        else:
            input_size = (conf.input_size, conf.input_size)
        train_transformer = create_transformer_v1(input_size=input_size,
                                                  augmentor=augmentor)
        val_transformer = create_testing_transformer_v1(input_size=input_size)
        workspace.log(f'Input size: {input_size}')

    train_dataset, val_dataset = bengali_dataset(
        data_dir,
        fold_id=fold_id,
        train_transformer=train_transformer,
        val_transformer=val_transformer,
        invert_color=conf.invert_color,
        n_channel=conf.n_channel,
        use_grapheme_code=conf.use_grapheme_code,
        logger=workspace.logger)
    workspace.log(f'#train={len(train_dataset)}, #val={len(val_dataset)}')
    train_dataset.set_low_freq_groups(n_class=conf.n_class_low_freq)

    if conf.sampler_type == 'pk':
        sampler = PKSampler(train_dataset,
                            n_iter_per_epoch=conf.n_iter_per_epoch,
                            p=conf.batch_p,
                            k=conf.batch_k)
        train_loader = DataLoader(train_dataset,
                                  shuffle=False,
                                  num_workers=8,
                                  pin_memory=True,
                                  batch_sampler=sampler)
        workspace.log(f'{sampler} is enabled')
        workspace.log(f'Real batch_size={sampler.batch_size}')
    elif conf.sampler_type == 'random+append':
        batch_sampler = LowFreqSampleMixinBatchSampler(
            train_dataset,
            conf.batch_size,
            n_low_freq_samples=conf.n_low_freq_samples,
            drop_last=True)
        train_loader = DataLoader(train_dataset,
                                  shuffle=False,
                                  num_workers=8,
                                  pin_memory=True,
                                  batch_sampler=batch_sampler)
        workspace.log(f'{batch_sampler} is enabled')
        workspace.log(f'Real batch_size={batch_sampler.batch_size}')
    elif conf.sampler_type == 'random':
        train_loader = DataLoader(train_dataset,
                                  batch_size=conf.batch_size,
                                  shuffle=True,
                                  num_workers=8,
                                  pin_memory=True,
                                  drop_last=True)
    else:
        raise ValueError(f'Invalid sampler_type: {conf.sampler_type}')

    val_loader = DataLoader(val_dataset,
                            batch_size=conf.batch_size,
                            shuffle=False,
                            num_workers=8,
                            pin_memory=True)

    workspace.log(f'Create init model: arch={conf.arch}')
    model = create_init_model(conf.arch,
                              pretrained=True,
                              pooling=conf.pooling_type,
                              dim=conf.feat_dim,
                              use_maxblurpool=conf.use_maxblurpool,
                              remove_last_stride=conf.remove_last_stride,
                              n_channel=conf.n_channel)
    if conf.weight_file:
        pretrained_weight = torch.load(conf.weight_file, map_location='cpu')
        result = model.load_state_dict(pretrained_weight)
        workspace.log(f'Pretrained weights were loaded: {conf.weight_file}')
        workspace.log(result)

    model = model.cuda()

    sub_models = []

    criterion_g = get_criterion(conf.loss_type_g,
                                weight=train_dataset.get_class_weights_g(),
                                rate=conf.ohem_rate)
    workspace.log(f'Loss type (g): {conf.loss_type_g}')

    criterion_v = get_criterion(conf.loss_type_v,
                                weights=train_dataset.get_class_weights_v(),
                                rate=conf.ohem_rate)
    workspace.log(f'Loss type (v): {conf.loss_type_v}')

    criterion_c = get_criterion(conf.loss_type_c,
                                weights=train_dataset.get_class_weights_c(),
                                rate=conf.ohem_rate)
    workspace.log(f'Loss type (c): {conf.loss_type_c}')

    if conf.loss_type_feat_g != 'none':
        assert isinstance(
            model, (M.BengaliResNet34V3, M.BengaliResNet34V4,
                    M.BengaliResNet34AGeMV4, M.BengaliSEResNeXt50V4,
                    M.BengaliEfficientNetB0V4, M.BengaliEfficientNetB3V4))
        criterion_feat_g = get_criterion(conf.loss_type_feat_g,
                                         dim=model.multihead.head_g.dim,
                                         n_class=168,
                                         s=conf.af_scale_g)
        workspace.log(f'Loss type (fg): {conf.loss_type_feat_g}')
        if conf.loss_type_feat_g in ('af', ):
            sub_models.append(criterion_feat_g)
            workspace.log('Add criterion_feat_g to sub model')
    else:
        criterion_feat_g = None

    if conf.loss_type_feat_v != 'none':
        assert isinstance(
            model, (M.BengaliResNet34V3, M.BengaliResNet34V4,
                    M.BengaliResNet34AGeMV4, M.BengaliSEResNeXt50V4,
                    M.BengaliEfficientNetB0V4, M.BengaliEfficientNetB3V4))
        criterion_feat_v = get_criterion(conf.loss_type_feat_v,
                                         dim=model.multihead.head_v.dim,
                                         n_class=11,
                                         s=conf.af_scale_v)
        workspace.log(f'Loss type (fv): {conf.loss_type_feat_v}')
        if conf.loss_type_feat_v in ('af', ):
            sub_models.append(criterion_feat_v)
            workspace.log('Add criterion_feat_v to sub model')
    else:
        criterion_feat_v = None

    if conf.loss_type_feat_c != 'none':
        assert isinstance(
            model, (M.BengaliResNet34V3, M.BengaliResNet34V4,
                    M.BengaliResNet34AGeMV4, M.BengaliSEResNeXt50V4,
                    M.BengaliEfficientNetB0V4, M.BengaliEfficientNetB3V4))
        criterion_feat_c = get_criterion(conf.loss_type_feat_c,
                                         dim=model.multihead.head_c.dim,
                                         n_class=7,
                                         s=conf.af_scale_c)
        workspace.log(f'Loss type (fc): {conf.loss_type_feat_c}')
        if conf.loss_type_feat_c in ('af', ):
            sub_models.append(criterion_feat_c)
            workspace.log('Add criterion_feat_c to sub model')
    else:
        criterion_feat_c = None

    if conf.use_grapheme_code:
        workspace.log('Use grapheme code classifier')
        grapheme_classifier = nn.Sequential(nn.BatchNorm1d(168 + 11 + 7),
                                            nn.Linear(168 + 11 + 7, 1295))
        grapheme_classifier = grapheme_classifier.cuda()
        grapheme_classifier.train()
        sub_models.append(grapheme_classifier)
        criterion_grapheme = L.OHEMCrossEntropyLoss().cuda()
    else:
        grapheme_classifier = None
        criterion_grapheme = None

    parameters = [{'params': model.parameters()}] + \
                 [{'params': sub_model.parameters()} for sub_model in sub_models]

    if conf.optimizer_type == 'adam':
        optimizer = torch.optim.Adam(parameters, lr=conf.lr)
    elif conf.optimizer_type == 'sgd':
        optimizer = torch.optim.SGD(parameters,
                                    lr=conf.lr,
                                    momentum=0.9,
                                    weight_decay=1e-4)
    elif conf.optimizer_type == 'ranger':
        optimizer = Ranger(parameters, lr=conf.lr, weight_decay=1e-4)
    elif conf.optimizer_type == 'radam':
        optimizer = RAdam(parameters, lr=conf.lr, weight_decay=1e-4)
    else:
        raise ValueError(conf.optimizer_type)
    workspace.log(f'Optimizer type: {conf.optimizer_type}')

    if conf.use_apex:
        workspace.log('Apex initialization')
        _models, optimizer = amp.initialize([model] + sub_models,
                                            optimizer,
                                            opt_level=conf.apex_opt_level)
        if len(_models) == 1:
            model = _models[0]
        else:
            model = _models[0]
            criterion_feat_g = _models[1]
            criterion_feat_v = _models[2]
            criterion_feat_c = _models[3]
        workspace.log('Initialized by Apex')
        workspace.log(f'{optimizer.__class__.__name__}')
        for m in _models:
            workspace.log(f'{m.__class__.__name__}')

    if conf.scheduler_type == 'cosanl':
        scheduler = CosineLRWithRestarts(
            optimizer,
            conf.batch_size,
            len(train_dataset),
            restart_period=conf.cosanl_restart_period,
            t_mult=conf.cosanl_t_mult)
        workspace.log(f'restart_period={scheduler.restart_period}')
        workspace.log(f't_mult={scheduler.t_mult}')
    elif conf.scheduler_type == 'rop':
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            patience=conf.rop_patience,
            mode='max',
            factor=conf.rop_factor,
            min_lr=1e-6,
            verbose=True)
    else:
        raise ValueError(conf.scheduler_type)

    train(model,
          train_loader,
          val_loader,
          optimizer,
          criterion_g,
          criterion_v,
          criterion_c,
          criterion_feat_g,
          criterion_feat_v,
          criterion_feat_c,
          workspace,
          scheduler=scheduler,
          n_epoch=conf.n_epoch,
          cutmix_prob=conf.cutmix_prob,
          mixup_prob=conf.mixup_prob,
          freeze_bn_epochs=conf.freeze_bn_epochs,
          feat_loss_weight=conf.feat_loss_weight,
          use_apex=conf.use_apex,
          decrease_ohem_rate=conf.decrease_ohem_rate,
          use_grapheme_code=conf.use_grapheme_code,
          grapheme_classifier=grapheme_classifier,
          criterion_grapheme=criterion_grapheme,
          final_ft=conf.final_ft)
def main():
    global args

    args = parse_args()
    print(args)

    init_dist(args)

    (train_loader, train_sampler), dev_loader = get_loaders(args)

    model = get_model(args)
    # model = model.to(memory_format=torch.channels_last)
    if args.dist.sync_bn:
        print("using apex synced BN")
        model = apex.parallel.convert_syncbn_model(model)

    model.cuda()

    criterion = get_criterion(args).cuda()

    opt = get_opt(args, model, criterion)

    scaler = torch.cuda.amp.GradScaler()

    # For distributed training, wrap the model with apex.parallel.DistributedDataParallel.
    # This must be done AFTER the call to amp.initialize.  If model = DDP(model) is called
    # before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter
    # the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks.
    if args.dist.distributed:
        # By default, apex.parallel.DistributedDataParallel overlaps communication with
        # computation in the backward pass.
        # model = DDP(model)
        # delay_allreduce delays all communication to the end of the backward pass.
        model = apex.parallel.DistributedDataParallel(model,
                                                      delay_allreduce=True)

    best_score = 0
    metrics = {"score": Score(), "acc": Accuracy()}

    history = {k: {k_: [] for k_ in ["train", "dev"]} for k in ["loss"]}
    history.update({k: {v: [] for v in ["train", "dev"]} for k in metrics})

    work_dir = Path(args.general.work_dir) / f"{args.train.fold}"
    if args.dist.local_rank == 0 and not work_dir.exists():
        work_dir.mkdir(parents=True)

    # Optionally load model from a checkpoint
    if args.train.load:

        def _load():
            path_to_load = Path(args.train.load).expanduser()
            if path_to_load.is_file():
                print(f"=> loading model '{path_to_load}'")
                checkpoint = torch.load(
                    path_to_load,
                    map_location=lambda storage, loc: storage.cuda(args.dist.
                                                                   gpu),
                )
                model.load_state_dict(checkpoint["state_dict"])
                print(f"=> loaded model '{path_to_load}'")
            else:
                print(f"=> no model found at '{path_to_load}'")

        _load()

    scheduler = None
    if args.opt.scheduler == "cos":
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            opt, T_max=args.opt.T_max, eta_min=max(args.opt.lr * 1e-2, 1e-6))

    # Optionally resume from a checkpoint
    if args.train.resume:
        # Use a local scope to avoid dangling references
        def _resume():
            nonlocal history, best_score
            path_to_resume = Path(args.train.resume).expanduser()
            if path_to_resume.is_file():
                print(f"=> loading resume checkpoint '{path_to_resume}'")
                checkpoint = torch.load(
                    path_to_resume,
                    map_location=lambda storage, loc: storage.cuda(args.dist.
                                                                   gpu),
                )
                args.train.start_epoch = checkpoint["epoch"] + 1
                history = checkpoint["history"]
                best_score = max(history["score"]["dev"])
                model.load_state_dict(checkpoint["state_dict"])
                opt.load_state_dict(checkpoint["opt_state_dict"])
                scheduler.load_state_dict(checkpoint["sched_state_dict"])
                scaler.load_state_dict(checkpoint["scaler"])
                print(
                    f"=> resume from checkpoint '{path_to_resume}' (epoch {checkpoint['epoch']})"
                )
            else:
                print(f"=> no checkpoint found at '{path_to_resume}'")

        _resume()

    def saver(path):
        torch.save(
            {
                "epoch":
                epoch,
                "best_score":
                best_score,
                "history":
                history,
                "state_dict":
                model.state_dict(),
                "opt_state_dict":
                opt.state_dict(),
                "sched_state_dict":
                scheduler.state_dict() if scheduler is not None else None,
                "scaler":
                scaler.state_dict(),
                "args":
                args,
            },
            path,
        )

    for epoch in range(args.train.start_epoch, args.train.epochs + 1):

        if args.dist.distributed:
            train_sampler.set_epoch(epoch)

        for metric in metrics.values():
            metric.clean()

        loss = epoch_step(
            train_loader,
            f"[ Training {epoch}/{args.train.epochs}.. ]",
            model=model,
            criterion=criterion,
            metrics=metrics,
            scaler=scaler,
            opt=opt,
            batch_accum=args.train.batch_accum,
        )
        history["loss"]["train"].append(loss)
        for k, metric in metrics.items():
            history[k]["train"].append(metric.evaluate())

        if not args.train.ft:
            with torch.no_grad():
                for metric in metrics.values():
                    metric.clean()
                loss = epoch_step(
                    dev_loader,
                    f"[ Validating {epoch}/{args.train.epochs}.. ]",
                    model=model,
                    criterion=criterion,
                    metrics=metrics,
                    scaler=scaler,
                    opt=None,
                )
                history["loss"]["dev"].append(loss)
                for k, metric in metrics.items():
                    history[k]["dev"].append(metric.evaluate())
        else:
            history["loss"]["dev"].append(loss)
            for k, metric in metrics.items():
                history[k]["dev"].append(metric.evaluate())

        if scheduler is not None:
            scheduler.step()

        if args.dist.local_rank == 0:
            if history["score"]["dev"][-1] > best_score:
                best_score = history["score"]["dev"][-1]
                saver(work_dir / "best.pth")

            saver(work_dir / "last.pth")
            plot_hist(history, work_dir)

    return 0