def LoadBatch(filename):
    """
    Loads batch based on the given filename and produces the X, Y, and y arrays

    :param filename: Path of the file
    :return: X, Y and y arrays
    """

    # borrowed from https://www.cs.toronto.edu/~kriz/cifar.html
    def unpickle(file):
        with open(file, 'rb') as fo:
            dict = pickle.load(fo, encoding='bytes')
        return dict

    dictionary = unpickle(filename)

    # borrowed from https://stackoverflow.com/questions/16977385/extract-the-nth-key-in-a-python-dictionary?utm_medium=organic&utm_source=google_rich_qa&utm_campaign=google_rich_qa
    def ix(dic, n):  # don't use dict as  a variable name
        try:
            return list(dic)[
                n]  # or sorted(dic)[n] if you want the keys to be sorted
        except IndexError:
            print('not enough keys')

    garbage = ix(dictionary, 1)
    y = dictionary[garbage]
    Y = np.transpose(make_class_categorical(y, 10))
    garbage = ix(dictionary, 2)
    X = np.transpose(dictionary[garbage]) / 255

    return X, Y, y
def LoadBatch(filename):
    """
    Loads a CIFAR-10 batch of data.

    :param filename: The path of the file in your local computer.
    :return: CIFAR-10 data X, their one-hot representation Y, and their true labels y
    """

    # borrowed from https://www.cs.toronto.edu/~kriz/cifar.html
    def unpickle(file):
        with open(file, 'rb') as fo:
            dict = pickle.load(fo, encoding='latin1')
        return dict

    dictionary = unpickle(filename)

    # borrowed from https://stackoverflow.com/questions/16977385/extract-the-nth-key-in-a-python-dictionary?utm_medium=organic&utm_source=google_rich_qa&utm_campaign=google_rich_qa
    def ix(dic, n):  # don't use dict as  a variable name
        try:
            return list(dic)[
                n]  # or sorted(dic)[n] if you want the keys to be sorted
        except IndexError:
            print('not enough keys')

    garbage = ix(dictionary, 1)
    y = dictionary[garbage]
    Y = np.transpose(make_class_categorical(y, 10))
    garbage = dictionary['data']
    X = np.transpose(garbage) / 255.0

    return X, Y, y
def create_augmented_dataset(X, y):
    """
    Creates an augmented dataset, by appying random transformations in each datum.
    The transformed images are then concatenated to the originals, thus extending the original dataset size.
    One-hot representations and true labels of the generated images are also added in the Y and y matrices.

    :param X: Training data.
    :param y: Data true labels

    :return: Extended training data, one-hot representations and true labels.
    """
    X_augmented = np.copy(X)
    y_augmented = y.copy()

    from keras.preprocessing.image import ImageDataGenerator
    datagen = ImageDataGenerator(rotation_range=90,
                                 width_shift_range=0.1,
                                 height_shift_range=0.1,
                                 horizontal_flip=True)

    data = np.ndarray(shape=(X.shape[1], 32, 32, 3))
    for datum in range(X.shape[1]):
        data[datum, :] = X[:, datum].reshape(3, 32, 32).transpose(1, 2, 0)

    cnt = 0
    augmented = np.ndarray(shape=(3072, X.shape[1]))
    labels = []
    for X_batch, y_batch in datagen.flow(data, y, batch_size=1):
        augmented[:, cnt] = X_batch[0].transpose(2, 1, 0).reshape(3072)
        labels.append(y_batch[0])
        cnt += 1
        if cnt == X.shape[1]:
            break

    X_augmented = np.copy(np.concatenate((X_augmented, augmented), axis=1))
    y_augmented.extend(labels)
    Y_augmented = make_class_categorical(y_augmented, 10).T

    return X_augmented, Y_augmented, y_augmented