def test_end2end_inception_v3_warm_up_from(self):
     """End-to-end test of model_train script."""
     checkpoint_dir = tf_test_utils.test_tmpdir('inception_v3_warm_up_from')
     tf_test_utils.write_fake_checkpoint('inception_v3',
                                         self.test_session(),
                                         checkpoint_dir)
     self._run_tiny_training(
         model_name='inception_v3',
         dataset=data_providers_test.make_golden_dataset(
             use_tpu=FLAGS.use_tpu),
         warm_start_from=checkpoint_dir + '/model')
Exemplo n.º 2
0
 def test_end2end_inception_v3_warm_up_allow_different_num_channels(self):
   """End-to-end test of model_train script."""
   FLAGS.allow_warmstart_from_different_num_channels = True
   checkpoint_dir = tf_test_utils.test_tmpdir(
       'inception_v3_warm_up_allow_different_num_channels')
   tf_test_utils.write_fake_checkpoint(
       'inception_v3',
       self.test_session(),
       checkpoint_dir,
       num_channels=dv_constants.PILEUP_NUM_CHANNELS + 1)
   self._run_tiny_training(
       model_name='inception_v3',
       dataset=data_providers_test.make_golden_dataset(use_tpu=FLAGS.use_tpu),
       warm_start_from=checkpoint_dir + '/model')
Exemplo n.º 3
0
 def test_end2end_inception_v3_warm_up_from_mobilenet_v1(self):
   """Tests the behavior when warm start from mobilenet but train inception."""
   checkpoint_dir = tf_test_utils.test_tmpdir(
       'inception_v3_warm_up_from_mobilenet_v1')
   tf_test_utils.write_fake_checkpoint('mobilenet_v1', self.test_session(),
                                       checkpoint_dir)
   self.assertTrue(
       tf_test_utils.check_equals_checkpoint_top_scopes(
           checkpoint_dir + '/model', ['MobilenetV1', 'global_step']))
   self._run_tiny_training(
       model_name='inception_v3',
       dataset=data_providers_test.make_golden_dataset(use_tpu=FLAGS.use_tpu),
       warm_start_from=checkpoint_dir + '/model')
   self.assertTrue(
       tf_test_utils.check_equals_checkpoint_top_scopes(
           FLAGS.train_dir + '/model.ckpt-1', ['InceptionV3', 'global_step']))
Exemplo n.º 4
0
 def test_end2end_inception_v3_warm_up_by_default_fail_diff_num_channels(self):
   """End-to-end test of model_train script."""
   checkpoint_dir = tf_test_utils.test_tmpdir(
       'test_end2end_inception_v3_warm_up_by_default_fail_diff_num_channels')
   tf_test_utils.write_fake_checkpoint(
       'inception_v3',
       self.test_session(),
       checkpoint_dir,
       num_channels=dv_constants.PILEUP_NUM_CHANNELS + 1)
   with self.assertRaisesRegex(
       ValueError,
       r'Shape of variable InceptionV3/Conv2d_1a_3x3/weights:0 \(\(.*\)\) '
       r'doesn\'t match with shape of tensor '
       r'InceptionV3/Conv2d_1a_3x3/weights \(\[.*\]\) from checkpoint reader.'
   ):
     self._run_tiny_training(
         model_name='inception_v3',
         dataset=data_providers_test.make_golden_dataset(
             use_tpu=FLAGS.use_tpu),
         warm_start_from=checkpoint_dir + '/model')
Exemplo n.º 5
0
 def _run_tiny_training(self, model_name, dataset, warm_start_from=''):
     """Runs one training step. This function always starts a new train_dir."""
     with mock.patch(
             'deepvariant.data_providers.'
             'get_input_fn_from_dataset') as mock_get_input_fn_from_dataset:
         mock_get_input_fn_from_dataset.return_value = dataset
         FLAGS.train_dir = tf_test_utils.test_tmpdir(uuid.uuid4().hex)
         FLAGS.batch_size = 2
         FLAGS.model_name = model_name
         FLAGS.save_interval_secs = -1
         FLAGS.save_interval_steps = 1
         FLAGS.number_of_steps = 1
         FLAGS.dataset_config_pbtxt = '/path/to/mock.pbtxt'
         FLAGS.start_from_checkpoint = warm_start_from
         FLAGS.master = ''
         model_train.parse_and_run()
         # We have a checkpoint after training.
         mock_get_input_fn_from_dataset.assert_called_once_with(
             dataset_config_filename=FLAGS.dataset_config_pbtxt,
             mode=tf.estimator.ModeKeys.TRAIN,
             use_tpu=mock.ANY,
         )
         self.assertIsNotNone(tf.train.latest_checkpoint(FLAGS.train_dir))