Exemple #1
0
def train_no_val(img_dir, model_dir, args):
    seed_everything(args.seed)

    start = time.time()
    get_current_time()

    save_dir = increment_path(os.path.join(model_dir, args.name))

    # settings
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # dataset
    dataset_module = getattr(import_module("dataset"), args.dataset)
    dataset = dataset_module(
        img_dir=img_dir,
        val_ratio=args.val_ratio,
    )
    num_classes = dataset.num_classes

    transform_module = getattr(import_module("dataset"), args.augmentation)
    transform = transform_module(mean=dataset.mean, std=dataset.std)

    dataset.set_transform(transform["train"])

    train_loader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        num_workers=2,
        shuffle=True,
        pin_memory=torch.cuda.is_available(),
        drop_last=True,
    )

    model_module = getattr(import_module("model"), args.model)
    model = model_module(num_classes=num_classes).to(device)

    model = torch.nn.DataParallel(model)

    criterion = create_criterion(args.criterion)

    optimizer = None
    if args.optimizer == "AdamP":
        optimizer = AdamP(model.parameters())
    else:
        opt_module = getattr(import_module("torch.optim"), args.optimizer)
        optimizer = opt_module(
            model.parameters(),
            # filter(lambda p: p.requires_grad, model.parameters()),
            lr=args.lr,
            # weight_decay=5e-4,
        )

    # scheduler = StepLR(optimizer, args.lr_decay_step, gamma=0.5)

    logger = SummaryWriter(log_dir=save_dir)

    best_val_acc = 0
    best_val_loss = np.inf
    best_val_f1 = 0

    for epoch in range(args.epochs):
        model.train()
        train_loss = 0
        train_acc = 0
        train_f1 = 0
        for i, data in enumerate(tqdm(train_loader)):
            imgs, labels = data
            imgs = imgs.float().to(device)
            labels = labels.long().to(device)

            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            preds = torch.argmax(outputs, 1)
            acc = (preds == labels).sum().item() / len(imgs)
            t_f1_score = f1_score(
                labels.cpu().detach().numpy(),
                preds.cpu().detach().numpy(),
                average="macro",
            )

            train_loss += loss
            train_acc += acc
            train_f1 += t_f1_score

            if (i + 1) % args.log_interval == 0:
                train_loss /= args.log_interval
                train_acc /= args.log_interval
                train_f1 /= args.log_interval
                current_lr = get_lr(optimizer)
                print(
                    f"Epoch[{epoch + 1}/{args.epochs}]({i + 1}/{len(train_loader)}) || trainin_loss {train_loss:.4f} || training acc {train_acc:.4f} || train f1_score {train_f1:.4f} || lr {current_lr}"
                )

                logger.add_scalar("Train/loss", train_loss,
                                  epoch * len(train_loader) + i)
                logger.add_scalar("Train/accuracy", train_acc,
                                  epoch * len(train_loader) + i)
                logger.add_scalar("Train/F1-score", train_f1,
                                  epoch * len(train_loader) + i)

                train_loss = 0
                train_acc = 0
                train_f1 = 0

    torch.save(model.module.state_dict(), f"{save_dir}/last.pth")

    # How much time training taken
    times = time.time() - start
    minute, sec = divmod(times, 60)
    print(f"Finish Training! Taken time is {minute} minutes {sec} seconds")
Exemple #2
0
def train(img_dir, model_dir, args):
    seed_everything(args.seed)

    start = time.time()
    get_current_time()

    save_dir = increment_path(os.path.join(model_dir, args.name))

    # settings
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # dataset
    dataset_module = getattr(import_module("dataset"), args.dataset)
    dataset = dataset_module(
        img_dir=img_dir,
        val_ratio=args.val_ratio,
    )
    num_classes = dataset.num_classes

    transform_module = getattr(import_module("dataset"), args.augmentation)
    transform = transform_module(mean=dataset.mean, std=dataset.std)

    train_dataset, val_dataset = dataset.split_dataset()
    train_dataset.dataset.set_transform(transform["train"])
    val_dataset.dataset.set_transform(transform["val"])

    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        num_workers=2,
        shuffle=True,
        pin_memory=torch.cuda.is_available(),
        drop_last=True,
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=args.valid_batch_size,
        num_workers=2,
        shuffle=False,
        pin_memory=torch.cuda.is_available(),
        drop_last=True,
    )

    model_module = getattr(import_module("model"), args.model)
    model = model_module(num_classes=num_classes).to(device)

    model = torch.nn.DataParallel(model)

    criterion = create_criterion(args.criterion)

    optimizer = None
    if args.optimizer == "AdamP":
        optimizer = AdamP(model.parameters(), lr=args.lr)
    else:
        opt_module = getattr(import_module("torch.optim"), args.optimizer)
        optimizer = opt_module(
            model.parameters(),
            # filter(lambda p: p.requires_grad, model.parameters()),
            lr=args.lr,
            # weight_decay=5e-4,
        )

    # scheduler = StepLR(optimizer, args.lr_decay_step, gamma=0.5)

    logger = SummaryWriter(log_dir=save_dir)

    best_val_acc = 0
    best_val_loss = np.inf
    best_val_f1 = 0

    for epoch in range(args.epochs):
        model.train()
        train_loss = 0
        train_acc = 0
        train_f1 = 0
        for i, data in enumerate(tqdm(train_loader)):
            imgs, labels = data
            imgs = imgs.float().to(device)
            labels = labels.long().to(device)

            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            preds = torch.argmax(outputs, 1)
            acc = (preds == labels).sum().item() / len(imgs)
            t_f1_score = f1_score(
                labels.cpu().detach().numpy(),
                preds.cpu().detach().numpy(),
                average="macro",
            )

            train_loss += loss
            train_acc += acc
            train_f1 += t_f1_score

            if (i + 1) % args.log_interval == 0:
                train_loss /= args.log_interval
                train_acc /= args.log_interval
                train_f1 /= args.log_interval
                current_lr = get_lr(optimizer)
                print(
                    f"Epoch[{epoch + 1}/{args.epochs}]({i + 1}/{len(train_loader)}) || trainin_loss {train_loss:.4f} || training acc {train_acc:.4f} || train f1_score {train_f1:.4f} || lr {current_lr}"
                )

                logger.add_scalar("Train/loss", train_loss,
                                  epoch * len(train_loader) + i)
                logger.add_scalar("Train/accuracy", train_acc,
                                  epoch * len(train_loader) + i)
                logger.add_scalar("Train/F1-score", train_f1,
                                  epoch * len(train_loader) + i)

                train_loss = 0
                train_acc = 0
                train_f1 = 0

        # scheduler.step()

        # training은 1 epoch이 끝나야 완료된 것
        # 학습이 끝난 각 epoch에서 최고의 score를 가진 것을 저장하는 것
        with torch.no_grad():
            print("Validation step---------------------")
            model.eval()
            val_loss_items = []
            val_acc_items = []
            val_f1_items = []

            for data in tqdm(val_loader):
                imgs, labels = data
                imgs = imgs.float().to(device)
                labels = labels.long().to(device)

                outputs = model(imgs)
                preds = torch.argmax(outputs, 1)

                loss = criterion(outputs, labels).item()
                acc = (labels == preds).sum().item()
                val_f1 = f1_score(
                    labels.cpu().detach().numpy(),
                    preds.cpu().detach().numpy(),
                    average="macro",
                )

                val_loss_items.append(loss)
                val_acc_items.append(acc)
                val_f1_items.append(val_f1)

            val_loss = np.sum(val_loss_items) / len(val_loader)
            val_acc = np.sum(val_acc_items) / len(val_dataset)
            val_f1 = np.sum(val_f1_items) / len(val_loader)

            print(
                f"val_loader: {len(val_loader)} | val_dataset: {len(val_dataset)}"
            )

            best_val_loss = min(best_val_loss, val_loss)
            best_val_f1 = max(val_f1, best_val_f1)
            best_val_acc = max(val_acc, best_val_acc)

            # if val_acc > best_val_acc:
            # print(
            #     f"New best model for val acc: {val_acc:4.2%}! saving the best model..."
            # )
            #     torch.save(model.module.state_dict(), f"{save_dir}/best.pth")
            #     best_val_acc = val_acc

            if val_f1 > best_val_f1:
                print(
                    f"New best model for val f1: {val_f1:.4f}! saving the best model..."
                )
                torch.save(model.module.state_dict(), f"{save_dir}/best.pth")
                best_val_f1 = val_f1

            # TODO: last model 저장이 여기 위치가 맞나 ??
            # torch.save(model.module.state_dict(), f"{save_dir}/last.pth")
            print(
                f"[Val] acc: {val_acc:.4f}, loss: {val_loss:.4f} || best acc: {best_val_acc:.4f}, best loss: {best_val_loss:.4f}"
            )

            logger.add_scalar("Val/loss", val_loss, epoch)
            logger.add_scalar("Val/accuracy", val_acc, epoch)
            logger.add_scalar("Val/f1-score", val_f1, epoch)
            print()

    torch.save(model.module.state_dict(), f"{save_dir}/last.pth")

    # How much time training taken
    times = time.time() - start
    minute, sec = divmod(times, 60)
    print(f"Finish Training! Taken time is {minute} minutes {sec} seconds")
def train(data_dir, model_dir, args):
    seed_everything(args.seed)

    save_dir = increment_path(os.path.join(model_dir, args.name))

    # -- settings
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    # -- dataset
    dataset_module = getattr(import_module("dataset"),
                             args.dataset)  # MaskBaseDataset
    dataset = dataset_module(data_dir=data_dir, val_ratio=args.val_ratio)
    num_classes = dataset.num_classes  # 18

    # -- augmentation
    transform_module = getattr(import_module("dataset"),
                               args.augmentation)  # default: BaseAugmentation
    transform = transform_module(
        resize=args.resize,
        mean=dataset.mean,
        std=dataset.std,
    )
    dataset.set_transform(transform)

    # -- data_loader
    train_set, val_set = dataset.split_dataset()

    train_loader = DataLoader(train_set,
                              batch_size=args.batch_size,
                              num_workers=8,
                              shuffle=True,
                              pin_memory=use_cuda,
                              drop_last=True)

    val_loader = DataLoader(val_set,
                            batch_size=args.batch_size,
                            num_workers=8,
                            shuffle=False,
                            pin_memory=use_cuda,
                            drop_last=True)

    # -- model
    models = []
    model_module_gender = getattr(import_module("model"),
                                  args.model_gender)  # default: BaseModel
    model_gender = model_module_gender(num_classes=args.num_classes_gender,
                                       grad_point=args.grad_point).to(device)
    model_gender = torch.nn.DataParallel(model_gender)

    # -- loss & metric

    criterion_gender = create_criterion(
        args.criterion_gender, classes=args.num_classes_gender)  # default: f1
    if args.optimizer == "AdamP":
        optimizer_gender = AdamP(filter(lambda p: p.requires_grad,
                                        model_gender.parameters()),
                                 lr=args.lr,
                                 weight_decay=5e-4)
    else:
        opt_module = getattr(import_module('torch.optim'),
                             args.optimizer)  # default: Adam
        optimizer_gender = opt_module(filter(lambda p: p.requires_grad,
                                             model_gender.parameters()),
                                      lr=args.lr,
                                      weight_decay=5e-4)
    scheduler_gender = StepLR(optimizer_gender, args.lr_decay_step, gamma=0.5)

    # -- logging
    logger_gender = SummaryWriter(log_dir=os.path.join(save_dir, 'gender'))
    with open(Path(save_dir) / 'gender' / 'config.json', 'w',
              encoding='utf-8') as f:
        json.dump(vars(args), f, ensure_ascii=False, indent=4)

    best_val_acc_gender = 0
    best_val_loss_gender = np.inf
    for epoch in range(args.epochs):
        # train loop
        model_gender.train()
        loss_value_gender = 0
        matches_gender = 0
        for idx, train_batch in enumerate(train_loader):
            inputs, labels_mask, labels_gender, labels_age = train_batch
            inputs = inputs.to(device)
            labels_gender = labels_gender.to(device)

            optimizer_gender.zero_grad()

            outs_gender = model_gender(inputs)
            preds_gender = torch.argmax(outs_gender, dim=-1)
            loss_gender = criterion_gender(outs_gender, labels_gender)

            loss_gender.backward()
            optimizer_gender.step()

            loss_value_gender += loss_gender.item()
            matches_gender += (preds_gender == labels_gender).sum().item()
            if (idx + 1) % args.log_interval == 0:
                train_loss_gender = loss_value_gender / args.log_interval
                train_acc_gender = matches_gender / args.batch_size / args.log_interval
                current_lr_gender = get_lr(optimizer_gender)
                print(
                    f"Epoch[{epoch}/{args.epochs}]({idx + 1}/{len(train_loader)}) || "
                    f"training loss {train_loss_gender:4.4} || training accuracy {train_acc_gender:4.2%} || lr {current_lr_gender}"
                )
                logger_gender.add_scalar("Train/loss", train_loss_gender,
                                         epoch * len(train_loader) + idx)
                logger_gender.add_scalar("Train/accuracy", train_acc_gender,
                                         epoch * len(train_loader) + idx)

                loss_value_gender = 0
                matches_gender = 0

        scheduler_gender.step()

        #val loop
        with torch.no_grad():
            print("Calculating validation results...")
            model_gender.eval()
            val_loss_items_gender = []
            val_acc_items_gender = []
            figure = None
            for val_batch in val_loader:
                inputs, labels_mask, labels_gender, labels_age = val_batch
                inputs = inputs.to(device)
                labels_gender = labels_gender.to(device)

                outs_gender = model_gender(inputs)
                preds_gender = torch.argmax(outs_gender, dim=-1)

                loss_item_gender = criterion_gender(outs_gender,
                                                    labels_gender).item()
                acc_item_gender = (labels_gender == preds_gender).sum().item()
                val_loss_items_gender.append(loss_item_gender)
                val_acc_items_gender.append(acc_item_gender)

                if figure is None:
                    # inputs_np = torch.clone(inputs).detach().cpu().permute(0, 2, 3, 1).numpy()
                    inputs_np = torch.clone(inputs).detach().cpu()
                    inputs_np = inputs_np.permute(0, 2, 3, 1).numpy()
                    inputs_np = dataset_module.denormalize_image(
                        inputs_np, dataset.mean, dataset.std)
                    figure = grid_image(
                        inputs_np, labels_mask, preds_gender,
                        args.dataset != "MaskSplitByProfileDataset")
                    plt.show()

            val_loss_gender = np.sum(val_loss_items_gender) / len(val_loader)
            val_acc_gender = np.sum(val_acc_items_gender) / len(val_set)
            if val_loss_gender < best_val_loss_gender or val_acc_gender > best_val_acc_gender:
                save_model(model_gender, epoch,
                           val_loss_gender, val_acc_gender,
                           os.path.join(save_dir, "gender"), args.model_gender)
                if val_loss_gender < best_val_loss_gender and val_acc_gender > best_val_acc_gender:
                    print(
                        f"New best model_gender for val acc and val loss : {val_acc_gender:4.2%} {val_loss_gender:4.2}! saving the best model_gender.."
                    )
                    best_val_loss_gender = val_loss_gender
                    best_val_acc_gender = val_acc_gender
                elif val_loss_gender < best_val_loss_gender:
                    print(
                        f"New best model_gender for val loss : {val_loss_gender:4.2}! saving the best model_gender.."
                    )
                    best_val_loss_gender = val_loss_gender
                elif val_acc_gender > best_val_acc_gender:
                    print(
                        f"New best model_gender for val accuracy : {val_acc_gender:4.2%}! saving the best model_gender.."
                    )
                    best_val_acc_gender = val_acc_gender

            print(
                f"[Val] acc: {val_acc_gender:4.2%}, loss: {val_loss_gender:4.2} || "
                f"best acc: {best_val_acc_gender:4.2%}, best loss: {best_val_loss_gender:4.2}"
            )
            logger_gender.add_scalar("Val/loss", val_loss_gender, epoch)
            logger_gender.add_scalar("Val/accuracy", val_acc_gender, epoch)
            logger_gender.add_figure("results", figure, epoch)
            print()
Exemple #4
0
total_iteration_per_epoch = int(np.ceil(len(train_dataset) / batch_size))

for epoch in range(1, total_epoch + 1):
    model.train()
    for itereation, (input, target) in enumerate(train_loader):
        images = input.to(device)
        labels = target.to(device)

        # Forward pass
        outputs = model(images)
        loss = CEloss(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(loss)
    scheduler.step()
    #if epoch % 2 == 0:
    torch.save(model.state_dict(),
               model_weight_save_path + 'model_' + str(epoch) + ".pt")
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for input, target in test_loader:
            images = input.to(device)
            labels = target.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += len(labels)
def pseudo_labeling(num_epochs, model, data_loader, val_loader,
                    unlabeled_loader, device, val_every, file_name):
    # Instead of using current epoch we use a "step" variable to calculate alpha_weight
    # This helps the model converge faster
    from torch.optim.swa_utils import AveragedModel, SWALR
    from segmentation_models_pytorch.losses import SoftCrossEntropyLoss, JaccardLoss
    from adamp import AdamP

    criterion = [
        SoftCrossEntropyLoss(smooth_factor=0.1),
        JaccardLoss('multiclass', classes=12)
    ]
    optimizer = AdamP(params=model.parameters(), lr=0.0001, weight_decay=1e-6)
    swa_scheduler = SWALR(optimizer, swa_lr=0.0001)
    swa_model = AveragedModel(model)
    optimizer = Lookahead(optimizer, la_alpha=0.5)

    step = 100
    size = 256
    best_mIoU = 0
    model.train()
    print('Start Pseudo-Labeling..')
    for epoch in range(num_epochs):
        hist = np.zeros((12, 12))
        for batch_idx, (imgs, image_infos) in enumerate(unlabeled_loader):

            # Forward Pass to get the pseudo labels
            # --------------------------------------------- test(unlabelse)를 모델에 통과
            model.eval()
            outs = model(torch.stack(imgs).to(device))
            oms = torch.argmax(outs.squeeze(), dim=1).detach().cpu().numpy()
            oms = torch.Tensor(oms)
            oms = oms.long()
            oms = oms.to(device)

            # --------------------------------------------- 학습

            model.train()
            # Now calculate the unlabeled loss using the pseudo label
            imgs = torch.stack(imgs)
            imgs = imgs.to(device)
            # preds_array = preds_array.to(device)

            output = model(imgs)
            loss = 0
            for each in criterion:
                loss += each(output, oms)

            unlabeled_loss = alpha_weight(step) * loss

            # Backpropogate
            optimizer.zero_grad()
            unlabeled_loss.backward()
            optimizer.step()
            output = torch.argmax(output.squeeze(),
                                  dim=1).detach().cpu().numpy()
            hist = add_hist(hist,
                            oms.detach().cpu().numpy(),
                            output,
                            n_class=12)

            if (batch_idx + 1) % 25 == 0:
                acc, acc_cls, mIoU, fwavacc = label_accuracy_score(hist)
                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, mIoU:{:.4f}'.
                      format(epoch + 1, num_epochs, batch_idx + 1,
                             len(unlabeled_loader), unlabeled_loss.item(),
                             mIoU))
            # For every 50 batches train one epoch on labeled data
            # 50배치마다 라벨데이터를 1 epoch학습
            if batch_idx % 50 == 0:

                # Normal training procedure
                for batch_idx, (images, masks, _) in enumerate(data_loader):
                    labeled_loss = 0
                    images = torch.stack(images)
                    # (batch, channel, height, width)
                    masks = torch.stack(masks).long()

                    # gpu 연산을 위해 device 할당
                    images, masks = images.to(device), masks.to(device)

                    output = model(images)

                    for each in criterion:
                        labeled_loss += each(output, masks)

                    optimizer.zero_grad()
                    labeled_loss.backward()
                    optimizer.step()

                # Now we increment step by 1
                step += 1

        if (epoch + 1) % val_every == 0:
            avrg_loss, val_mIoU = validation(epoch + 1, model, val_loader,
                                             criterion, device)
            if val_mIoU > best_mIoU:
                print('Best performance at epoch: {}'.format(epoch + 1))
                print('Save model in', saved_dir)
                best_mIoU = val_mIoU
                save_model(model, file_name=file_name)

        model.train()

        if epoch > 3:
            swa_model.update_parameters(model)
            swa_scheduler.step()
def train(data_dir, model_dir, args):
    seed_everything(args.seed)
    # args.__dict__ == vars(args)

    save_dir = increment_path(os.path.join(model_dir, args.name))

    # -- settings
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    # -- dataset
    dataset_module = getattr(import_module("dataset"),
                             args.dataset)  # MaskBaseDataset
    dataset = dataset_module(data_dir=data_dir, val_ratio=args.val_ratio)
    num_classes = dataset.num_classes  # 18

    # -- augmentation
    transform_module = getattr(import_module("dataset"),
                               args.augmentation)  # default: BaseAugmentation
    transform = transform_module(
        resize=args.resize,
        mean=dataset.mean,
        std=dataset.std,
    )
    dataset.set_transform(transform)

    # -- data_loader
    train_set, val_set = dataset.split_dataset()

    train_loader = DataLoader(train_set,
                              batch_size=args.batch_size,
                              num_workers=8,
                              shuffle=True,
                              pin_memory=use_cuda,
                              drop_last=True)

    val_loader = DataLoader(val_set,
                            batch_size=args.batch_size,
                            num_workers=8,
                            shuffle=False,
                            pin_memory=use_cuda,
                            drop_last=True)

    # -- model
    model_module = getattr(import_module("model"),
                           args.model)  # default: BaseModel
    model = model_module(num_classes=num_classes,
                         grad_point=args.grad_point).to(device)
    model = torch.nn.DataParallel(model)
    # if want model train begin from args.continue_epoch checkpoint.
    if args.continue_train:
        try_dir = find_dir_try(args.continue_try_num, model_dir,
                               args.continue_name)
        epoch_dir = find_dir_epoch(args.continue_epoch, try_dir)
        model.load_state_dict(torch.load(epoch_dir))

    # -- loss & metric
    if args.criterion == "cross_entropy":
        criterion = create_criterion(args.criterion)  # default: cross_entropy
    else:
        criterion = create_criterion(
            args.criterion, classes=num_classes)  # default: cross_entropy
    if args.optimizer == "AdamP":
        optimizer = AdamP(filter(lambda p: p.requires_grad,
                                 model.parameters()),
                          lr=args.lr,
                          weight_decay=5e-4)
    else:
        opt_module = getattr(import_module('torch.optim'),
                             args.optimizer)  # default: Adam
        optimizer = opt_module(filter(lambda p: p.requires_grad,
                                      model.parameters()),
                               lr=args.lr,
                               weight_decay=5e-4)
    scheduler = StepLR(optimizer, args.lr_decay_step, gamma=0.5)

    # -- logging
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
    with open(Path(save_dir) / 'config.json', 'w', encoding='utf-8') as f:
        json.dump(vars(args), f, ensure_ascii=False, indent=4)

    best_val_acc = 0
    best_val_loss = np.inf
    for epoch in range(args.epochs):
        # train loop
        model.train()
        loss_value = 0
        matches = 0
        for idx, train_batch in enumerate(train_loader):
            inputs, labels = train_batch
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            outs = model(inputs)
            preds = torch.argmax(outs, dim=-1)
            loss = criterion(outs, labels)

            loss.backward()
            optimizer.step()

            loss_value += loss.item()
            matches += (preds == labels).sum().item()
            if (idx + 1) % args.log_interval == 0:
                train_loss = loss_value / args.log_interval
                train_acc = matches / args.batch_size / args.log_interval
                current_lr = get_lr(optimizer)
                print(
                    f"Epoch[{epoch}/{args.epochs}]({idx + 1}/{len(train_loader)}) || "
                    f"training loss {train_loss:4.4} || training accuracy {train_acc:4.2%} || lr {current_lr}"
                )

                loss_value = 0
                matches = 0

        scheduler.step()

        #val loop
        with torch.no_grad():
            print("Calculating validation results...")
            model.eval()
            val_loss_items = []
            val_acc_items = []
            figure = None
            for val_batch in val_loader:
                inputs, labels = val_batch
                inputs = inputs.to(device)
                labels = labels.to(device)

                outs = model(inputs)
                preds = torch.argmax(outs, dim=-1)

                loss_item = criterion(outs, labels).item()
                acc_item = (labels == preds).sum().item()
                val_loss_items.append(loss_item)
                val_acc_items.append(acc_item)

                if figure is None:
                    # inputs_np = torch.clone(inputs).detach().cpu().permute(0, 2, 3, 1).numpy()
                    inputs_np = torch.clone(inputs).detach().cpu()
                    inputs_np = inputs_np.permute(0, 2, 3, 1).numpy()
                    inputs_np = dataset_module.denormalize_image(
                        inputs_np, dataset.mean, dataset.std)
                    figure = grid_image(
                        inputs_np, labels, preds,
                        args.dataset != "MaskSplitByProfileDataset")
                    plt.show()

            val_loss = np.sum(val_loss_items) / len(val_loader)
            val_acc = np.sum(val_acc_items) / len(val_set)
            if val_loss < best_val_loss or val_acc > best_val_acc:
                save_model(model, epoch, val_loss, val_acc, save_dir,
                           args.model)
                if val_loss < best_val_loss and val_acc > best_val_acc:
                    print(
                        f"New best model for val acc and val loss : {val_acc:4.2%} {val_loss:4.2}! saving the best model.."
                    )
                    best_val_loss = val_loss
                    best_val_acc = val_acc
                elif val_loss < best_val_loss:
                    print(
                        f"New best model for val loss : {val_loss:4.2}! saving the best model.."
                    )
                    save_model(model, epoch, val_loss, val_acc, save_dir,
                               args.model)
                    best_val_loss = val_loss
                elif val_acc > best_val_acc:
                    print(
                        f"New best model for val accuracy : {val_acc:4.2%}! saving the best model.."
                    )
                    save_model(model, epoch, val_loss, val_acc, save_dir,
                               args.model)
                    best_val_acc = val_acc

            print(
                f"[Val] acc: {val_acc:4.2%}, loss: {val_loss:4.2} || "
                f"best acc: {best_val_acc:4.2%}, best loss: {best_val_loss:4.2}"
            )
            print()
Exemple #7
0
def train_model(config, wandb):

    seed_everything(config.seed)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    model_module = getattr(import_module("model"), config.model)
    model = model_module(num_classes=18).to(device)

    #model = torch.nn.DataParallel(model)

    ########  DataSet

    transform = DataAugmentation(type=config.transform)  #center_384_1
    dataset = MaskDataset(config.data_dir, transform=transform)

    len_valid_set = int(config.data_ratio * len(dataset))
    len_train_set = len(dataset) - len_valid_set
    dataloaders, batch_num = {}, {}

    train_dataset, valid_dataset = torch.utils.data.random_split(
        dataset, [len_train_set, len_valid_set])
    if config.random_split == 0:
        print("tbd")

    sampler = None

    if config.sampler == 'ImbalancedDatasetSampler':
        sampler = ImbalancedDatasetSampler(train_dataset)

    use_cuda = torch.cuda.is_available()

    dataloaders['train'] = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        sampler=sampler,
        shuffle=False,
        num_workers=4,
        pin_memory=use_cuda)

    dataloaders['valid'] = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=use_cuda)

    batch_num['train'], batch_num['valid'] = len(dataloaders['train']), len(
        dataloaders['valid'])

    #Loss
    criterion = create_criterion(config.criterion)

    #Optimizer
    optimizer = optim.SGD(model.parameters(), lr=config.lr, momentum=0.9)

    if config.optim == "AdamP":
        optimizer = AdamP(model.parameters(),
                          lr=config.lr,
                          betas=(0.9, 0.999),
                          weight_decay=config.weight_decay)
    elif config.optim == "AdamW":
        optimizer = optim.AdamW(model.parameters(),
                                lr=config.lr,
                                weight_decay=config.weight_decay)

    #Scheduler
    # Decay LR by a factor of 0.1 every 7 epochs
    #exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)
    if config.lr_scheduler == "cosine":
        print('cosine')
        Q = math.floor(len(train_dataset) / config.batch_size +
                       1) * config.epochs / 7
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=Q)
        #ConsineAnnealingWarmRestarts

    since = time.time()
    low_train = 0
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    train_loss, train_acc, valid_loss, valid_acc = [], [], [], []
    num_epochs = config.epochs
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()  # Set model to evaluate mode

            running_loss, running_corrects, num_cnt = 0.0, 0, 0
            runnnig_f1 = 0

            # Iterate over data.
            idx = 0
            for inputs, labels in dataloaders[phase]:
                idx += 1
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)
                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                    else:
                        runnnig_f1 += f1_score(labels.data.detach().cpu(),
                                               preds.detach().cpu(),
                                               average='macro')

                # statistics
                val_loss = loss.item() * inputs.size(0)
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                num_cnt += len(labels)
                if idx % 100 == 0:
                    _loss = loss.item() / config.batch_size
                    print(
                        f"Epoch[{epoch}/{config.epochs}]({idx}/{batch_num[phase]}) || "
                        f"{phase} loss {_loss:4.4} ")

            if phase == 'train':
                scheduler.step()

            epoch_loss = float(running_loss / num_cnt)
            epoch_acc = float(
                (running_corrects.double() / num_cnt).cpu() * 100)
            epoch_f1 = float(runnnig_f1 / num_cnt)
            if phase == 'train':
                train_loss.append(epoch_loss)
                train_acc.append(epoch_acc)
                if config.wandb:
                    wandb.log({"Train acc": epoch_acc})
            else:
                valid_loss.append(epoch_loss)
                valid_acc.append(epoch_acc)
                if config.wandb:
                    wandb.log({"Valid acc": epoch_acc})
                    wandb.log({"F1 Score": epoch_f1})

            print('{} Loss: {:.2f} Acc: {:.1f} f1 :{:.3f}'.format(
                phase, epoch_loss, epoch_acc, epoch_f1))

            # deep copy the model
            if phase == 'valid':
                if epoch_acc > best_acc:
                    best_idx = epoch
                    best_acc = epoch_acc
                    best_model_wts = copy.deepcopy(model.state_dict())
                    print('==> best model saved - %d / %.1f' %
                          (best_idx, best_acc))
                    low_train = 0
                elif epoch_acc < best_acc:
                    print('==> model finish')
                    low_train += 1

        if low_train > 0 and epoch > 4:
            break

        if phase == 'valid':
            if epoch_acc < 80:
                print('Stop valid is so low')
                break

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best valid Acc: %d - %.1f' % (best_idx, best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    #torch.save(model.state_dict(), 'mask_model.pt')
    torch.save(model.state_dict(), config.name + '.pt')
    print('model saved')
    if config.wandb:
        wandb.finish()
    return model, best_idx, best_acc, train_loss, train_acc, valid_loss, valid_acc
Exemple #8
0
def train(data_dir, model_dir, args):
    seed_everything(args.seed)

    s_dir = args.model + str(args.num_hidden_layers) + '-' + args.preprocess + '-epoch' + str(args.epochs) + \
            '-' + args.criterion + '-' + args.scheduler + '-' + args.optimizer + '-' + args.dataset + '-' + args.tokenize

    if args.name:
        s_dir += '-' + args.name
    save_dir = increment_path(os.path.join(model_dir, s_dir))

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    print("This notebook use [%s]." % (device))

    # load model and tokenizer
    MODEL_NAME = args.model
    if MODEL_NAME == "monologg/kobert":
        tokenizer = KoBertTokenizer.from_pretrained(MODEL_NAME)
    else:
        tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

    # load dataset
    dataset = load_data("/opt/ml/input/data/train/train.tsv")
    labels = dataset['label'].values

    # setting model hyperparameter
    bert_config = BertConfig.from_pretrained(MODEL_NAME)
    bert_config.num_labels = args.num_labels
    bert_config.num_hidden_layers = args.num_hidden_layers
    model = BertForSequenceClassification.from_pretrained(MODEL_NAME,
                                                          config=bert_config)
    model.dropout = nn.Dropout(p=args.drop)
    model.to(device)

    summary(model)

    # loss & optimizer
    if args.criterion == 'f1' or args.criterion == 'label_smoothing' or args.criterion == 'f1cross':
        criterion = create_criterion(args.criterion,
                                     classes=args.num_labels,
                                     smoothing=0.1)
    else:
        criterion = create_criterion(args.criterion)

    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.01
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]

    if args.optimizer == 'AdamP':
        optimizer = AdamP(filter(lambda p: p.requires_grad,
                                 model.parameters()),
                          lr=args.lr,
                          betas=(0.9, 0.999),
                          weight_decay=args.weight_decay)
    else:
        opt_module = getattr(import_module("torch.optim"),
                             args.optimizer)  # default: SGD
        optimizer = opt_module(
            optimizer_grouped_parameters,
            lr=args.lr,
        )

    # logging
    logger = SummaryWriter(log_dir=save_dir)
    with open(os.path.join(save_dir, 'config.json'), 'w',
              encoding='utf-8') as f:
        json.dump(vars(args), f, ensure_ascii=False, indent=4)

    set_neptune(save_dir, args)

    # preprocess dataset
    if args.preprocess != 'no':
        pre_module = getattr(import_module("preprocess"), args.preprocess)
        dataset = pre_module(dataset, model, tokenizer)

    # train, val split
    kfold = StratifiedKFold(n_splits=5)

    for train_idx, val_idx in kfold.split(dataset, labels):
        train_dataset, val_dataset = dataset.loc[train_idx], dataset.loc[
            val_idx]
        break

    tok_module = getattr(import_module("load_data"), args.tokenize)

    train_tokenized = tok_module(train_dataset,
                                 tokenizer,
                                 max_len=args.max_len)
    val_tokenized = tok_module(val_dataset, tokenizer, max_len=args.max_len)

    # make dataset for pytorch.
    RE_train_dataset = RE_Dataset(
        train_tokenized, train_dataset['label'].reset_index(drop='index'))
    RE_val_dataset = RE_Dataset(val_tokenized,
                                val_dataset['label'].reset_index(drop='index'))

    train_loader = DataLoader(
        RE_train_dataset,
        batch_size=args.batch_size,
        num_workers=4,
        shuffle=True,
        pin_memory=use_cuda,
    )

    val_loader = DataLoader(
        RE_val_dataset,
        batch_size=12,
        num_workers=1,
        shuffle=False,
        pin_memory=use_cuda,
    )

    if args.scheduler == 'cosine':
        scheduler = CosineAnnealingLR(optimizer, T_max=2, eta_min=1e-6)
    elif args.scheduler == 'reduce':
        scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=5)
    elif args.scheduler == 'step':
        scheduler = StepLR(optimizer, args.lr_decay_step, gamma=0.5)
    elif args.scheduler == 'cosine_warmup':
        t_total = len(train_loader) * args.epochs
        warmup_step = int(t_total * args.warmup_ratio)
        scheduler = get_cosine_schedule_with_warmup(
            optimizer,
            num_warmup_steps=warmup_step,
            num_training_steps=t_total)
    else:
        scheduler = None

    print("Training Start!!!")

    best_val_acc = 0
    best_val_loss = np.inf

    for epoch in range(args.epochs):
        # train loop
        model.train()

        train_loss, train_acc = AverageMeter(), AverageMeter()

        for idx, train_batch in enumerate(train_loader):
            optimizer.zero_grad()

            try:
                inputs, token_types, attention_mask, labels = train_batch.values(
                )
                inputs = inputs.to(device)
                token_types = token_types.to(device)
                attention_mask = attention_mask.to(device)
                labels = labels.to(device)
                outs = model(input_ids=inputs,
                             token_type_ids=token_types,
                             attention_mask=attention_mask)
            except:
                inputs, attention_mask, labels = train_batch.values()
                inputs = inputs.to(device)
                attention_mask = attention_mask.to(device)
                labels = labels.to(device)
                outs = model(input_ids=inputs, attention_mask=attention_mask)

            preds = torch.argmax(outs.logits, dim=-1)
            loss = criterion(outs.logits, labels)
            acc = (preds == labels).sum().item() / len(labels)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.7)
            optimizer.step()

            if scheduler:
                scheduler.step()

            neptune.log_metric('learning_rate', get_lr(optimizer))

            train_loss.update(loss.item(), len(labels))
            train_acc.update(acc, len(labels))

            if (idx + 1) % args.log_interval == 0:
                current_lr = get_lr(optimizer)
                print(
                    f"Epoch[{epoch + 1}/{args.epochs}]({idx + 1}/{len(train_loader)}) || "
                    f"training loss {train_loss.avg:.4f} || training accuracy {train_acc.avg:4.2%} || lr {current_lr}"
                )
                logger.add_scalar("Train/loss", train_loss.avg,
                                  epoch * len(train_loader) + idx)
                logger.add_scalar("Train/accuracy", train_acc.avg,
                                  epoch * len(train_loader) + idx)

        neptune.log_metric(f'Train_loss', train_loss.avg)
        neptune.log_metric(f'Train_avg', train_acc.avg)
        neptune.log_metric('learning_rate', current_lr)

        val_loss, val_acc = AverageMeter(), AverageMeter()
        # val loop
        with torch.no_grad():
            print("Calculating validation results...")
            model.eval()

            for val_batch in val_loader:
                try:
                    inputs, token_types, attention_mask, labels = val_batch.values(
                    )
                    inputs = inputs.to(device)
                    token_types = token_types.to(device)
                    attention_mask = attention_mask.to(device)
                    labels = labels.to(device)
                    outs = model(input_ids=inputs,
                                 token_type_ids=token_types,
                                 attention_mask=attention_mask)
                except:
                    inputs, attention_mask, labels = val_batch.values()
                    inputs = inputs.to(device)
                    attention_mask = attention_mask.to(device)
                    labels = labels.to(device)
                    outs = model(input_ids=inputs,
                                 attention_mask=attention_mask)

                preds = torch.argmax(outs.logits, dim=-1)
                loss = criterion(outs.logits, labels)
                acc = (preds == labels).sum().item() / len(labels)

                val_loss.update(loss.item(), len(labels))
                val_acc.update(acc, len(labels))

            if val_acc.avg > best_val_acc:
                print(
                    f"New best model for val acc : {val_acc.avg:4.2%}! saving the best model.."
                )
                torch.save(model.state_dict(), f"{save_dir}/best.pth")
                best_val_acc = val_acc.avg
                best_val_loss = min(best_val_loss, val_loss.avg)

            print(
                f"[Val] acc : {val_acc.avg:4.2%}, loss : {val_loss.avg:.4f} || "
                f"best acc : {best_val_acc:4.2%}, best loss : {best_val_loss:.4f}"
            )
            logger.add_scalar("Val/loss", val_loss.avg, epoch)
            logger.add_scalar("Val/accuracy", val_acc.avg, epoch)
            neptune.log_metric(f'Val_loss', val_loss.avg)
            neptune.log_metric(f'Val_avg', val_acc.avg)

            print()
Exemple #9
0
        model_v.train()
        model_b.train()
        for idx, batch in enumerate(tqdm(train_loader)):
            image, mnist_label, cifar_label = batch[0].to(device), batch[1].to(
                device), batch[2].to(device)
            model_f_out, model_g_out = model_f(image), model_g(image)
            f_logits, f_feats = model_f_out[0], model_f_out[1]
            _, g_feats = model_g_out[0], model_g_out[1]

            ### Update model_f
            model_f.zero_grad()
            loss_f = criterionCE(
                f_logits, mnist_label
            ) + args.reg_lambda[0] * criterionHSIC.forward(f_feats, g_feats)
            loss_f.backward()
            optimizer_f.step()

            model_f_out, model_g_out = model_f(image), model_g(image)
            _, f_feats = model_f_out[0], model_f_out[1]
            g_logits, g_feats = model_g_out[0], model_g_out[1]

            ### Update model_g
            model_g.zero_grad()
            loss_g = criterionCE(
                g_logits, mnist_label
            ) - args.reg_lambda[1] * criterionHSIC.forward(f_feats, g_feats)
            loss_g.backward()
            optimizer_g.step()

            model_v_out = model_v(image)
            v_logits, _ = model_v_out[0], model_v_out[1]
Exemple #10
0
def main(root_dir=cfg.root_dir,batch_size=cfg.batch_size,lr=cfg.lr,model_name=cfg.model_name,\
    weight_decay=cfg.weight_decay,n_epochs=cfg.n_epochs,log_dir=cfg.log_dir,k_fold=cfg.k_fold,\
    patience=cfg.patience,steplr_step_size=cfg.steplr_step_size,steplr_gamma=cfg.steplr_gamma,\
    save_dir=cfg.save_dir):
    
    #Seed 
    torch.manual_seed(42)
    np.random.seed(42)

    # Available CUDA
    use_cuda= True if torch.cuda.is_available() else False
    device=torch.device('cuda:0' if use_cuda else 'cpu') # CPU or GPU


    #Transforms
    train_transforms=A.Compose([
        A.CenterCrop(230,230),
        A.RandomCrop(224,224),
        A.ElasticTransform(),         
        A.IAAPerspective(),
        A.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ToTensorV2(),
    ])

    valid_transforms=A.Compose([
        A.CenterCrop(224,224),
        A.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ToTensorV2()

    ])

    #Dataset
    dataset=CustomDataset(root_dir,transforms=train_transforms)

    
    # Stratified K-fold Cross Validation
    kf=StratifiedKFold(n_splits=k_fold,shuffle=True,random_state=42)

    # Tensorboard Writer
    writer=SummaryWriter(log_dir)


    for n_fold, (train_indices,test_indices) in enumerate(kf.split(dataset.x_data,dataset.y_data),start=1):
        print(f'=====Stratified {k_fold}-fold : {n_fold}=====')

        

        #Dataloader
        train_sampler=torch.utils.data.SubsetRandomSampler(train_indices)
        valid_sampler=torch.utils.data.SubsetRandomSampler(test_indices)
        train_loader=torch.utils.data.DataLoader(dataset,batch_size=batch_size,sampler=train_sampler)
        valid_loader=torch.utils.data.DataLoader(dataset,batch_size=batch_size,sampler=valid_sampler)

        #Model,Criterion,Optimizer,scheduler,regularization
        model=Net(model_name).to(device)
        criterion=nn.CrossEntropyLoss().to(device)
        optimizer=AdamP(model.parameters(),lr=lr)
        scheduler_steplr=optim.lr_scheduler.StepLR(optimizer,step_size=steplr_step_size,gamma=steplr_gamma)
        regularization=EarlyStopping(patience=patience)

        
        
        #Train
        model.train()
        for epoch in range(n_epochs):
            print(f'Learning rate : {optimizer.param_groups[0]["lr"]}')
            train_metric_monitor=MetricMonitor()
            train_stream=tqdm(train_loader)

            for batch_idx,sample in enumerate(train_stream,start=1):
                img,label=sample['img'].to(device),sample['label'].to(device)
                output=model(img)

                optimizer.zero_grad()
                loss=criterion(output,label)
                _,preds=torch.max(output,dim=1)
                correct=torch.sum(preds==label.data)


                train_metric_monitor.update('Loss',loss.item())
                train_metric_monitor.update('Accuracy',100.*correct/len(img))

                loss.backward()
                optimizer.step()
                train_stream.set_description(
                    f'Train epoch : {epoch} | {train_metric_monitor}'
                )


            # Valid
            valid_metric_monitor=MetricMonitor()
            valid_stream=tqdm(valid_loader)
            model.eval()
            
            with torch.no_grad():
                for batch_idx,sample in enumerate(valid_stream):

                    img,label=sample['img'].to(device),sample['label'].to(device)
                    output=model(img) 
                    
                    loss=criterion(output,label)
                    _,preds=torch.max(output,dim=1)
                    correct=torch.sum(preds==label.data)

                    valid_metric_monitor.update('Loss',loss.item())
                    valid_metric_monitor.update('Accuracy',100.*correct/len(img))
                    valid_stream.set_description(
                    f'Test epoch : {epoch} | {valid_metric_moniotr}'
                    )
            
            # Tensorboard
            train_loss=train_metric_monitor.metrics['Loss']['avg']
            train_accuracy=train_metric_monitor.metrics['Accuracy']['avg']
            valid_loss=valid_metric_monitor.metrics['Loss']['avg']
            valid_accuracy=valid_metric_monitor.metrics['Accuracy']['avg']
            
            writer.add_scalars(f'{n_fold}-fold Loss',{'train':train_loss,'valid':valid_loss},epoch)
            writer.add_scalars(f'{n_fold}-fold Accuracy',{'train':train_accuracy,'valid':valid_accuracy},epoch)
            
            #Save Model
            if regularization.early_stopping:
                break
            
            regularization.path=os.path.join(save_dir,f'{n_fold}_fold_{epoch}_epoch.pt')
            regularization(val_loss=valid_loss,model=model)
            scheduler_steplr.step()
    writer.close()
Exemple #11
0
def train(data_dir, model_dir, args):
    seed_everything(args.seed)

    save_dir = increment_path(os.path.join(model_dir, args.name))

    # -- settings
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    info = pd.read_csv('/opt/ml/input/data/train/train.csv')

    info['gender_age'] = info.apply(
        lambda x: convert_gender_age(x.gender, x.age), axis=1)
    n_fold = int(1 / args.val_ratio)

    skf = StratifiedKFold(n_splits=n_fold, shuffle=True)
    info.loc[:, 'fold'] = 0
    for fold_num, (train_index, val_index) in enumerate(
            skf.split(X=info.index, y=info.gender_age.values)):
        info.loc[info.iloc[val_index].index, 'fold'] = fold_num

    fold_idx = 0
    train = info[info.fold != fold_idx].reset_index(drop=True)
    val = info[info.fold == fold_idx].reset_index(drop=True)

    # -- dataset
    dataset_module = getattr(import_module("dataset"),
                             args.dataset)  # default: MaskDataset

    # -- augmentation
    train_transform_module = getattr(
        import_module("dataset"),
        args.train_augmentation)  # default: BaseAugmentation
    val_transform_module = getattr(
        import_module("dataset"),
        args.val_augmentation)  # default: BaseAugmentation

    train_transform = train_transform_module(
        resize=args.resize,
        mean=MEAN,
        std=STD,
    )
    val_transform = val_transform_module(
        resize=args.resize,
        mean=MEAN,
        std=STD,
    )

    print(train_transform.transform, val_transform.transform)

    if args.dataset == 'MaskDataset' or args.dataset == 'MaskOldDataset':
        if args.dataset == 'MaskOldDataset':
            old_transform_module = getattr(import_module('dataset'),
                                           args.old_augmentation)

            old_transform = old_transform_module(
                resize=args.resize,
                mean=MEAN,
                std=STD,
            )
            train_dataset = dataset_module(data_dir, train, train_transform,
                                           old_transform)
            if args.val_old:
                val_dataset = dataset_module(data_dir, val, val_transform,
                                             old_transform)
            else:
                val_dataset = dataset_module(data_dir, val, val_transform)
        else:
            train_dataset = dataset_module(data_dir, train, train_transform)
            val_dataset = dataset_module(data_dir, val, val_transform)
    else:
        dataset = dataset_module(data_dir=data_dir, )

        # dataset.set_transform(transform)
        # -- data_loader
        train_set, val_set = dataset.split_dataset()
        if args.val_old:
            old_transform_module = getattr(import_module('dataset'),
                                           args.old_augmentation)

            old_transform = old_transform_module(
                resize=args.resize,
                mean=MEAN,
                std=STD,
            )

            train_dataset = DatasetFromSubset(train_set,
                                              transform=train_transform,
                                              old_transform=old_transform)
        else:
            train_dataset = DatasetFromSubset(train_set,
                                              transform=train_transform)
        val_dataset = DatasetFromSubset(val_set, transform=val_transform)

    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        num_workers=4,
        shuffle=True,
        pin_memory=use_cuda,
        #drop_last=True,
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=args.valid_batch_size,
        num_workers=1,
        shuffle=False,
        pin_memory=use_cuda,
        #drop_last=True,
    )

    # -- model
    model_module = getattr(import_module("model"),
                           args.model)  # default: BaseModel
    model = model_module(num_classes=args.num_classes).to(device)
    model = torch.nn.DataParallel(model)

    # -- loss & metric
    if args.criterion == 'f1' or args.criterion == 'label_smoothing' or args.criterion == 'f1cross':
        criterion = create_criterion(args.criterion, classes=args.num_classes)
    else:
        criterion = create_criterion(args.criterion)

    if args.optimizer == 'AdamP':
        optimizer = AdamP(filter(lambda p: p.requires_grad,
                                 model.parameters()),
                          lr=args.lr,
                          betas=(0.9, 0.999),
                          weight_decay=args.weight_decay)
    else:
        opt_module = getattr(import_module("torch.optim"),
                             args.optimizer)  # default: SGD
        optimizer = opt_module(filter(lambda p: p.requires_grad,
                                      model.parameters()),
                               lr=args.lr,
                               weight_decay=args.weight_decay)
    if args.scheduler == 'cosine':
        scheduler = CosineAnnealingLR(optimizer, T_max=2, eta_min=1e-6)
    elif args.scheduler == 'reduce':
        scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=5)
    elif args.scheduler == 'step':
        scheduler = StepLR(optimizer, args.lr_decay_step, gamma=0.5)
    else:
        scheduler = None

    # -- logging
    logger = SummaryWriter(log_dir=save_dir)
    with open(os.path.join(save_dir, 'config.json'), 'w',
              encoding='utf-8') as f:
        json.dump(vars(args), f, ensure_ascii=False, indent=4)

    best_val_acc = 0
    best_val_loss = np.inf
    print("This notebook use [%s]." % (device))

    early_stopping = EarlyStopping(patience=args.patience, verbose=True)

    for epoch in range(args.epochs):
        # train loop
        model.train()
        loss_value = 0
        matches = 0

        train_loss, train_acc = AverageMeter(), AverageMeter()

        for idx, train_batch in enumerate(train_loader):
            inputs, labels = train_batch
            if args.dataset == 'MaskDataset' or args.dataset == 'MaskOldDataset':
                labels = labels.argmax(dim=-1)
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            if args.mixup and (idx + epoch) % 2:
                inputs, labels_a, labels_b, lam = mixup_data(inputs,
                                                             labels,
                                                             alpha=1.0)
                inputs, labels_a, labels_b = map(Variable,
                                                 (inputs, labels_a, labels_b))

                outputs = model(inputs)
                loss = mixup_criterion(criterion, outputs, labels_a, labels_b,
                                       lam)

                _, predicted = torch.max(outputs.data, 1)

                correct = (
                    lam * predicted.eq(labels_a.data).cpu().sum().float() +
                    (1 - lam) *
                    predicted.eq(labels_b.data).cpu().sum().float())
                acc = correct / len(labels)

            else:
                outs = model(inputs)
                preds = torch.argmax(outs, dim=-1)
                loss = criterion(outs, labels)
                acc = (preds == labels).sum().item() / len(labels)

            loss.backward()
            optimizer.step()

            #loss_value += loss.item()
            #matches += (preds == labels).sum().item()

            train_loss.update(loss.item(), len(labels))
            train_acc.update(acc, len(labels))

            if (idx + 1) % args.log_interval == 0:
                #train_loss = loss_value / args.log_interval
                #train_acc = matches / args.batch_size / args.log_interval
                train_f1_acc = f1_score(preds.cpu().detach().type(torch.int),
                                        labels.cpu().detach().type(torch.int),
                                        average='macro')
                current_lr = get_lr(optimizer)
                print(
                    f"Epoch[{epoch + 1}/{args.epochs}]({idx + 1}/{len(train_loader)}) || "
                    f"training loss {train_loss.avg:.4f} || training accuracy {train_acc.avg:4.2%} || train_f1_acc {train_f1_acc:.4} || lr {current_lr}"
                )
                logger.add_scalar("Train/loss", train_loss.avg,
                                  epoch * len(train_loader) + idx)
                logger.add_scalar("Train/accuracy", train_acc.avg,
                                  epoch * len(train_loader) + idx)

                loss_value = 0
                matches = 0

        scheduler.step()

        val_loss, val_acc = AverageMeter(), AverageMeter()
        # val loop
        with torch.no_grad():
            print("Calculating validation results...")
            model.eval()
            val_labels_items = np.array([])
            val_preds_items = np.array([])
            figure = None
            for val_batch in val_loader:
                inputs, labels = val_batch
                if args.dataset == 'MaskDataset' or args.dataset == 'MaskOldDataset':
                    labels = labels.argmax(dim=-1)

                inputs = inputs.to(device)
                labels = labels.to(device)

                outs = model(inputs)
                preds = torch.argmax(outs, dim=-1)

                #loss_item = criterion(outs, labels).item()
                #acc_item = (labels == preds).sum().item()
                #val_loss_items.append(loss_item)
                #val_acc_items.append(acc_item)

                loss = criterion(outs, labels)
                acc = (preds == labels).sum().item() / len(labels)

                val_loss.update(loss.item(), len(labels))
                val_acc.update(acc, len(labels))

                val_labels_items = np.concatenate(
                    [val_labels_items, labels.cpu().numpy()])
                val_preds_items = np.concatenate(
                    [val_preds_items, preds.cpu().numpy()])

                if figure is None:
                    if epoch % 2:
                        images, labels, preds = get_all_datas(
                            model, device, val_loader)
                        figure = log_confusion_matrix(
                            labels.cpu().numpy(),
                            np.argmax(preds.cpu().numpy(), axis=1),
                            args.num_classes)
                        # figure2 = plots_result(images.cpu().numpy()[:36], labels.cpu().numpy()[:36], preds.cpu().numpy()[:36], args.num_classes, title="plots_result")
                    else:
                        inputs_np = torch.clone(inputs).detach().cpu().permute(
                            0, 2, 3, 1).numpy()
                        inputs_np = val_dataset.denormalize_image(
                            inputs_np, MEAN, STD)
                        figure = grid_image(inputs_np, labels, preds, 9, False)

            # val_loss = np.sum(val_loss_items) / len(val_loader)
            # val_acc = np.sum(val_acc_items) / len(val_set)
            val_f1_acc = f1_score(val_labels_items.astype(np.int),
                                  val_preds_items.astype(np.int),
                                  average='macro')

            best_val_acc = max(best_val_acc, val_acc.avg)
            # best_val_loss = min(best_val_loss, val_loss)
            if val_loss.avg < best_val_loss:
                print(
                    f"New best model for val loss : {val_loss.avg:4.2%}! saving the best model.."
                )
                torch.save(model.module.state_dict(), f"{save_dir}/best.pth")
                best_val_loss = val_loss.avg
            torch.save(model.module.state_dict(), f"{save_dir}/last.pth")
            print(
                f"[Val] acc : {val_acc.avg:4.2%}, loss : {val_loss.avg:.4f} || val_f1_acc : {val_f1_acc:.4} || "
                f"best acc : {best_val_acc:4.2%}, best loss : {best_val_loss:.4f}"
            )
            logger.add_scalar("Val/loss", val_loss.avg, epoch)
            logger.add_scalar("Val/accuracy", val_acc.avg, epoch)
            logger.add_figure("results", figure, epoch)
            # logger.add_figure("results1", figure2, epoch)

            early_stopping(val_loss.avg, model)

            if early_stopping.early_stop:
                print('Early stopping...')
                break

            print()