def test_raises_on_bad_crop_sizes(self, crop_height, crop_width): train_client_spec = client_spec.ClientSpec(num_epochs=2, batch_size=10, max_elements=3, shuffle_buffer_size=5) with self.assertRaisesRegex( ValueError, 'The crop_height and crop_width ' 'must be between 1 and 32.'): image_classification_tasks.create_image_classification_task( train_client_spec, model_id='resnet18', crop_height=crop_height, crop_width=crop_width, use_synthetic_data=True)
def test_constructs_with_no_eval_client_spec(self): train_client_spec = client_spec.ClientSpec(num_epochs=2, batch_size=10, max_elements=3, shuffle_buffer_size=5) baseline_task_spec = image_classification_tasks.create_image_classification_task( train_client_spec, model_id='resnet18', use_synthetic_data=True) self.assertIsInstance(baseline_task_spec, baseline_task.BaselineTask)
def test_constructs_with_different_models(self, model_id): train_client_spec = client_spec.ClientSpec(num_epochs=2, batch_size=10, max_elements=3, shuffle_buffer_size=5) baseline_task_spec = image_classification_tasks.create_image_classification_task( train_client_spec, model_id=model_id, crop_height=3, crop_width=3, use_synthetic_data=True) self.assertIsInstance(baseline_task_spec, baseline_task.BaselineTask)
def test_no_train_distortion_gives_deterministic_result(self): train_client_spec = client_spec.ClientSpec(num_epochs=1, batch_size=1, max_elements=1, shuffle_buffer_size=1) task = image_classification_tasks.create_image_classification_task( train_client_spec, model_id='resnet18', distort_train_images=False, use_synthetic_data=True) train_preprocess_fn = task.datasets.train_preprocess_fn dataset = task.datasets.train_data.create_tf_dataset_from_all_clients() tf.random.set_seed(0) example1 = next(iter(train_preprocess_fn(dataset))) tf.random.set_seed(1) example2 = next(iter(train_preprocess_fn(dataset))) self.assertAllClose(example1, example2)