def test_end2end_inception_v3_warm_up_from(self): """End-to-end test of model_train script.""" checkpoint_dir = tf_test_utils.test_tmpdir('inception_v3_warm_up_from') tf_test_utils.write_fake_checkpoint('inception_v3', self.test_session(), checkpoint_dir) self._run_tiny_training( model_name='inception_v3', dataset=data_providers_test.make_golden_dataset( use_tpu=FLAGS.use_tpu), warm_start_from=checkpoint_dir + '/model')
def test_end2end_inception_v3_warm_up_allow_different_num_channels(self): """End-to-end test of model_train script.""" FLAGS.allow_warmstart_from_different_num_channels = True checkpoint_dir = tf_test_utils.test_tmpdir( 'inception_v3_warm_up_allow_different_num_channels') tf_test_utils.write_fake_checkpoint( 'inception_v3', self.test_session(), checkpoint_dir, num_channels=dv_constants.PILEUP_NUM_CHANNELS + 1) self._run_tiny_training( model_name='inception_v3', dataset=data_providers_test.make_golden_dataset(use_tpu=FLAGS.use_tpu), warm_start_from=checkpoint_dir + '/model')
def test_end2end_inception_v3_warm_up_from_mobilenet_v1(self): """Tests the behavior when warm start from mobilenet but train inception.""" checkpoint_dir = tf_test_utils.test_tmpdir( 'inception_v3_warm_up_from_mobilenet_v1') tf_test_utils.write_fake_checkpoint('mobilenet_v1', self.test_session(), checkpoint_dir) self.assertTrue( tf_test_utils.check_equals_checkpoint_top_scopes( checkpoint_dir + '/model', ['MobilenetV1', 'global_step'])) self._run_tiny_training( model_name='inception_v3', dataset=data_providers_test.make_golden_dataset(use_tpu=FLAGS.use_tpu), warm_start_from=checkpoint_dir + '/model') self.assertTrue( tf_test_utils.check_equals_checkpoint_top_scopes( FLAGS.train_dir + '/model.ckpt-1', ['InceptionV3', 'global_step']))
def test_end2end_inception_v3_warm_up_by_default_fail_diff_num_channels(self): """End-to-end test of model_train script.""" checkpoint_dir = tf_test_utils.test_tmpdir( 'test_end2end_inception_v3_warm_up_by_default_fail_diff_num_channels') tf_test_utils.write_fake_checkpoint( 'inception_v3', self.test_session(), checkpoint_dir, num_channels=dv_constants.PILEUP_NUM_CHANNELS + 1) with self.assertRaisesRegex( ValueError, r'Shape of variable InceptionV3/Conv2d_1a_3x3/weights:0 \(\(.*\)\) ' r'doesn\'t match with shape of tensor ' r'InceptionV3/Conv2d_1a_3x3/weights \(\[.*\]\) from checkpoint reader.' ): self._run_tiny_training( model_name='inception_v3', dataset=data_providers_test.make_golden_dataset( use_tpu=FLAGS.use_tpu), warm_start_from=checkpoint_dir + '/model')
def _run_tiny_training(self, model_name, dataset, warm_start_from=''): """Runs one training step. This function always starts a new train_dir.""" with mock.patch( 'deepvariant.data_providers.' 'get_input_fn_from_dataset') as mock_get_input_fn_from_dataset: mock_get_input_fn_from_dataset.return_value = dataset FLAGS.train_dir = tf_test_utils.test_tmpdir(uuid.uuid4().hex) FLAGS.batch_size = 2 FLAGS.model_name = model_name FLAGS.save_interval_secs = -1 FLAGS.save_interval_steps = 1 FLAGS.number_of_steps = 1 FLAGS.dataset_config_pbtxt = '/path/to/mock.pbtxt' FLAGS.start_from_checkpoint = warm_start_from FLAGS.master = '' model_train.parse_and_run() # We have a checkpoint after training. mock_get_input_fn_from_dataset.assert_called_once_with( dataset_config_filename=FLAGS.dataset_config_pbtxt, mode=tf.estimator.ModeKeys.TRAIN, use_tpu=mock.ANY, ) self.assertIsNotNone(tf.train.latest_checkpoint(FLAGS.train_dir))