Ejemplo n.º 1
0
def run(method, output_dir, fake_data=False, fake_training=False):
    """Trains a model and records its predictions on configured datasets.

  Args:
    method: Modeling method to experiment with.
    output_dir: Directory to record the trained model and output stats.
    fake_data: If true, use fake data.
    fake_training: If true, train for a trivial number of steps.
  Returns:
    Trained Keras model.
  """
    tf.io.gfile.makedirs(output_dir)
    model_opts = hparams_lib.model_opts_from_hparams(
        hparams_lib.HPS_DICT[method], method, fake_training=fake_training)
    if fake_training:
        model_opts.batch_size = 32
        model_opts.examples_per_epoch = 256
        model_opts.train_epochs = 1

    experiment_utils.record_config(model_opts,
                                   output_dir + '/model_options.json')

    dataset_train = data_lib.build_dataset(image_data_utils.DATA_CONFIG_TRAIN,
                                           is_training=True,
                                           fake_data=fake_data)
    dataset_test = data_lib.build_dataset(image_data_utils.DATA_CONFIG_TEST,
                                          fake_data=fake_data)
    model = models_lib.build_and_train(model_opts, dataset_train, dataset_test,
                                       output_dir)

    logging.info('Saving model to output_dir.')
    model.save_weights(output_dir + '/model.ckpt')
    # TODO(yovadia): Figure out why save_model() wants to serialize ModelOptions.
    return model
Ejemplo n.º 2
0
 def test_static_cifar_c(self, split):
   if not flags.FLAGS.fake_data:
     config = image_data_utils.DataConfig(
         split, corruption_static=True, corruption_level=3,
         corruption_type='pixelate')
     if split in ['train', 'valid']:
       with self.assertRaises(ValueError):
         data_lib.build_dataset(config)
     else:
       dataset = data_lib.build_dataset(config)
       image_shape = next(iter(dataset))[0].numpy().shape
       self.assertEqual(image_shape, data_lib.CIFAR_SHAPE)
Ejemplo n.º 3
0
def run(dataset_name,
        model_dir,
        predictions_per_example,
        max_examples,
        output_dir,
        fake_data=False):
    """Runs predictions on the given dataset using the specified model."""
    tf.io.gfile.makedirs(output_dir)
    data_config = image_data_utils.get_data_config(dataset_name)
    dataset = data_lib.build_dataset(data_config, fake_data=fake_data)
    if max_examples:
        dataset = dataset.take(max_examples)

    model = models_lib.load_model(model_dir)
    logging.info('Starting predictions.')
    predictions = experiment_utils.make_predictions(model,
                                                    dataset.batch(_BATCH_SIZE),
                                                    predictions_per_example)

    logging.info('Done computing predictions; recording results to disk.')
    array_utils.write_npz(output_dir, 'predictions_%s.npz' % dataset_name,
                          predictions)
    del predictions['logits_samples']
    array_utils.write_npz(output_dir,
                          'predictions_small_%s.npz' % dataset_name,
                          predictions)
Ejemplo n.º 4
0
 def test_value_cifar_c(self, split):
   if not flags.FLAGS.fake_data:
     config = image_data_utils.DataConfig(
         split, corruption_value=.25, corruption_type='brightness')
     dataset = data_lib.build_dataset(config)
     image_shape = next(iter(dataset))[0].numpy().shape
     self.assertEqual(image_shape, data_lib.CIFAR_SHAPE)
Ejemplo n.º 5
0
 def test_array_cifar_c(self, split):
   if not flags.FLAGS.fake_data:
     config = image_data_utils.DataConfig(
         split, corruption_level=4, corruption_type='glass_blur')
     dataset = data_lib.build_dataset(config)
     image_shape = next(iter(dataset))[0].numpy().shape
     self.assertEqual(image_shape, data_lib.CIFAR_SHAPE)
Ejemplo n.º 6
0
 def test_roll_pixels(self, split):
     config = image_data_utils.DataConfig(split, roll_pixels=5)
     if not flags.FLAGS.fake_data:
         dataset = data_lib.build_dataset(config)
         image_shape = next(iter(dataset))[0].numpy().shape
         self.assertEqual(image_shape, data_lib.CIFAR_SHAPE)
Ejemplo n.º 7
0
 def test_fake_data(self, split):
     # config is ignored for fake data
     config = image_data_utils.DataConfig(split)
     dataset = data_lib.build_dataset(config, fake_data=True)
     image_shape = next(iter(dataset))[0].numpy().shape
     self.assertEqual(image_shape, data_lib.CIFAR_SHAPE)