def train_deep_model():
    dataset_train = dataset.FER2013Dataset("train", (224, 224))
    dataset_val = dataset.FER2013Dataset("val", (224, 224))

    data_loader = {
        "train": DataLoader(dataset_train, batch_size=128, num_workers=4),
        "val": DataLoader(dataset_val, batch_size=128, num_workers=4),
    }

    train_vgg_face(data_loader["train"], data_loader["val"])
    del data_loader

    data_loader = {
        "train": DataLoader(dataset_train, batch_size=256, num_workers=4),
        "val": DataLoader(dataset_val, batch_size=256, num_workers=4),
    }

    train_vgg_f(data_loader["train"], data_loader["val"])
    del dataset_train, dataset_val, data_loader

    dataset_train = dataset.FER2013Dataset("train", (64, 64))
    dataset_val = dataset.FER2013Dataset("val", (64, 64))

    data_loader = {
        "train": DataLoader(dataset_train, batch_size=512, num_workers=4),
        "val": DataLoader(dataset_val, batch_size=512, num_workers=4),
    }

    train_vgg_13(data_loader["train"], data_loader["val"])
def train_combine_global():
    dataset_train = dataset.FER2013Dataset("train", (48, 48))
    dataset_test = dataset.FER2013Dataset("test", (48, 48))

    data_train, label_train = dataset_train.get_data()
    data_test, label_test = dataset_test.get_data()

    del data_train, data_test

    sample_train_f = h5py.File("./data/combined_features_train.hdf5", "r")
    sample_test_f = h5py.File("./data/combined_features_test.hdf5", "r")

    sample_train = sample_train_f["features"]
    sample_test = sample_test_f["features"]

    # Train global SVM
    print("Training global SVM ...")

    classes = list(range(7))
    max_iter = 30
    alphas = [3.5714285714285716e-07, 10**-4]
    acc1 = []

    global_svm = SGDClassifier(alpha=alphas[1], n_jobs=-1)
    for i in trange(max_iter):
        minibatchor = handcraft_model.iter_minibatches(sample_train,
                                                       label_train, 1000)

        for x, y in tqdm(minibatchor):
            global_svm.partial_fit(x, y, classes)

        idx = [0, 1000, 2000, 3000, 3589]
        acc = []
        for j in range(len(idx) - 1):
            acc.append(
                global_svm.score(sample_test[idx[j]:idx[j + 1]],
                                 label_test[idx[j]:idx[j + 1]]))
        print("Global SVM validation accuracy: {:.4f}".format(np.mean(acc)))

    torch.save({
        "svm": global_svm,
    }, "./models/global_svm_2.pth")

    idx = [0, 1000, 2000, 3000, 3589]
    acc = []
    for j in range(len(idx) - 1):
        acc.append(
            global_svm.score(sample_test[idx[j]:idx[j + 1]],
                             label_test[idx[j]:idx[j + 1]]))
    print("Global SVM validation accuracy: {:.4f}".format(np.mean(acc)))

    sample_train_f.close()
    sample_test_f.close()
def extract_vgg_f():
    print("Extracting vgg_f model training sample features ...")

    check = torch.load("./models/vgg_f_no_dsd.pth")
    model = deep_model.Vggf()

    model.fc8 = nn.Linear(model.fc8.in_features, 7)
    model.conv1 = nn.Conv2d(1, 64, kernel_size=11, stride=4)

    model = model.to(DEVICE)
    model.load_state_dict(check["model_state_dict"])

    dataset_ = dataset.FER2013Dataset("train", (224, 224))
    dataloader = DataLoader(dataset_, batch_size=1, num_workers=4)

    model.eval()

    with h5py.File("./data/vgg_f_features_train.hdf5", "w") as f:
        features = f.create_dataset("features", (len(dataset_), 4096))

        with torch.set_grad_enabled(False):
            counter = 0
            for data, label in tqdm(dataloader):
                data = data.to(DEVICE)
                preds = model.extract_features(data)
                preds = preds.cpu().numpy()
                preds = preprocessing.normalize(preds, norm="l2")

                features[counter] = preds
                counter += 1

    del dataset_, dataloader

    print("Extracting vgg_f model testing sample features ...")

    dataset_ = dataset.FER2013Dataset("test", (224, 224))
    dataloader = DataLoader(dataset_, batch_size=1, num_workers=4)

    with h5py.File("./data/vgg_f_features_test.hdf5", "w") as f:
        features = f.create_dataset("features", (len(dataset_), 4096))

        with torch.set_grad_enabled(False):
            counter = 0
            for data, label in tqdm(dataloader):
                data = data.to(DEVICE)
                preds = model.extract_features(data)
                preds = preds.cpu().numpy()
                preds = preprocessing.normalize(preds, norm="l2")

                features[counter] = preds
                counter += 1
def train_combine():
    sample_train_f = h5py.File("./data/combined_features_train.hdf5", "r")
    sample_test_f = h5py.File("./data/combined_features_test.hdf5", "r")
    similarity_f = h5py.File("./data/cosine_similarity.hdf5", "r")

    sample_train = sample_train_f["features"]
    sample_test = sample_test_f["features"]
    similarity = similarity_f["similarity"]

    dataset_train = dataset.FER2013Dataset("train", (48, 48))
    dataset_test = dataset.FER2013Dataset("test", (48, 48))
    data_train, label_train = dataset_train.get_data()
    data_test, label_test = dataset_test.get_data()
    del dataset_train, dataset_test, data_train, data_test

    preds = []
    for i in trange(similarity.shape[0]):
        best_idx = similarity[i]

        near_train = np.zeros((200, sample_train.shape[1]))
        near_train_label = np.zeros((200, ))
        for idx, value in enumerate(best_idx):
            near_train[idx] = sample_train[value]
            near_train_label[idx] = label_train[value]

        test = sample_test[i]

        local_svm = LinearSVC(C=100,
                              verbose=1,
                              max_iter=100,
                              class_weight="balanced")
        local_svm.fit(near_train, near_train_label)
        preds.append(
            local_svm.predict(test.reshape(1, -1)).astype(np.int64)[0])

        accuracy = accuracy_score(label_test[:i + 1], preds)
        print("Local SVM validation accuracy: {:.4f}".format(accuracy))

    accuracy = accuracy_score(label_test, preds)
    print("Local SVM validation accuracy: {:.4f}".format(accuracy))

    sample_train_f.close()
    sample_test_f.close()
    similarity_f.close()
def extract_handcraft_model(resume=True):
    dataset_train = dataset.FER2013Dataset("train", (48, 48))
    dataset_val = dataset.FER2013Dataset("test", (48, 48))

    data_train, label_train = dataset_train.get_data()
    data_val, label_val = dataset_val.get_data()

    descriptors_train = handcraft_model.calc_densen_SIFT(data_train)
    descriptors_val = handcraft_model.calc_densen_SIFT(data_val)

    if resume:
        print("Loading kmeans models ...")
        checkpoint = torch.load("./models/kmeans.pth")
        kmeans_train = checkpoint["kmeans_train"]
        kmeans_val = checkpoint["kmeans_val"]
    else:
        kmeans_train = handcraft_model.kmeans([17000, 14000, 8000],
                                              descriptors_train)
        kmeans_val = handcraft_model.kmeans([17000, 14000, 8000],
                                            descriptors_val)

        torch.save(
            {
                "kmeans_train": kmeans_train,
                "kmeans_val": kmeans_val,
            },
            "./models/kmeans.pth",
        )

    handcraft_model.get_histogram_spatial_pyramid(
        data_train, kmeans_train, "./data/histogram_train.hdf5")
    print("histogram train file saved")

    handcraft_model.get_histogram_spatial_pyramid(
        data_val, kmeans_val, "./data/histogram_test.hdf5")
    print("histogram val file saved")
def test_vgg_f(imgs=None):
    check = torch.load("./models/vgg_f_no_dsd.pth", map_location=DEVICE)
    model = deep_model.Vggf(pretrain=False)

    model.fc8 = nn.Linear(model.fc8.in_features, 7)
    model.conv1 = nn.Conv2d(1, 64, kernel_size=11, stride=4)

    model = model.to(DEVICE)
    model.load_state_dict(check["model_state_dict"])

    if imgs is None:
        dataset_val = dataset.FER2013Dataset("test", (224, 224))
        data_loader = DataLoader(dataset_val, batch_size=256, num_workers=4)

        model.eval()
        val_acc = []
        with torch.set_grad_enabled(False):
            for data, label in tqdm(data_loader):
                data = data.to(DEVICE)
                label = label.to(DEVICE)

                preds = model(data)

                preds = torch.argmax(preds, 1).cpu().numpy()
                label = label.cpu().numpy()
                val_acc.append(accuracy_score(label, preds))

        accuracy = np.mean(np.array(val_acc).astype(np.float))

    else:
        model.eval()
        val_acc = []
        with torch.set_grad_enabled(False):
            data = imgs[0].to(DEVICE)
            label = imgs[1].to(DEVICE)

            preds = model(data)

            preds = torch.argmax(preds, 1).cpu().numpy()
            label = label.cpu().numpy()
            val_acc.append(accuracy_score(label, preds))

        accuracy = np.mean(np.array(val_acc).astype(np.float))
        return preds

    print("Finish vgg_f model test with accuracy {:.4f}".format(accuracy))
def test_vgg_13(imgs=None):
    check = torch.load("./models/vgg_13_no_dsd.pth", map_location=DEVICE)
    model = deep_model.Vgg13()

    model = model.to(DEVICE)
    model.load_state_dict(check["model_state_dict"])

    if imgs is None:
        dataset_val = dataset.FER2013Dataset("test", (64, 64))
        data_loader = DataLoader(dataset_val, batch_size=512, num_workers=4)

        model.eval()
        val_acc = []
        with torch.set_grad_enabled(False):
            for data, label in tqdm(data_loader):
                data = data.to(DEVICE)
                label = label.to(DEVICE)

                preds = model(data)

                preds = torch.argmax(preds, 1).cpu().numpy()
                label = label.cpu().numpy()
                val_acc.append(accuracy_score(label, preds))

        accuracy = np.mean(np.array(val_acc).astype(np.float))

    else:
        model.eval()
        val_acc = []
        with torch.set_grad_enabled(False):
            data = imgs[0].to(DEVICE)
            label = imgs[1].to(DEVICE)

            preds = model(data)

            preds = torch.argmax(preds, 1).cpu().numpy()
            label = label.cpu().numpy()
            val_acc.append(accuracy_score(label, preds))

        accuracy = np.mean(np.array(val_acc).astype(np.float))
        return preds

    print("Finish vgg_13 model test with accuracy {:.4f}".format(accuracy))
    train_loss = checkpoint["train_loss"]
    val_loss = checkpoint["val_loss"]
    val_acc = checkpoint["val_acc"]

    fig = plt.figure()
    line1, = plt.plot(train_loss, label='train_loss')
    line2, = plt.plot(val_loss, label='val_loss')
    line3, = plt.plot(val_acc, label='val_acc')
    plt.legend(handles=[line1, line2, line3], loc='best')
    plt.xlabel('epoch')
    plt.ylabel('training loss')
    plt.grid()
    fig.savefig("./vgg_13_process.png")

    print("\nSelecting 3 images from testing set ...")
    test_dataset_224 = dataset.FER2013Dataset(data_type="test",
                                              scale_size=(224, 224))
    test_dataset_64 = dataset.FER2013Dataset(data_type="test",
                                             scale_size=(64, 64))

    # Make testing samples
    idx = [0, 3, 4]
    imgs_tensor_224 = []
    labels_224 = []
    imgs_tensor_64 = []
    labels_64 = []
    for i in range(len(idx)):
        data, label = test_dataset_224[idx[i]]

        # Save test sample image
        img = transforms.ToPILImage()(data)
        img.save("test_sample_{}.png".format(i))