def predict(model_name, model_class, weight_pth, image_size, normalize):
    print(f'[+] predict {model_name}')
    model = get_model(model_class)
    model.load_state_dict(torch.load(weight_pth))
    model.eval()

    tta_preprocess = [
        preprocess(normalize, image_size),
        preprocess_hflip(normalize, image_size)
    ]
    tta_preprocess += make_transforms(
        [transforms.Resize((image_size + 20, image_size + 20))],
        [transforms.ToTensor(), normalize], five_crops(image_size))
    tta_preprocess += make_transforms(
        [transforms.Resize((image_size + 20, image_size + 20))],
        [HorizontalFlip(), transforms.ToTensor(), normalize],
        five_crops(image_size))
    print(f'[+] tta size: {len(tta_preprocess)}')

    data_loaders = []
    for transform in tta_preprocess:
        test_dataset = FurnitureDataset('test', transform=transform)
        data_loader = DataLoader(dataset=test_dataset,
                                 num_workers=1,
                                 batch_size=BATCH_SIZE,
                                 shuffle=False)
        data_loaders.append(data_loader)

    lx, px = utils.predict_tta(model, data_loaders)
    data = {
        'lx': lx.cpu(),
        'px': px.cpu(),
    }
    torch.save(data, f'{model_name}_test_prediction.pth')

    data_loaders = []
    for transform in tta_preprocess:
        test_dataset = FurnitureDataset('val', transform=transform)
        data_loader = DataLoader(dataset=test_dataset,
                                 num_workers=1,
                                 batch_size=BATCH_SIZE,
                                 shuffle=False)
        data_loaders.append(data_loader)

    lx, px = utils.predict_tta(model, data_loaders)
    data = {
        'lx': lx.cpu(),
        'px': px.cpu(),
    }
    torch.save(data, f'{model_name}_val_prediction.pth')
示例#2
0
def ensemble():
    test_dataset = FurnitureDataset('test', transform=preprocess)
    probs = []
    # for e in EPOCHS:
    for d, eps in zip(DIRS, EPOCHS):
        for e in eps:
            pth = d + NAME % e
            test_pred = torch.load(pth)
            test_prob = F.softmax(Variable(test_pred['px']),
                                  dim=1).data.numpy()
            if len(probs) == 0:
                probs = test_prob
            else:
                probs = np.concatenate((probs, test_prob), axis=-1)

    den_preds = np.argmax(probs, axis=1) + 1
    probs = probs.mean(axis=2)
    # import pdb
    # pdb.set_trace()
    # probs = 0.851 * probs[:, :, :21].mean(axis=2) + 0.863 * probs[:, :, 21:36].mean(axis=2) + 0.855 * probs[:, :, 36:].mean(axis=2)
    nas_probs, nas_preds = read_nasnet()
    en_preds = np.concatenate([den_preds, nas_preds], axis=1)
    probs += nas_probs
    # probs = np.concatenate([probs, nas_probs], axis=2)
    # probs = scistates.gmean(probs, axis=2)
    # probs = 0.85 * probs + 0.86 * nas_probs
    probs = calibrate_probs(probs)
    preds = np.argmax(probs, axis=1)
    preds += 1

    # preds = bin_count(en_preds)

    sx = pd.read_csv('../data/sample_submission_randomlabel.csv')
    sx.loc[sx.id.isin(test_dataset.data.image_id), 'predicted'] = preds
    sx.to_csv('ensemble.csv', index=False)
示例#3
0
def predict():
    model = get_model()
    model.load_state_dict(torch.load('best_val_weight.pth'))
    model.eval()

    tta_preprocess = [preprocess, preprocess_hflip]

    # tta_preprocess = [preprocess, preprocess_hflip, preprocess_with_augmentation,
    #                   preprocess_with_augmentation, preprocess_with_augmentation,
    #                   preprocess_with_augmentation, preprocess_with_augmentation,
    #                   preprocess_with_augmentation, preprocess_with_augmentation,
    #                   preprocess_with_augmentation]

    data_loaders = []
    for transform in tta_preprocess:
        test_dataset = FurnitureDataset('test2', transform=transform)
        data_loader = DataLoader(dataset=test_dataset,
                                 num_workers=1,
                                 batch_size=BATCH_SIZE,
                                 shuffle=False)
        data_loaders.append(data_loader)

    lx, px = utils.predict_tta(model, data_loaders)
    data = {
        'lx': lx.cpu(),
        'px': px.cpu(),
    }
    torch.save(data, 'test_prediction.pth')
def predict(args):
    model = get_model(args.name)
    model.load_state_dict(
        torch.load('models_trained/{}_{}_{}/best_val_weight_{}.pth'.format(
            args.name, args.aug, args.alpha, args.name)))
    model.eval()

    #tta_preprocess = [preprocess_five_crop, preprocess_five_crop_hflip]
    tta_preprocess = [preprocess, preprocess_hflip]

    data_loaders = []
    for transform in tta_preprocess:
        test_dataset = FurnitureDataset('test', transform=transform)
        data_loader = DataLoader(dataset=test_dataset,
                                 num_workers=1,
                                 batch_size=BATCH_SIZE,
                                 shuffle=False)
        data_loaders.append(data_loader)

    lx, px = utils.predict_tta(model, data_loaders)
    data = {
        'lx': lx.cpu(),
        'px': px.cpu(),
    }
    torch.save(
        data, 'models_trained/{}_{}_{}/test_prediction_{}.pth'.format(
            args.name, args.aug, args.alpha, args.name))

    data_loaders = []
    for transform in tta_preprocess:
        test_dataset = FurnitureDataset('val', transform=transform)
        data_loader = DataLoader(dataset=test_dataset,
                                 num_workers=1,
                                 batch_size=BATCH_SIZE,
                                 shuffle=False)
        data_loaders.append(data_loader)

    lx, px = utils.predict_tta(model, data_loaders)
    data = {
        'lx': lx.cpu(),
        'px': px.cpu(),
    }
    torch.save(
        data, 'models_trained/{}_{}_{}/val_prediction_{}.pth'.format(
            args.name, args.aug, args.alpha, args.name))
def generate_final_predictions(model_name):
	test_dataset = FurnitureDataset('test', transform=preprocess)

	test_pred = torch.load('test_prediction_' + model_name + '.pth')
	test_prob = F.softmax(Variable(test_pred['px']), dim=1).data.numpy()
	test_prob = test_prob.mean(axis=2)

	test_predicted = np.argmax(test_prob, axis=1)
	test_predicted += 1
	result = test_predicted

	sx = pd.read_csv('data/sample_submission_randomlabel.csv')
	sx.loc[sx.id.isin(test_dataset.data.image_id), 'predicted'] = result
	sx.to_csv('sx_' + model_name + '.csv', index=False)
def predict(model_name, outputDir):
    model = get_model(model_name)
    model_checkpoint = torch.load(os.path.join(outputDir, 'best_val_acc_weight_' + model_name + '.pth'))
    model.load_state_dict(model_checkpoint)
    model.eval()

    tta_preprocess = [preprocess, preprocess_hflip]

    data_loaders = []
    for transform in tta_preprocess:
        test_dataset = FurnitureDataset('test', transform=transform)
        data_loader = DataLoader(dataset=test_dataset, num_workers=1,
                                 batch_size=BATCH_SIZE,
                                 shuffle=False)
        data_loaders.append(data_loader)

    lx, px = utils.predict_tta(model, data_loaders, device)
    data = {
        'lx': lx.cpu(),
        'px': px.cpu(),
    }
    torch.save(data, os.path.join(outputDir, 'test_prediction_' + model_name + '.pth'))

    data_loaders = []
    for transform in tta_preprocess:
        test_dataset = FurnitureDataset('val', transform=transform)
        data_loader = DataLoader(dataset=test_dataset, num_workers=1,
                                 batch_size=BATCH_SIZE,
                                 shuffle=False)
        data_loaders.append(data_loader)

    lx, px = utils.predict_tta(model, data_loaders, device)
    data = {
        'lx': lx.cpu(),
        'px': px.cpu(),
    }
    torch.save(data, os.path.join(outputDir, 'val_prediction_' + model_name + '.pth'))
示例#7
0
def main():
    test_dataset = FurnitureDataset('test', transform=preprocess)

    # pth = DIRS[2] + 'test_prediction_e6.pth'
    pth = 'test_prediction_e10.pth'
    print("loading {}".format(pth))
    test_pred = torch.load(pth)
    test_prob = F.softmax(Variable(test_pred['px']), dim=1).data.numpy()
    test_prob = test_prob.mean(axis=2)

    test_predicted = np.argmax(test_prob, axis=1)
    test_predicted += 1
    result = test_predicted

    sx = pd.read_csv('../data/sample_submission_randomlabel.csv')
    sx.loc[sx.id.isin(test_dataset.data.image_id), 'predicted'] = result
    sx.to_csv(pth.split('.')[0] + '.csv', index=False)
def train(model_name, outputDir):
    train_dataset = FurnitureDataset('train', transform=preprocess_with_augmentation)
    val_dataset = FurnitureDataset('val', transform=preprocess)
    training_data_loader = DataLoader(dataset=train_dataset, num_workers=12,
                                      batch_size=BATCH_SIZE,
                                      shuffle=True)
    validation_data_loader = DataLoader(dataset=val_dataset, num_workers=1,
                                        batch_size=BATCH_SIZE,
                                        shuffle=False)

    model = get_model(model_name)

    nb_learnable_params = sum(p.numel() for p in model.fresh_params())
    print('Number of learnable params: %s' % str(nb_learnable_params))

    # Use model.fresh_params() to train only the newly initialized weights
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2)

    if model_name.endswith("_focal"):
        print ("Using Focal loss instead of normal cross-entropy")
        criterion = FocalLoss(NB_CLASSES).to(device)
    else:
        criterion = nn.CrossEntropyLoss().to(device)

    min_loss = float("inf")
    max_acc = 0.0
    patience = 0
    for epoch in range(NUM_EPOCHS):
        print('Epoch: %d' % epoch)
        
        running_loss = RunningMean()
        running_error = RunningMean()
        running_accuracy = RunningMean()

        model.train()
        pbar = tqdm(training_data_loader, total=len(training_data_loader))
        for inputs, labels in pbar:
            batch_size = inputs.size(0)

            inputs = Variable(inputs)
            labels = Variable(labels)
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            outputs = model(inputs)
            _, preds = torch.max(outputs.data, dim=1)

            loss = criterion(outputs, labels)
            running_loss.update(loss.data[0], 1)
            running_error.update(torch.sum(preds != labels.data), batch_size)
            running_accuracy.update(torch.sum(preds == labels.data), batch_size)

            loss.backward()
            optimizer.step()

            pbar.set_description('%.5f %.3f %.3f' % (running_loss.value, running_accuracy.value, running_error.value))
        print('Epoch: %d | Running loss: %.5f | Running accuracy: %.3f | Running error: %.3f' % (epoch, running_loss.value, running_accuracy.value, running_error.value))

        lx, px = utils.predict(model, validation_data_loader, device)
        log_loss = criterion(Variable(px), Variable(lx))
        log_loss = log_loss.data[0]
        _, preds = torch.max(px, dim=1)
        accuracy = torch.mean((preds == lx).float())
        error = torch.mean((preds != lx).float())
        print('Validation loss: %.5f | Accuracy: %.3f | Error: %.3f' % (log_loss, accuracy, error))
        scheduler.step(log_loss)

        # Save model after each epoch
        torch.save(model.state_dict(), os.path.join(outputDir, 'weight_' + model_name + '.pth'))

        betterModelFound = False
        if log_loss < min_loss:
            torch.save(model.state_dict(), os.path.join(outputDir, 'best_val_loss_weight_' + model_name + '.pth'))
            print('Validation score improved from %.5f to %.5f. Model snapshot saved!' % (min_loss, log_loss))
            min_loss = log_loss
            patience = 0
            betterModelFound = True

        if accuracy > max_acc:
            torch.save(model.state_dict(), os.path.join(outputDir, 'best_val_acc_weight_' + model_name + '.pth'))
            print('Validation accuracy improved from %.5f to %.5f. Model snapshot saved!' % (max_acc, accuracy))
            max_acc = accuracy
            patience = 0
            betterModelFound = True

        if not betterModelFound:
            patience += 1
示例#9
0
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import pandas as pd
from misc import FurnitureDataset, preprocess

five_crop = False

test_dataset = FurnitureDataset('test', transform=preprocess)

test_pred = torch.load('test_prediction_densenet201.pth')
test_prob = F.softmax(Variable(test_pred['px']), dim=1).data.numpy()

if five_crop:
    nsamples, nclass, naug = test_prob.shape
    test_prob = test_prob.transpose(1,2,0)
    #print(test_prob.shape)
    test_prob = test_prob.reshape(nclass, naug, -1, 5)
    test_prob = test_prob.mean(axis=-1)
    test_prob = test_prob.transpose(2,0,1)
#print(test_prob.shape)
test_prob = test_prob.mean(axis=2)

test_id_ord = test_pred['lx'].numpy()
if five_crop:
    test_id_ord = test_id_ord.reshape(int(test_id_ord.shape[0]/5), -1)
    test_id_ord = test_id_ord.mean(axis=-1)
#print(test_id_ord)

def train():
    train_dataset = FurnitureDataset('train',
                                     transform=preprocess_with_augmentation)
    val_dataset = FurnitureDataset('val', transform=preprocess)
    training_data_loader = DataLoader(dataset=train_dataset,
                                      num_workers=8,
                                      batch_size=BATCH_SIZE,
                                      shuffle=True)
    validation_data_loader = DataLoader(dataset=val_dataset,
                                        num_workers=1,
                                        batch_size=BATCH_SIZE,
                                        shuffle=False)

    model = get_model()

    criterion = nn.CrossEntropyLoss().cuda()

    nb_learnable_params = sum(p.numel() for p in model.fresh_params())
    print(f'[+] nb learnable params {nb_learnable_params}')

    min_loss = float("inf")
    lr = 0
    patience = 0
    for epoch in range(20):
        print(f'epoch {epoch}')
        if epoch == 1:
            lr = 0.00003
            print(f'[+] set lr={lr}')
        if patience == 2:
            patience = 0
            model.load_state_dict(torch.load('best_val_weight.pth'))
            lr = lr / 10
            print(f'[+] set lr={lr}')
        if epoch == 0:
            lr = 0.001
            print(f'[+] set lr={lr}')
            optimizer = torch.optim.Adam(model.fresh_params(), lr=lr)
        else:
            optimizer = torch.optim.Adam(model.parameters(),
                                         lr=lr,
                                         weight_decay=0.0001)

        running_loss = RunningMean()
        running_score = RunningMean()

        model.train()
        pbar = tqdm(training_data_loader, total=len(training_data_loader))
        for inputs, labels in pbar:
            batch_size = inputs.size(0)

            inputs = Variable(inputs)
            labels = Variable(labels)
            if use_gpu:
                inputs = inputs.cuda()
                labels = labels.cuda()

            optimizer.zero_grad()
            outputs = model(inputs)
            _, preds = torch.max(outputs.data, dim=1)

            loss = criterion(outputs, labels)
            running_loss.update(loss.data[0], 1)
            running_score.update(torch.sum(preds != labels.data), batch_size)

            loss.backward()
            optimizer.step()

            pbar.set_description(
                f'{running_loss.value:.5f} {running_score.value:.3f}')
        print(
            f'[+] epoch {epoch} {running_loss.value:.5f} {running_score.value:.3f}'
        )

        lx, px = utils.predict(model, validation_data_loader)
        log_loss = criterion(Variable(px), Variable(lx))
        log_loss = log_loss.data[0]
        _, preds = torch.max(px, dim=1)
        accuracy = torch.mean((preds != lx).float())
        print(f'[+] val {log_loss:.5f} {accuracy:.3f}')

        if log_loss < min_loss:
            torch.save(model.state_dict(), 'best_val_weight.pth')
            print(
                f'[+] val score improved from {min_loss:.5f} to {log_loss:.5f}. Saved!'
            )
            min_loss = log_loss
            patience = 0
        else:
            patience += 1
示例#11
0
def predict():
    model = get_model()
    model.load_state_dict(torch.load('best_val_weight.pth'))
    model.eval()

    tta_preprocess = [
        preprocess_with_augmentation, preprocess_with_augmentation,
        preprocess_with_augmentation
    ]

    # data_loaders = []
    # for transform in tta_preprocess:
    #     test_dataset = FurnitureDataset('test', transform=transform)
    #     data_loader = DataLoader(dataset=test_dataset, num_workers=1,
    #                              batch_size=BATCH_SIZE,
    #                              shuffle=False)
    #     data_loaders.append(data_loader)
    #
    # lx, px = utils.predict_tta(model, data_loaders)
    # data = {
    #     'lx': lx.cpu(),
    #     'px': px.cpu(),
    # }

    prediction = None
    lx = None

    for idx, transform in enumerate(tta_preprocess):
        test_dataset = FurnitureDataset('test', transform=transform)
        data_loader = DataLoader(dataset=test_dataset,
                                 num_workers=1,
                                 batch_size=BATCH_SIZE,
                                 shuffle=False)
        lx, px = utils_predict(model, data_loader)

        print("{:} stack predictions".format(idx), "\n\n")
        prediction = safe_stack_2array(prediction, px, dim=-1)
        print("prediction shape:", prediction.cpu(), "\n\n")
        # data_loaders.append(data_loader)

    # def predict_tta(model, dataloaders):
    #     prediction = None
    #     lx = None
    #     for dataloader in dataloaders:
    #         lx, px = predict(model, dataloader)
    #         prediction = safe_stack_2array(prediction, px, dim=-1)
    #
    #     return lx, prediction

    # lx, px = utils.predict_tta(model, data_loaders)

    data = {
        'lx': lx.cpu(),
        'px': prediction.cpu(),
    }
    torch.save(data, 'test_prediction.pth')

    # data_loaders = []
    prediction = None
    lx = None

    for idx, transform in enumerate(tta_preprocess):
        test_dataset = FurnitureDataset('val', transform=transform)
        data_loader = DataLoader(dataset=test_dataset,
                                 num_workers=1,
                                 batch_size=BATCH_SIZE,
                                 shuffle=False)
        lx, px = utils_predict(model, data_loader)

        print("{:} stack predictions".format(idx), "\n\n")
        prediction = safe_stack_2array(prediction, px, dim=-1)
        print("prediction shape:", prediction.shape, "\n\n")
        # data_loaders.append(data_loader)

    # def predict_tta(model, dataloaders):
    #     prediction = None
    #     lx = None
    #     for dataloader in dataloaders:
    #         lx, px = predict(model, dataloader)
    #         prediction = safe_stack_2array(prediction, px, dim=-1)
    #
    #     return lx, prediction

    # lx, px = utils.predict_tta(model, data_loaders)

    data = {
        'lx': lx.cpu(),
        'px': prediction.cpu(),
    }
    torch.save(data, 'val_prediction.pth')
def predict(epoch=None, attention=False):
    BATCH_SIZE = 8
    if not epoch:
        epoch = EPOCH
    if attention:
        save_name = "att_"
    else:
        save_name = ""
    pth_path = OUTPUT_PATH + save_name + 'best_val_weight_%s.pth' % epoch
    print("loading %s" % pth_path)
    if not attention:
        model = get_model()
    else:
        print("loading model...")
        model = dense_attention201(pretrained=False, num_classes=128)
        if use_gpu:
            model.cuda()
        print("done.")
    model.load_state_dict(torch.load(pth_path))
    model.eval()
    tta_preprocess = [
        preprocess_for_test, preprocess_for_test, preprocess_for_test,
        preprocess, preprocess_hflip
    ]

    ################### TEST VALIDATION SET
    # data_loaders = []
    # for transform in [preprocess]:
    #     test_dataset = FurnitureDataset('validation', transform=transform)
    #     data_loader = DataLoader(dataset=test_dataset, num_workers=0,
    #                              batch_size=BATCH_SIZE,
    #                              shuffle=False)
    #     data_loaders.append(data_loader)

    # lx, px = utils.predict_tta(model, data_loaders)
    # data = {
    #     'lx': lx.cpu(),
    #     'px': px.cpu(),
    # }
    # _, preds = torch.max(px, dim=1)
    # accuracy = torch.mean((preds.view(-1) != lx).float())
    # print("accuracy: {:.5f}".format(accuracy))
    # torch.save(data, save_name + 'val_prediction.pth')
    ################### TEST VALIDATION SET

    data_loaders = []
    print("number of tta: {}".format(len(tta_preprocess)))
    for transform in tta_preprocess:
        test_dataset = FurnitureDataset('test', transform=transform)
        data_loader = DataLoader(dataset=test_dataset,
                                 num_workers=0,
                                 batch_size=BATCH_SIZE,
                                 shuffle=False)
        data_loaders.append(data_loader)

    lx, px = utils.predict_tta(model, data_loaders, test=True)
    data = {
        #'lx': lx.cpu(),
        'px': px.cpu(),
    }
    torch.save(data, save_name + 'test_prediction_e%s.pth' % epoch)
def train(attention=False):
    train_dataset = FurnitureDataset('train',
                                     transform=preprocess_with_augmentation)
    train_val_dataset = FurnitureDataset(
        'validation', transform=preprocess_with_augmentation)
    val_dataset = FurnitureDataset('validation', transform=preprocess)

    training_data_loader = DataLoader(dataset=train_dataset,
                                      num_workers=8,
                                      batch_size=BATCH_SIZE,
                                      shuffle=True)
    train_val_data_loader = DataLoader(dataset=val_dataset,
                                       num_workers=8,
                                       batch_size=BATCH_SIZE,
                                       shuffle=True)
    validation_data_loader = DataLoader(dataset=val_dataset,
                                        num_workers=0,
                                        batch_size=BATCH_SIZE // 2,
                                        shuffle=False)

    if USE_FOCAL_LOSS:
        criterion = nn.CrossEntropyLoss(reduce=False).cuda()
    else:
        criterion = nn.CrossEntropyLoss().cuda()

    print("loading model...")
    if not attention:
        model = get_model()
        save_name = ""
    else:
        save_name = "att_"
        model = dense_attention201(num_classes=128)
        if use_gpu:
            model.cuda()
        fresh_params = [p['params'] for p in model.fresh_params()]
        nb_learnable_params = 0
        for pp in fresh_params:
            nb_learnable_params += sum(p.numel() for p in pp)
        print('[+] nb learnable params {}'.format(nb_learnable_params))
    print("done.")

    min_loss = float("inf")
    patience = 0

    for epoch in range(STARTER, STARTER + 10):
        print('epoch {}'.format(epoch))
        if epoch == 1:
            lr = 0.00002
            model.load_state_dict(torch.load('best_val_weight_0.pth'))
            print("[+] loading best_val_weight_0.pth")
        if patience == 2:
            patience = 0
            model.load_state_dict(torch.load('best_val_weight.pth'))
            lr = lr / 5
        elif epoch + 1 % 3 == 0:
            ckpt = save_name + 'best_val_weight_%s.pth' % (epoch - 1)
            if not os.path.exists(ckpt):
                ckpt = save_name + 'best_val_weight.pth'
            print("loading {}".format(ckpt))
            model.load_state_dict(torch.load(ckpt))
            lr = lr / 2

        if epoch == 0:
            lr = 0.001
            optimizer = torch.optim.Adam(model.fresh_params(), lr=lr)
        else:
            optimizer = torch.optim.Adam(model.parameters(),
                                         lr=lr,
                                         weight_decay=wd)

        print('[+] set lr={}'.format(lr))
        running_loss = RunningMean()
        running_score = RunningMean()

        model.train()
        ### FOR TRAINING VALIDATION SET
        # if epoch - STARTER + 1 % 2 == 0 and epoch - STARTER > 4:
        #     loader =train_val_data_loader
        #     print("[+] trianing with validation set")
        # else:
        #     loader = training_data_loader
        ### FOR TRAINING VALIDATION SET
        loader = training_data_loader
        pbar = tqdm(loader, total=len(loader))
        for inputs, labels in pbar:
            batch_size = inputs.size(0)

            inputs = Variable(inputs)
            target = Variable(labels)
            if use_gpu:
                inputs = inputs.cuda()
                target = target.cuda()

            optimizer.zero_grad()
            outputs = model(inputs)
            _, preds = torch.max(outputs.data, dim=1)
            loss = criterion(outputs, target)
            if USE_FOCAL_LOSS:
                y_index = torch.LongTensor(np.arange(labels.shape[0])).cpu()
                l_weight = F.softmax(outputs,
                                     dim=1).cpu()[y_index,
                                                  torch.LongTensor(labels)]
                l_weight = l_weight.detach()
                loss = torch.mean(4 * l_weight.cuda() * loss)
            running_loss.update(loss.data[0], 1)
            running_score.update(
                torch.sum(preds != target.data, dtype=torch.float32),
                batch_size)
            loss.backward()
            optimizer.step()

            pbar.set_description('{:.5f} {:.3f}'.format(
                running_loss.value, running_score.value))
        print('[+] epoch {} {:.5f} {:.3f}'.format(epoch, running_loss.value,
                                                  running_score.value))

        torch.save(model.state_dict(),
                   save_name + 'best_val_weight_%s.pth' % epoch)

        lx, px = utils.predict(model, validation_data_loader)
        log_loss = criterion(Variable(px), Variable(lx))
        log_loss = log_loss.data[0]
        _, preds = torch.max(px, dim=1)
        accuracy = torch.mean((preds != lx).float())
        print('[+] val {:.5f} {:.3f}'.format(log_loss, accuracy))

        if log_loss < min_loss:
            torch.save(model.state_dict(), 'best_val_weight.pth')
            print(
                '[+] val score improved from {:.5f} to {:.5f}. Saved!'.format(
                    min_loss, log_loss))
            min_loss = log_loss
            patience = 0
        else:
            patience += 1
def train(args):

    train_dataset = FurnitureDataset('train',
                                     transform=preprocess_with_augmentation)
    val_dataset = FurnitureDataset('val', transform=preprocess)
    training_data_loader = DataLoader(dataset=train_dataset,
                                      num_workers=8,
                                      batch_size=BATCH_SIZE,
                                      shuffle=True)
    validation_data_loader = DataLoader(dataset=val_dataset,
                                        num_workers=1,
                                        batch_size=BATCH_SIZE,
                                        shuffle=False)

    model = get_model(args.name)

    class_weight = np.load('./class_weight.npy')

    #criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor(class_weight)).cuda()
    criterion = nn.CrossEntropyLoss().cuda()
    #criterion = FocalLoss(alpha=alpha, gamma=0).cuda()

    nb_learnable_params = sum(p.numel() for p in model.fresh_params())
    print(f'[+] nb learnable params {nb_learnable_params}')

    min_loss = float("inf")
    lr = 0
    patience = 0
    for epoch in range(30):
        print(f'epoch {epoch}')
        if epoch == 1:
            lr = 0.00003
            print(f'[+] set lr={lr}')
        if patience == 2:
            patience = 0
            model.load_state_dict(
                torch.load(
                    'models_trained/{}_{}_{}/best_val_weight_{}.pth'.format(
                        args.name, args.aug, args.alpha, args.name)))
            lr = lr / 10
            if lr < 3e-6:
                lr = 3e-6
            print(f'[+] set lr={lr}')
        if epoch == 0:
            lr = 0.001
            print(f'[+] set lr={lr}')
            optimizer = torch.optim.Adam(model.fresh_params(), lr=lr)
        else:
            optimizer = torch.optim.Adam(model.parameters(),
                                         lr=lr,
                                         weight_decay=0.0001)

        running_loss = RunningMean()
        running_score = RunningMean()

        model.train()
        pbar = tqdm(training_data_loader, total=len(training_data_loader))
        for inputs, labels in pbar:
            batch_size = inputs.size(0)

            inputs = Variable(inputs)
            labels = Variable(labels)
            if use_gpu:
                inputs = inputs.cuda()
                labels = labels.cuda()

            optimizer.zero_grad()

            if args.aug:
                inputs, targets_a, targets_b, lam = mixup_data(
                    inputs, labels, args.alpha, use_gpu)

            outputs = model(inputs)

            if args.aug:
                loss_func = mixup_criterion(targets_a, targets_b, lam)
                loss = loss_func(criterion, outputs)
            else:
                loss = criterion(outputs, labels)

            _, preds = torch.max(outputs.data, dim=1)
            running_loss.update(loss.data[0], 1)

            if args.aug:
                running_score.update(
                    batch_size - lam * preds.eq(targets_a.data).cpu().sum() -
                    (1 - lam) * preds.eq(targets_b.data).cpu().sum(),
                    batch_size)
            else:
                running_score.update(torch.sum(preds != labels.data),
                                     batch_size)

            loss.backward()
            optimizer.step()

            pbar.set_description(
                f'{running_loss.value:.5f} {running_score.value:.3f}')
        print(
            f'[+] epoch {epoch} {running_loss.value:.5f} {running_score.value:.3f}'
        )

        lx, px = utils.predict(model, validation_data_loader)
        log_loss = criterion(Variable(px), Variable(lx))
        log_loss = log_loss.data[0]
        _, preds = torch.max(px, dim=1)
        accuracy = torch.mean((preds != lx).float())
        print(f'[+] val {log_loss:.5f} {accuracy:.3f}')

        if log_loss < min_loss:
            torch.save(
                model.state_dict(),
                'models_trained/{}_{}_{}/best_val_weight_{}.pth'.format(
                    args.name, args.aug, args.alpha, args.name))
            print(
                f'[+] val score improved from {min_loss:.5f} to {log_loss:.5f}. Saved!'
            )
            min_loss = log_loss
            patience = 0
        else:
            patience += 1