def train_mnist(features_mnist_train=gin.REQUIRED, features_mnist_test=gin.REQUIRED, data_path_mnist=gin.REQUIRED, features_notmnist_test=gin.REQUIRED, labels_notmnist_test=gin.REQUIRED, model_dir=gin.REQUIRED, dim_input=gin.REQUIRED, num_classes=gin.REQUIRED): """Training function.""" # Define the paths data_path_mnist = os.path.join(FLAGS.baseroute, data_path_mnist) features_mnist_train = os.path.join(FLAGS.baseroute, features_mnist_train) features_mnist_test = os.path.join(FLAGS.baseroute, features_mnist_test) features_notmnist_test = os.path.join(FLAGS.baseroute, features_notmnist_test) labels_notmnist_test = os.path.join(FLAGS.baseroute, labels_notmnist_test) model_dir = os.path.join(FLAGS.baseroute, model_dir) # Load the features for mnist (_, y_train), (_, y_test) = mnist_input.load_data(data_path_mnist) with tf.gfile.Open(features_mnist_train, 'r') as f: x_train = np.load(f) with tf.gfile.Open(features_mnist_test, 'r') as f: x_test = np.load(f) dataset = (x_train, y_train), (x_test, y_test) model = build_model(dataset, model_dir, dim_input, num_classes) # Load notmnist features and labels test dataset with tf.gfile.Open(features_notmnist_test, 'r') as f: x_notmnist = np.load(f) with tf.gfile.Open(labels_notmnist_test, 'r') as f: y_notmnist = np.load(f) y_notmnist = tf.keras.utils.to_categorical(y_notmnist, num_classes=10) # Compute output probabilities on x_test and x_test_notmnist # Warning: take time, pass over all the saved weights. x = np.vstack((x_test, x_notmnist)) model.predict(x) # Dictionary for y - Metrics y_dic = {'mnist': y_test, 'notmnist': y_notmnist} # Postprocessing and metrics if FLAGS.algorithm in ['simple', 'dropout', 'precond']: postprocess.postprocess_mnist(FLAGS.workdir) for dataset_str in ['mnist', 'notmnist']: path_postprocess = os.path.join(FLAGS.workdir, dataset_str) metrics.Metrics(y_dic[dataset_str], path_postprocess) # Write the gin config in the working directory util.save_gin(os.path.join(FLAGS.workdir, 'gin_configuration.txt')) util.write_gin(FLAGS.workdir)
def train_cifar100(features_cifar100_train=gin.REQUIRED, features_cifar100_test=gin.REQUIRED, data_path_cifar100=gin.REQUIRED, features_cifar100_train_distorted=gin.REQUIRED, distorted=gin.REQUIRED, model_dir=gin.REQUIRED, dim_input=gin.REQUIRED, num_classes=gin.REQUIRED): """Training function.""" # Define the paths features_cifar100_train = os.path.join(FLAGS.baseroute, features_cifar100_train) features_cifar100_test = os.path.join(FLAGS.baseroute, features_cifar100_test) data_path_cifar100 = os.path.join(FLAGS.baseroute, data_path_cifar100) features_cifar100_train_distorted = os.path.join( FLAGS.baseroute, features_cifar100_train_distorted) model_dir = os.path.join(FLAGS.baseroute, model_dir) # Load the features for cifar100 (_, y_train), (_, y_test) = cifar_input.load_data(distorted, data_path_cifar100, 'cifar100') if distorted: with tf.gfile.Open(features_cifar100_train_distorted, 'r') as f: x_train = np.load(f) else: with tf.gfile.Open(features_cifar100_train, 'r') as f: x_train = np.load(f) with tf.gfile.Open(features_cifar100_test, 'r') as f: x_test = np.load(f) dataset = (x_train, y_train), (x_test, y_test) model = build_model(dataset, model_dir, dim_input, num_classes) # Compute output probabilities on x_test # Warning: take time, pass over all the saved weights. model.predict(x_test) # Postprocessing and metrics if FLAGS.algorithm in ['simple', 'dropout', 'precond']: postprocess.postprocess_cifar(FLAGS.workdir, 'cifar100') path_postprocess = os.path.join(FLAGS.workdir, 'cifar100') metrics.Metrics(y_test, path_postprocess) # Write the gin config in the working directory util.save_gin(os.path.join(FLAGS.workdir, 'gin_configuration.txt')) util.write_gin(FLAGS.workdir)