Exemple #1
0
def train_mnist(features_mnist_train=gin.REQUIRED,
                features_mnist_test=gin.REQUIRED,
                data_path_mnist=gin.REQUIRED,
                features_notmnist_test=gin.REQUIRED,
                labels_notmnist_test=gin.REQUIRED,
                model_dir=gin.REQUIRED,
                dim_input=gin.REQUIRED,
                num_classes=gin.REQUIRED):
    """Training function."""
    # Define the paths
    data_path_mnist = os.path.join(FLAGS.baseroute, data_path_mnist)
    features_mnist_train = os.path.join(FLAGS.baseroute, features_mnist_train)
    features_mnist_test = os.path.join(FLAGS.baseroute, features_mnist_test)
    features_notmnist_test = os.path.join(FLAGS.baseroute,
                                          features_notmnist_test)
    labels_notmnist_test = os.path.join(FLAGS.baseroute, labels_notmnist_test)
    model_dir = os.path.join(FLAGS.baseroute, model_dir)
    # Load the features for mnist
    (_, y_train), (_, y_test) = mnist_input.load_data(data_path_mnist)
    with tf.gfile.Open(features_mnist_train, 'r') as f:
        x_train = np.load(f)
    with tf.gfile.Open(features_mnist_test, 'r') as f:
        x_test = np.load(f)
    dataset = (x_train, y_train), (x_test, y_test)

    model = build_model(dataset, model_dir, dim_input, num_classes)

    # Load notmnist features and labels test dataset
    with tf.gfile.Open(features_notmnist_test, 'r') as f:
        x_notmnist = np.load(f)
    with tf.gfile.Open(labels_notmnist_test, 'r') as f:
        y_notmnist = np.load(f)
    y_notmnist = tf.keras.utils.to_categorical(y_notmnist, num_classes=10)

    # Compute output probabilities on x_test and x_test_notmnist
    # Warning: take time, pass over all the saved weights.
    x = np.vstack((x_test, x_notmnist))
    model.predict(x)

    # Dictionary for y - Metrics
    y_dic = {'mnist': y_test, 'notmnist': y_notmnist}

    # Postprocessing and metrics
    if FLAGS.algorithm in ['simple', 'dropout', 'precond']:
        postprocess.postprocess_mnist(FLAGS.workdir)
        for dataset_str in ['mnist', 'notmnist']:
            path_postprocess = os.path.join(FLAGS.workdir, dataset_str)
            metrics.Metrics(y_dic[dataset_str], path_postprocess)

    # Write the gin config in the working directory
    util.save_gin(os.path.join(FLAGS.workdir, 'gin_configuration.txt'))
    util.write_gin(FLAGS.workdir)
Exemple #2
0
def train_cifar100(features_cifar100_train=gin.REQUIRED,
                   features_cifar100_test=gin.REQUIRED,
                   data_path_cifar100=gin.REQUIRED,
                   features_cifar100_train_distorted=gin.REQUIRED,
                   distorted=gin.REQUIRED,
                   model_dir=gin.REQUIRED,
                   dim_input=gin.REQUIRED,
                   num_classes=gin.REQUIRED):
    """Training function."""
    # Define the paths
    features_cifar100_train = os.path.join(FLAGS.baseroute,
                                           features_cifar100_train)
    features_cifar100_test = os.path.join(FLAGS.baseroute,
                                          features_cifar100_test)
    data_path_cifar100 = os.path.join(FLAGS.baseroute, data_path_cifar100)
    features_cifar100_train_distorted = os.path.join(
        FLAGS.baseroute, features_cifar100_train_distorted)
    model_dir = os.path.join(FLAGS.baseroute, model_dir)
    # Load the features for cifar100
    (_, y_train), (_, y_test) = cifar_input.load_data(distorted,
                                                      data_path_cifar100,
                                                      'cifar100')
    if distorted:
        with tf.gfile.Open(features_cifar100_train_distorted, 'r') as f:
            x_train = np.load(f)
    else:
        with tf.gfile.Open(features_cifar100_train, 'r') as f:
            x_train = np.load(f)
    with tf.gfile.Open(features_cifar100_test, 'r') as f:
        x_test = np.load(f)
    dataset = (x_train, y_train), (x_test, y_test)

    model = build_model(dataset, model_dir, dim_input, num_classes)

    # Compute output probabilities on x_test
    # Warning: take time, pass over all the saved weights.
    model.predict(x_test)

    # Postprocessing and metrics
    if FLAGS.algorithm in ['simple', 'dropout', 'precond']:
        postprocess.postprocess_cifar(FLAGS.workdir, 'cifar100')
        path_postprocess = os.path.join(FLAGS.workdir, 'cifar100')
        metrics.Metrics(y_test, path_postprocess)

    # Write the gin config in the working directory
    util.save_gin(os.path.join(FLAGS.workdir, 'gin_configuration.txt'))
    util.write_gin(FLAGS.workdir)