def main(argv): """Main function.""" del argv dataset = load_data(distorted=True, data_path=DATA_PATH, dataset='cifar100') (x_train_distorted, _), (x_test_distorted, _) = dataset dataset = load_data(distorted=False, data_path=DATA_PATH, dataset='cifar100') (x_train, _), (x_test, _) = dataset model = NetworkCifar() features = model.features(x_train) str_file = 'features_ll_train.npy' data_path = os.path.join(PATH_MODEL, str_file) with tf.gfile.Open(data_path, 'wb') as f: np.save(f, features) features = model.features(x_train_distorted) str_file = 'features_ll_train_distorted.npy' data_path = os.path.join(PATH_MODEL, str_file) with tf.gfile.Open(data_path, 'wb') as f: np.save(f, features) features = model.features(x_test) str_file = 'features_ll_test.npy' data_path = os.path.join(PATH_MODEL, str_file) with tf.gfile.Open(data_path, 'wb') as f: np.save(f, features) features = model.features(x_test_distorted) str_file = 'features_ll_test_distorted.npy' data_path = os.path.join(PATH_MODEL, str_file) with tf.gfile.Open(data_path, 'wb') as f: np.save(f, features)
def train_cifar100(features_cifar100_train=gin.REQUIRED, features_cifar100_test=gin.REQUIRED, data_path_cifar100=gin.REQUIRED, features_cifar100_train_distorted=gin.REQUIRED, distorted=gin.REQUIRED, model_dir=gin.REQUIRED, dim_input=gin.REQUIRED, num_classes=gin.REQUIRED): """Training function.""" # Define the paths features_cifar100_train = os.path.join(FLAGS.baseroute, features_cifar100_train) features_cifar100_test = os.path.join(FLAGS.baseroute, features_cifar100_test) data_path_cifar100 = os.path.join(FLAGS.baseroute, data_path_cifar100) features_cifar100_train_distorted = os.path.join( FLAGS.baseroute, features_cifar100_train_distorted) model_dir = os.path.join(FLAGS.baseroute, model_dir) # Load the features for cifar100 (_, y_train), (_, y_test) = cifar_input.load_data(distorted, data_path_cifar100, 'cifar100') if distorted: with tf.gfile.Open(features_cifar100_train_distorted, 'r') as f: x_train = np.load(f) else: with tf.gfile.Open(features_cifar100_train, 'r') as f: x_train = np.load(f) with tf.gfile.Open(features_cifar100_test, 'r') as f: x_test = np.load(f) dataset = (x_train, y_train), (x_test, y_test) model = build_model(dataset, model_dir, dim_input, num_classes) # Compute output probabilities on x_test # Warning: take time, pass over all the saved weights. model.predict(x_test) # Postprocessing and metrics if FLAGS.algorithm in ['simple', 'dropout', 'precond']: postprocess.postprocess_cifar(FLAGS.workdir, 'cifar100') path_postprocess = os.path.join(FLAGS.workdir, 'cifar100') metrics.Metrics(y_test, path_postprocess) # Write the gin config in the working directory util.save_gin(os.path.join(FLAGS.workdir, 'gin_configuration.txt')) util.write_gin(FLAGS.workdir)