def test_static_imagenet_c(self, split): if not flags.FLAGS.fake_data: config = image_data_utils.DataConfig( split, corruption_static=True, corruption_level=3, corruption_type='pixelate') if split in ['train', 'valid']: with self.assertRaises(ValueError): data_lib.build_dataset(config, BATCH_SIZE) else: dataset = data_lib.build_dataset(config, BATCH_SIZE) image = next(iter(dataset))[0].numpy() self.assertEqual(image.shape, BATCHED_IMAGES_SHAPE) self.assertAllInRange(image, 0., 1.) self.assertTrue((image > 1./255).any())
def run(dataset_name, model_dir, batch_size, predictions_per_example, max_examples, output_dir, fake_data=False): """Runs predictions on the given dataset using the specified model.""" gfile.makedirs(output_dir) data_config = image_data_utils.get_data_config(dataset_name) dataset = data_lib.build_dataset(data_config, batch_size, fake_data=fake_data) if max_examples: dataset = dataset.take(max_examples) model_opts = experiment_utils.load_config(model_dir + '/model_options.json') model_opts = models_lib.ModelOptions(**model_opts) logging.info('Loaded model options: %s', model_opts) model = models_lib.build_model(model_opts) logging.info('Loading model weights...') model.load_weights(model_dir + '/model.ckpt') logging.info('done loading model weights.') writer = array_utils.StatsWriter( os.path.join(output_dir, 'predictions_%s' % dataset_name)) writer_small = array_utils.StatsWriter( os.path.join(output_dir, 'predictions_small_%s' % dataset_name)) writers = {'full': writer, 'small': writer_small} max_batches = 50000 // batch_size experiment_utils.make_predictions(model, dataset, predictions_per_example, writers, max_batches)
def test_roll_pixels(self, split): config = image_data_utils.DataConfig(split, roll_pixels=5) if not flags.FLAGS.fake_data: dataset = data_lib.build_dataset(config, BATCH_SIZE) image = next(iter(dataset))[0].numpy() self.assertEqual(image.shape, BATCHED_IMAGES_SHAPE) self.assertAllInRange(image, 0., 1.) self.assertTrue((image > 1. / 255).any())
def test_fake_data(self, split): # config is ignored for fake data config = image_data_utils.DataConfig(split) dataset = data_lib.build_dataset(config, BATCH_SIZE, fake_data=True) image = next(iter(dataset))[0].numpy() self.assertEqual(image.shape, BATCHED_IMAGES_SHAPE) self.assertAllInRange(image, 0., 1.) self.assertTrue((image > 1. / 255).any())
def test_alt_dataset(self): if not flags.FLAGS.fake_data: config = image_data_utils.DataConfig('test', alt_dataset_name='celeb_a') dataset = data_lib.build_dataset(config, BATCH_SIZE) image = next(iter(dataset))[0].numpy() self.assertEqual(image.shape, BATCHED_IMAGES_SHAPE) self.assertAllInRange(image, 0., 1.) self.assertTrue((image > 1./255).any())
def test_value_imagenet_c(self, split): if not flags.FLAGS.fake_data: config = image_data_utils.DataConfig( split, corruption_value=.25, corruption_type='brightness') dataset = data_lib.build_dataset(config, BATCH_SIZE) image = next(iter(dataset))[0].numpy() self.assertEqual(image.shape, BATCHED_IMAGES_SHAPE) self.assertAllInRange(image, 0., 1.) self.assertTrue((image > 1./255).any())
def run(method, output_dir, task_number, use_tpu, tpu, metrics, fake_data=False, fake_training=False): """Train a ResNet model on ImageNet.""" gfile.makedirs(output_dir) model_opts = hparams_lib.model_opts_from_hparams( hparams_lib.HPS_DICT[method], method, use_tpu, tpu, fake_training=fake_training) if fake_training: model_opts.batch_size = 32 model_opts.examples_per_epoch = 256 model_opts.train_epochs = 1 if use_tpu: logging.info('Use TPU at %s', model_opts.tpu if model_opts.tpu is not None else 'local') resolver = tf.contrib.cluster_resolver.TPUClusterResolver( tpu=model_opts.tpu) tf.contrib.distribute.initialize_tpu_system(resolver) strategy = tf.contrib.distribute.TPUStrategy(resolver) else: strategy = experiment_utils.get_distribution_strategy( distribution_strategy='default', num_gpus=model_opts.num_gpus, num_workers=model_opts.num_replicas, all_reduce_alg=None) logging.info('Use global batch size: %s.', model_opts.batch_size) logging.info('Use bfloat16: %s.', model_opts.use_bfloat16) experiment_utils.record_config(model_opts, output_dir + '/model_options.json') imagenet_train = data_lib.build_dataset( image_data_utils.DATA_CONFIG_TRAIN, batch_size=model_opts.batch_size, is_training=True, fake_data=fake_data, use_bfloat16=model_opts.use_bfloat16) imagenet_eval = data_lib.build_dataset( image_data_utils.DATA_CONFIG_TEST, batch_size=model_opts.batch_size, fake_data=fake_data, use_bfloat16=model_opts.use_bfloat16) if fake_training: model = models_lib.build_and_train(model_opts, imagenet_train, imagenet_eval, output_dir, metrics) else: with strategy.scope(): model = models_lib.build_and_train(model_opts, imagenet_train, imagenet_eval, output_dir, metrics) save_model(model, output_dir, method, use_tpu, task_number)