def run(method, output_dir, fake_data=False, fake_training=False): """Trains a model and records its predictions on configured datasets. Args: method: Modeling method to experiment with. output_dir: Directory to record the trained model and output stats. fake_data: If true, use fake data. fake_training: If true, train for a trivial number of steps. Returns: Trained Keras model. """ tf.io.gfile.makedirs(output_dir) model_opts = hparams_lib.model_opts_from_hparams( hparams_lib.HPS_DICT[method], method, fake_training=fake_training) if fake_training: model_opts.batch_size = 32 model_opts.examples_per_epoch = 256 model_opts.train_epochs = 1 experiment_utils.record_config(model_opts, output_dir + '/model_options.json') dataset_train = data_lib.build_dataset(image_data_utils.DATA_CONFIG_TRAIN, is_training=True, fake_data=fake_data) dataset_test = data_lib.build_dataset(image_data_utils.DATA_CONFIG_TEST, fake_data=fake_data) model = models_lib.build_and_train(model_opts, dataset_train, dataset_test, output_dir) logging.info('Saving model to output_dir.') model.save_weights(output_dir + '/model.ckpt') # TODO(yovadia): Figure out why save_model() wants to serialize ModelOptions. return model
def test_static_cifar_c(self, split): if not flags.FLAGS.fake_data: config = image_data_utils.DataConfig( split, corruption_static=True, corruption_level=3, corruption_type='pixelate') if split in ['train', 'valid']: with self.assertRaises(ValueError): data_lib.build_dataset(config) else: dataset = data_lib.build_dataset(config) image_shape = next(iter(dataset))[0].numpy().shape self.assertEqual(image_shape, data_lib.CIFAR_SHAPE)
def run(dataset_name, model_dir, predictions_per_example, max_examples, output_dir, fake_data=False): """Runs predictions on the given dataset using the specified model.""" tf.io.gfile.makedirs(output_dir) data_config = image_data_utils.get_data_config(dataset_name) dataset = data_lib.build_dataset(data_config, fake_data=fake_data) if max_examples: dataset = dataset.take(max_examples) model = models_lib.load_model(model_dir) logging.info('Starting predictions.') predictions = experiment_utils.make_predictions(model, dataset.batch(_BATCH_SIZE), predictions_per_example) logging.info('Done computing predictions; recording results to disk.') array_utils.write_npz(output_dir, 'predictions_%s.npz' % dataset_name, predictions) del predictions['logits_samples'] array_utils.write_npz(output_dir, 'predictions_small_%s.npz' % dataset_name, predictions)
def test_value_cifar_c(self, split): if not flags.FLAGS.fake_data: config = image_data_utils.DataConfig( split, corruption_value=.25, corruption_type='brightness') dataset = data_lib.build_dataset(config) image_shape = next(iter(dataset))[0].numpy().shape self.assertEqual(image_shape, data_lib.CIFAR_SHAPE)
def test_array_cifar_c(self, split): if not flags.FLAGS.fake_data: config = image_data_utils.DataConfig( split, corruption_level=4, corruption_type='glass_blur') dataset = data_lib.build_dataset(config) image_shape = next(iter(dataset))[0].numpy().shape self.assertEqual(image_shape, data_lib.CIFAR_SHAPE)
def test_roll_pixels(self, split): config = image_data_utils.DataConfig(split, roll_pixels=5) if not flags.FLAGS.fake_data: dataset = data_lib.build_dataset(config) image_shape = next(iter(dataset))[0].numpy().shape self.assertEqual(image_shape, data_lib.CIFAR_SHAPE)
def test_fake_data(self, split): # config is ignored for fake data config = image_data_utils.DataConfig(split) dataset = data_lib.build_dataset(config, fake_data=True) image_shape = next(iter(dataset))[0].numpy().shape self.assertEqual(image_shape, data_lib.CIFAR_SHAPE)