コード例 #1
0
 def test_raises_on_bad_crop_sizes(self, crop_height, crop_width):
     train_client_spec = client_spec.ClientSpec(num_epochs=2,
                                                batch_size=10,
                                                max_elements=3,
                                                shuffle_buffer_size=5)
     with self.assertRaisesRegex(
             ValueError, 'The crop_height and crop_width '
             'must be between 1 and 32.'):
         image_classification_tasks.create_image_classification_task(
             train_client_spec,
             model_id='resnet18',
             crop_height=crop_height,
             crop_width=crop_width,
             use_synthetic_data=True)
コード例 #2
0
 def test_constructs_with_no_eval_client_spec(self):
     train_client_spec = client_spec.ClientSpec(num_epochs=2,
                                                batch_size=10,
                                                max_elements=3,
                                                shuffle_buffer_size=5)
     baseline_task_spec = image_classification_tasks.create_image_classification_task(
         train_client_spec, model_id='resnet18', use_synthetic_data=True)
     self.assertIsInstance(baseline_task_spec, baseline_task.BaselineTask)
コード例 #3
0
 def test_constructs_with_different_models(self, model_id):
     train_client_spec = client_spec.ClientSpec(num_epochs=2,
                                                batch_size=10,
                                                max_elements=3,
                                                shuffle_buffer_size=5)
     baseline_task_spec = image_classification_tasks.create_image_classification_task(
         train_client_spec,
         model_id=model_id,
         crop_height=3,
         crop_width=3,
         use_synthetic_data=True)
     self.assertIsInstance(baseline_task_spec, baseline_task.BaselineTask)
コード例 #4
0
 def test_no_train_distortion_gives_deterministic_result(self):
     train_client_spec = client_spec.ClientSpec(num_epochs=1,
                                                batch_size=1,
                                                max_elements=1,
                                                shuffle_buffer_size=1)
     task = image_classification_tasks.create_image_classification_task(
         train_client_spec,
         model_id='resnet18',
         distort_train_images=False,
         use_synthetic_data=True)
     train_preprocess_fn = task.datasets.train_preprocess_fn
     dataset = task.datasets.train_data.create_tf_dataset_from_all_clients()
     tf.random.set_seed(0)
     example1 = next(iter(train_preprocess_fn(dataset)))
     tf.random.set_seed(1)
     example2 = next(iter(train_preprocess_fn(dataset)))
     self.assertAllClose(example1, example2)