def main(argv):
    """Main function."""

    del argv

    dataset = load_data(distorted=True,
                        data_path=DATA_PATH,
                        dataset='cifar100')
    (x_train_distorted, _), (x_test_distorted, _) = dataset
    dataset = load_data(distorted=False,
                        data_path=DATA_PATH,
                        dataset='cifar100')
    (x_train, _), (x_test, _) = dataset
    model = NetworkCifar()
    features = model.features(x_train)
    str_file = 'features_ll_train.npy'
    data_path = os.path.join(PATH_MODEL, str_file)
    with tf.gfile.Open(data_path, 'wb') as f:
        np.save(f, features)
    features = model.features(x_train_distorted)
    str_file = 'features_ll_train_distorted.npy'
    data_path = os.path.join(PATH_MODEL, str_file)
    with tf.gfile.Open(data_path, 'wb') as f:
        np.save(f, features)
    features = model.features(x_test)
    str_file = 'features_ll_test.npy'
    data_path = os.path.join(PATH_MODEL, str_file)
    with tf.gfile.Open(data_path, 'wb') as f:
        np.save(f, features)
    features = model.features(x_test_distorted)
    str_file = 'features_ll_test_distorted.npy'
    data_path = os.path.join(PATH_MODEL, str_file)
    with tf.gfile.Open(data_path, 'wb') as f:
        np.save(f, features)
示例#2
0
def train_cifar100(features_cifar100_train=gin.REQUIRED,
                   features_cifar100_test=gin.REQUIRED,
                   data_path_cifar100=gin.REQUIRED,
                   features_cifar100_train_distorted=gin.REQUIRED,
                   distorted=gin.REQUIRED,
                   model_dir=gin.REQUIRED,
                   dim_input=gin.REQUIRED,
                   num_classes=gin.REQUIRED):
    """Training function."""
    # Define the paths
    features_cifar100_train = os.path.join(FLAGS.baseroute,
                                           features_cifar100_train)
    features_cifar100_test = os.path.join(FLAGS.baseroute,
                                          features_cifar100_test)
    data_path_cifar100 = os.path.join(FLAGS.baseroute, data_path_cifar100)
    features_cifar100_train_distorted = os.path.join(
        FLAGS.baseroute, features_cifar100_train_distorted)
    model_dir = os.path.join(FLAGS.baseroute, model_dir)
    # Load the features for cifar100
    (_, y_train), (_, y_test) = cifar_input.load_data(distorted,
                                                      data_path_cifar100,
                                                      'cifar100')
    if distorted:
        with tf.gfile.Open(features_cifar100_train_distorted, 'r') as f:
            x_train = np.load(f)
    else:
        with tf.gfile.Open(features_cifar100_train, 'r') as f:
            x_train = np.load(f)
    with tf.gfile.Open(features_cifar100_test, 'r') as f:
        x_test = np.load(f)
    dataset = (x_train, y_train), (x_test, y_test)

    model = build_model(dataset, model_dir, dim_input, num_classes)

    # Compute output probabilities on x_test
    # Warning: take time, pass over all the saved weights.
    model.predict(x_test)

    # Postprocessing and metrics
    if FLAGS.algorithm in ['simple', 'dropout', 'precond']:
        postprocess.postprocess_cifar(FLAGS.workdir, 'cifar100')
        path_postprocess = os.path.join(FLAGS.workdir, 'cifar100')
        metrics.Metrics(y_test, path_postprocess)

    # Write the gin config in the working directory
    util.save_gin(os.path.join(FLAGS.workdir, 'gin_configuration.txt'))
    util.write_gin(FLAGS.workdir)