コード例 #1
0
 def test_static_imagenet_c(self, split):
   if not flags.FLAGS.fake_data:
     config = image_data_utils.DataConfig(
         split, corruption_static=True, corruption_level=3,
         corruption_type='pixelate')
     if split in ['train', 'valid']:
       with self.assertRaises(ValueError):
         data_lib.build_dataset(config, BATCH_SIZE)
     else:
       dataset = data_lib.build_dataset(config, BATCH_SIZE)
       image = next(iter(dataset))[0].numpy()
       self.assertEqual(image.shape, BATCHED_IMAGES_SHAPE)
       self.assertAllInRange(image, 0., 1.)
       self.assertTrue((image > 1./255).any())
コード例 #2
0
def run(dataset_name,
        model_dir,
        batch_size,
        predictions_per_example,
        max_examples,
        output_dir,
        fake_data=False):
    """Runs predictions on the given dataset using the specified model."""
    gfile.makedirs(output_dir)
    data_config = image_data_utils.get_data_config(dataset_name)
    dataset = data_lib.build_dataset(data_config,
                                     batch_size,
                                     fake_data=fake_data)
    if max_examples:
        dataset = dataset.take(max_examples)

    model_opts = experiment_utils.load_config(model_dir +
                                              '/model_options.json')
    model_opts = models_lib.ModelOptions(**model_opts)
    logging.info('Loaded model options: %s', model_opts)

    model = models_lib.build_model(model_opts)
    logging.info('Loading model weights...')
    model.load_weights(model_dir + '/model.ckpt')
    logging.info('done loading model weights.')

    writer = array_utils.StatsWriter(
        os.path.join(output_dir, 'predictions_%s' % dataset_name))
    writer_small = array_utils.StatsWriter(
        os.path.join(output_dir, 'predictions_small_%s' % dataset_name))

    writers = {'full': writer, 'small': writer_small}
    max_batches = 50000 // batch_size
    experiment_utils.make_predictions(model, dataset, predictions_per_example,
                                      writers, max_batches)
コード例 #3
0
 def test_roll_pixels(self, split):
     config = image_data_utils.DataConfig(split, roll_pixels=5)
     if not flags.FLAGS.fake_data:
         dataset = data_lib.build_dataset(config, BATCH_SIZE)
         image = next(iter(dataset))[0].numpy()
         self.assertEqual(image.shape, BATCHED_IMAGES_SHAPE)
         self.assertAllInRange(image, 0., 1.)
         self.assertTrue((image > 1. / 255).any())
コード例 #4
0
 def test_fake_data(self, split):
     # config is ignored for fake data
     config = image_data_utils.DataConfig(split)
     dataset = data_lib.build_dataset(config, BATCH_SIZE, fake_data=True)
     image = next(iter(dataset))[0].numpy()
     self.assertEqual(image.shape, BATCHED_IMAGES_SHAPE)
     self.assertAllInRange(image, 0., 1.)
     self.assertTrue((image > 1. / 255).any())
コード例 #5
0
 def test_alt_dataset(self):
   if not flags.FLAGS.fake_data:
     config = image_data_utils.DataConfig('test', alt_dataset_name='celeb_a')
     dataset = data_lib.build_dataset(config, BATCH_SIZE)
     image = next(iter(dataset))[0].numpy()
     self.assertEqual(image.shape, BATCHED_IMAGES_SHAPE)
     self.assertAllInRange(image, 0., 1.)
     self.assertTrue((image > 1./255).any())
コード例 #6
0
 def test_value_imagenet_c(self, split):
   if not flags.FLAGS.fake_data:
     config = image_data_utils.DataConfig(
         split, corruption_value=.25, corruption_type='brightness')
     dataset = data_lib.build_dataset(config, BATCH_SIZE)
     image = next(iter(dataset))[0].numpy()
     self.assertEqual(image.shape, BATCHED_IMAGES_SHAPE)
     self.assertAllInRange(image, 0., 1.)
     self.assertTrue((image > 1./255).any())
コード例 #7
0
def run(method,
        output_dir,
        task_number,
        use_tpu,
        tpu,
        metrics,
        fake_data=False,
        fake_training=False):
    """Train a ResNet model on ImageNet."""
    gfile.makedirs(output_dir)
    model_opts = hparams_lib.model_opts_from_hparams(
        hparams_lib.HPS_DICT[method],
        method,
        use_tpu,
        tpu,
        fake_training=fake_training)
    if fake_training:
        model_opts.batch_size = 32
        model_opts.examples_per_epoch = 256
        model_opts.train_epochs = 1

    if use_tpu:
        logging.info('Use TPU at %s',
                     model_opts.tpu if model_opts.tpu is not None else 'local')
        resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
            tpu=model_opts.tpu)
        tf.contrib.distribute.initialize_tpu_system(resolver)
        strategy = tf.contrib.distribute.TPUStrategy(resolver)
    else:
        strategy = experiment_utils.get_distribution_strategy(
            distribution_strategy='default',
            num_gpus=model_opts.num_gpus,
            num_workers=model_opts.num_replicas,
            all_reduce_alg=None)

    logging.info('Use global batch size: %s.', model_opts.batch_size)
    logging.info('Use bfloat16: %s.', model_opts.use_bfloat16)
    experiment_utils.record_config(model_opts,
                                   output_dir + '/model_options.json')

    imagenet_train = data_lib.build_dataset(
        image_data_utils.DATA_CONFIG_TRAIN,
        batch_size=model_opts.batch_size,
        is_training=True,
        fake_data=fake_data,
        use_bfloat16=model_opts.use_bfloat16)

    imagenet_eval = data_lib.build_dataset(
        image_data_utils.DATA_CONFIG_TEST,
        batch_size=model_opts.batch_size,
        fake_data=fake_data,
        use_bfloat16=model_opts.use_bfloat16)

    if fake_training:
        model = models_lib.build_and_train(model_opts, imagenet_train,
                                           imagenet_eval, output_dir, metrics)
    else:
        with strategy.scope():
            model = models_lib.build_and_train(model_opts, imagenet_train,
                                               imagenet_eval, output_dir,
                                               metrics)

    save_model(model, output_dir, method, use_tpu, task_number)