def test_predictor(self): input_generator = default_input_generator.DefaultRandomInputGenerator( batch_size=_BATCH_SIZE) model_dir = self.create_tempdir().full_path mock_model = mocks.MockT2RModel() train_eval.train_eval_model(t2r_model=mock_model, input_generator_train=input_generator, max_train_steps=_MAX_TRAIN_STEPS, model_dir=model_dir) predictor = checkpoint_predictor.CheckpointPredictor( t2r_model=mock_model, checkpoint_dir=model_dir, use_gpu=False) with self.assertRaises(ValueError): predictor.predict({'does_not_matter': np.zeros(1)}) self.assertEqual(predictor.model_version, -1) self.assertEqual(predictor.global_step, -1) self.assertTrue(predictor.restore()) self.assertGreater(predictor.model_version, 0) self.assertEqual(predictor.global_step, 3) ref_feature_spec = mock_model.preprocessor.get_in_feature_specification( tf.estimator.ModeKeys.PREDICT) tensorspec_utils.assert_equal(predictor.get_feature_specification(), ref_feature_spec) features = tensorspec_utils.make_random_numpy(ref_feature_spec, batch_size=_BATCH_SIZE) predictions = predictor.predict(features) self.assertLen(predictions, 1) self.assertCountEqual(sorted(predictions.keys()), ['logit']) self.assertEqual(predictions['logit'].shape, (2, 1))
def test_with_mock_training(self): model_dir = self.create_tempdir().full_path mock_t2r_model = mocks.MockT2RModel( preprocessor_cls=noop_preprocessor.NoOpPreprocessor, device_type='tpu', use_avg_model_params=True) mock_input_generator = mocks.MockInputGenerator(batch_size=_BATCH_SIZE) export_dir = os.path.join(model_dir, _EXPORT_DIR) hook_builder = async_export_hook_builder.AsyncExportHookBuilder( export_dir=export_dir, create_export_fn=async_export_hook_builder.default_create_export_fn ) gin.parse_config('tf.contrib.tpu.TPUConfig.iterations_per_loop=1') gin.parse_config('tf.estimator.RunConfig.save_checkpoints_steps=1') # We optimize our network. train_eval.train_eval_model(t2r_model=mock_t2r_model, input_generator_train=mock_input_generator, train_hook_builders=[hook_builder], model_dir=model_dir, max_train_steps=_MAX_STEPS) self.assertNotEmpty(tf.io.gfile.listdir(model_dir)) self.assertNotEmpty(tf.io.gfile.listdir(export_dir)) for exported_model_dir in tf.io.gfile.listdir(export_dir): self.assertNotEmpty( tf.io.gfile.listdir( os.path.join(export_dir, exported_model_dir))) predictor = exported_savedmodel_predictor.ExportedSavedModelPredictor( export_dir=export_dir) self.assertTrue(predictor.restore())
def recordio_train(self, module_name, model_name, file_patterns, **module_kwargs): """Trains the model with a RecordIO dataset for a few steps.""" tf_model = getattr(module_name, model_name)(**module_kwargs) params = self._get_params( model_dir=self._test_case.create_tempdir().full_path, **module_kwargs) input_generator = default_input_generator.DefaultRecordInputGenerator( file_patterns, batch_size=params['batch_size']) initialize_system = tpu.initialize_system with mock.patch.object(tpu, 'initialize_system', autospec=True) as mock_init: mock_init.side_effect = initialize_system train_eval.train_eval_model( t2r_model=tf_model, input_generator_train=input_generator, max_train_steps=params['max_train_steps'], model_dir=params['model_dir'], use_tpu_wrapper=params['use_tpu_wrapper']) if self._use_tpu: mock_init.assert_called() train_eval_test_utils.assert_output_files( test_case=self._test_case, model_dir=params['model_dir'], expected_output_filename_patterns=train_eval_test_utils. DEFAULT_TRAIN_FILENAME_PATTERNS)
def test_with_mock_training(self): model_dir = self.create_tempdir().full_path mock_t2r_model = mocks.MockT2RModel( preprocessor_cls=noop_preprocessor.NoOpPreprocessor, device_type='cpu') mock_input_generator = mocks.MockInputGenerator(batch_size=_BATCH_SIZE) default_create_export_fn = functools.partial( async_export_hook_builder.default_create_export_fn, batch_sizes_for_export=_BATCH_SIZES_FOR_EXPORT) export_dir = os.path.join(model_dir, _EXPORT_DIR) default_create_export_fn = functools.partial( async_export_hook_builder.default_create_export_fn, batch_sizes_for_export=_BATCH_SIZES_FOR_EXPORT) hook_builder = async_export_hook_builder.AsyncExportHookBuilder( export_dir=export_dir, create_export_fn=default_create_export_fn) default_create_export_fn = functools.partial( async_export_hook_builder.default_create_export_fn, batch_sizes_for_export=_BATCH_SIZES_FOR_EXPORT) # We optimize our network. train_eval.train_eval_model( t2r_model=mock_t2r_model, input_generator_train=mock_input_generator, train_hook_builders=[hook_builder], model_dir=model_dir, max_train_steps=_MAX_STEPS) self.assertNotEmpty(tf.io.gfile.listdir(model_dir)) self.assertNotEmpty(tf.io.gfile.listdir(export_dir)) for exported_model_dir in tf.io.gfile.listdir(export_dir): self.assertNotEmpty( tf.io.gfile.listdir(os.path.join(export_dir, exported_model_dir))) predictor = exported_savedmodel_predictor.ExportedSavedModelPredictor( export_dir=export_dir) self.assertTrue(predictor.restore())
def test_predictor(self): input_generator = default_input_generator.DefaultRandomInputGenerator( batch_size=_BATCH_SIZE) model_dir = self.create_tempdir().full_path mock_model = mocks.MockT2RModel() train_eval.train_eval_model( t2r_model=mock_model, input_generator_train=input_generator, input_generator_eval=input_generator, max_train_steps=_MAX_TRAIN_STEPS, eval_steps=_MAX_EVAL_STEPS, model_dir=model_dir, create_exporters_fn=train_eval.create_default_exporters) predictor = exported_savedmodel_predictor.ExportedSavedModelPredictor( export_dir=os.path.join(model_dir, 'export', 'latest_exporter_numpy')) with self.assertRaises(ValueError): predictor.get_feature_specification() with self.assertRaises(ValueError): predictor.predict({'does_not_matter': np.zeros(1)}) with self.assertRaises(ValueError): _ = predictor.model_version self.assertTrue(predictor.restore()) self.assertGreater(predictor.model_version, 0) ref_feature_spec = mock_model.preprocessor.get_in_feature_specification( tf.estimator.ModeKeys.PREDICT) tensorspec_utils.assert_equal(predictor.get_feature_specification(), ref_feature_spec) features = tensorspec_utils.make_random_numpy(ref_feature_spec, batch_size=_BATCH_SIZE) predictions = predictor.predict(features) self.assertLen(predictions, 1) self.assertEqual(predictions['logit'].shape, (2, 1))
def test_regression_maml(self): maml_model = pose_env_maml_models.PoseEnvRegressionModelMAML( base_model=pose_env_models.PoseEnvRegressionModel()) train_eval.train_eval_model( t2r_model=maml_model, input_generator_train=self._meta_record_input_generator_train, input_generator_eval=self._meta_record_input_generator_eval, create_exporters_fn=None)
def test_init_from_checkpoint_global_step(self): """Tests that a simple model trains and exported models are valid.""" gin.bind_parameter('tf.estimator.RunConfig.save_checkpoints_steps', 100) gin.bind_parameter('tf.estimator.RunConfig.keep_checkpoint_max', 10) model_dir = self.create_tempdir().full_path mock_t2r_model = mocks.MockT2RModel( preprocessor_cls=noop_preprocessor.NoOpPreprocessor) mock_input_generator_train = mocks.MockInputGenerator( batch_size=_BATCH_SIZE) train_eval.train_eval_model( t2r_model=mock_t2r_model, input_generator_train=mock_input_generator_train, max_train_steps=_MAX_TRAIN_STEPS, model_dir=model_dir, eval_steps=_EVAL_STEPS, eval_throttle_secs=_EVAL_THROTTLE_SECS, create_exporters_fn=train_eval.create_default_exporters) # The model trains for 1000 steps and saves a checkpoint each 100 steps and # keeps 10 -> len == 10. self.assertLen( tf.io.gfile.glob(os.path.join(model_dir, 'model*.meta')), 10) # The continuous training has its own directory. continue_model_dir = self.create_tempdir().full_path init_from_checkpoint_fn = functools.partial( abstract_model.default_init_from_checkpoint_fn, checkpoint=model_dir) continue_mock_t2r_model = mocks.MockT2RModel( preprocessor_cls=noop_preprocessor.NoOpPreprocessor, init_from_checkpoint_fn=init_from_checkpoint_fn) continue_mock_input_generator_train = mocks.MockInputGenerator( batch_size=_BATCH_SIZE) train_eval.train_eval_model( t2r_model=continue_mock_t2r_model, input_generator_train=continue_mock_input_generator_train, model_dir=continue_model_dir, max_train_steps=_MAX_TRAIN_STEPS + 100, eval_steps=_EVAL_STEPS, eval_throttle_secs=_EVAL_THROTTLE_SECS, create_exporters_fn=train_eval.create_default_exporters) # If the model was successful restored including the global step, only 1 # additional checkpoint to the init one should be created -> len == 2. self.assertLen( tf.io.gfile.glob(os.path.join(continue_model_dir, 'model*.meta')), 2)
def test_predictor_with_async_hook(self): model_dir = self.create_tempdir().full_path default_create_export_fn = functools.partial( async_export_hook_builder.default_create_export_fn, batch_sizes_for_export=_BATCH_SIZES_FOR_EXPORT) export_dir = os.path.join(model_dir, _EXPORT_DIR) hook_builder = async_export_hook_builder.AsyncExportHookBuilder( export_dir=export_dir, create_export_fn=default_create_export_fn) input_generator = default_input_generator.DefaultRandomInputGenerator( batch_size=_BATCH_SIZE) mock_model = mocks.MockT2RModel() train_eval.train_eval_model(t2r_model=mock_model, input_generator_train=input_generator, train_hook_builders=[hook_builder], max_train_steps=_MAX_TRAIN_STEPS, model_dir=model_dir) predictor = exported_savedmodel_predictor.ExportedSavedModelPredictor( export_dir=os.path.join(model_dir, _EXPORT_DIR)) with self.assertRaises(ValueError): predictor.get_feature_specification() with self.assertRaises(ValueError): predictor.predict({'does_not_matter': np.zeros(1)}) with self.assertRaises(ValueError): _ = predictor.model_version self.assertEqual(predictor.global_step, -1) self.assertTrue(predictor.restore()) self.assertGreater(predictor.model_version, 0) # NOTE: The async hook builder will export the global step. self.assertEqual(predictor.global_step, 3) ref_feature_spec = mock_model.preprocessor.get_in_feature_specification( tf.estimator.ModeKeys.PREDICT) tensorspec_utils.assert_equal(predictor.get_feature_specification(), ref_feature_spec) features = tensorspec_utils.make_random_numpy(ref_feature_spec, batch_size=_BATCH_SIZE) predictions = predictor.predict(features) self.assertLen(predictions, 1) self.assertCountEqual(sorted(predictions.keys()), ['logit']) self.assertEqual(predictions['logit'].shape, (2, 1))
def test_train_eval_model(self): """Tests that a simple model trains and exported models are valid.""" gin.bind_parameter('tf.estimator.RunConfig.save_checkpoints_steps', 100) model_dir = self.create_tempdir().full_path mock_t2r_model = mocks.MockT2RModel( preprocessor_cls=noop_preprocessor.NoOpPreprocessor) mock_input_generator_train = mocks.MockInputGenerator( batch_size=_BATCH_SIZE) mock_input_generator_eval = mocks.MockInputGenerator(batch_size=1) fake_hook_builder = FakeHookBuilder() train_eval.train_eval_model( t2r_model=mock_t2r_model, input_generator_train=mock_input_generator_train, input_generator_eval=mock_input_generator_eval, max_train_steps=_MAX_TRAIN_STEPS, model_dir=model_dir, train_hook_builders=[fake_hook_builder], eval_hook_builders=[fake_hook_builder], eval_steps=_EVAL_STEPS, eval_throttle_secs=_EVAL_THROTTLE_SECS, create_exporters_fn=train_eval.create_default_exporters) self.assertTrue(fake_hook_builder.hook_mock.begin.called) # We ensure that both numpy and tf_example inference models are exported. best_exporter_numpy_path = os.path.join(model_dir, 'export', 'best_exporter_numpy', '*') numpy_model_paths = sorted(tf.io.gfile.glob(best_exporter_numpy_path)) # This mock network converges nicely which is why we have several best # models, by default we keep the best 5 and the latest one is always the # best. self.assertLen(numpy_model_paths, 5) best_exporter_tf_example_path = os.path.join( model_dir, 'export', 'best_exporter_tf_example', '*') tf_example_model_paths = sorted( tf.io.gfile.glob(best_exporter_tf_example_path)) # This mock network converges nicely which is why we have several best # models, by default we keep the best 5 and the latest one is always the # best. self.assertLen(tf_example_model_paths, 5) # We test both saved models within one test since the bulk of the time # is spent training the model in the firstplace. # Verify that the serving estimator does exactly the same as the normal # estimator with all the parameters. estimator_predict = tf.estimator.Estimator( model_fn=mock_t2r_model.model_fn, config=tf.estimator.RunConfig(model_dir=model_dir)) prediction_ref = estimator_predict.predict( input_fn=mock_input_generator_eval.create_dataset_input_fn( mode=tf.estimator.ModeKeys.EVAL)) # Now we can load our exported estimator graph with the numpy feed_dict # interface, there are no dependencies on the model_fn or preprocessor # anymore. # We load the latest model since it had the best eval performance. numpy_predictor_fn = tf.contrib.predictor.from_saved_model( numpy_model_paths[-1]) features, labels = mock_input_generator_eval.create_numpy_data() ref_error = self._compute_total_loss( labels, [val['logit'].flatten() for val in prediction_ref]) numpy_predictions = [] for feature, label in zip(features, labels): predicted = numpy_predictor_fn({'x': feature.reshape(1, -1) })['logit'].flatten() numpy_predictions.append(predicted) # This ensures that we actually achieve perfect classification. if label > 0: self.assertGreater(predicted[0], 0) else: self.assertLess(predicted[0], 0) numpy_error = self._compute_total_loss(labels, numpy_predictions) # Now we can load our exported estimator graph with the tf_example feed_dict # interface, there are no dependencies on the model_fn or preprocessor # anymore. # We load the latest model since it had the best eval performance. tf_example_predictor_fn = tf.contrib.predictor.from_saved_model( tf_example_model_paths[-1]) tf_example_predictions = [] for feature, label in zip(features, labels): # We have to create our serialized tf.Example proto. example = tf.train.Example() example.features.feature[ 'measured_position'].float_list.value.extend(feature) feed_dict = { 'input_example_tensor': np.array(example.SerializeToString()).reshape(1, ) } predicted = tf_example_predictor_fn(feed_dict)['logit'].flatten() tf_example_predictions.append(predicted) # This ensures that we actually achieve perfect classification. if label > 0: self.assertGreater(predicted[0], 0) else: self.assertLess(predicted[0], 0) tf_example_error = self._compute_total_loss(labels, tf_example_predictions) np.testing.assert_almost_equal(tf_example_error, numpy_error) # The exported saved models both have to have the same performance and since # we train on eval on the same fixed dataset the latest and greatest # model error should also be the best. np.testing.assert_almost_equal(ref_error, tf_example_error)
def test_init_from_checkpoint_use_avg_model_params_and_weights(self): """Tests that a simple model trains and exported models are valid.""" gin.bind_parameter('tf.estimator.RunConfig.save_checkpoints_steps', 100) gin.bind_parameter('tf.estimator.RunConfig.keep_checkpoint_max', 10) model_dir = self.create_tempdir().full_path mock_t2r_model = mocks.MockT2RModel( preprocessor_cls=noop_preprocessor.NoOpPreprocessor, use_avg_model_params=True) mock_input_generator_train = mocks.MockInputGenerator( batch_size=_BATCH_SIZE) mock_input_generator = mocks.MockInputGenerator(batch_size=1) mock_input_generator.set_specification_from_model( mock_t2r_model, tf.estimator.ModeKeys.TRAIN) train_eval.train_eval_model( t2r_model=mock_t2r_model, input_generator_train=mock_input_generator_train, max_train_steps=_MAX_TRAIN_STEPS, model_dir=model_dir) init_checkpoint = tf.train.NewCheckpointReader( tf.train.latest_checkpoint(model_dir)) # Verify that the serving estimator does exactly the same as the normal # estimator with all the parameters. initial_estimator_predict = tf.estimator.Estimator( model_fn=mock_t2r_model.model_fn, config=tf.estimator.RunConfig(model_dir=model_dir)) # pylint: disable=g-complex-comprehension initial_predictions = [ prediction['logit'] for prediction in list( initial_estimator_predict.predict( input_fn=mock_input_generator.create_dataset_input_fn( mode=tf.estimator.ModeKeys.EVAL))) ] # The continuous training has its own directory. continue_model_dir = self.create_tempdir().full_path init_from_checkpoint_fn = functools.partial( abstract_model.default_init_from_checkpoint_fn, checkpoint=model_dir) continue_mock_t2r_model = mocks.MockT2RModel( preprocessor_cls=noop_preprocessor.NoOpPreprocessor, init_from_checkpoint_fn=init_from_checkpoint_fn) continue_mock_input_generator_train = mocks.MockInputGenerator( batch_size=_BATCH_SIZE) # Re-initialize the model and train for one step, basically the same # performance as the original model. train_eval.train_eval_model( t2r_model=continue_mock_t2r_model, input_generator_train=continue_mock_input_generator_train, model_dir=continue_model_dir, max_train_steps=_MAX_TRAIN_STEPS) continue_checkpoint = tf.train.NewCheckpointReader( tf.train.latest_checkpoint(continue_model_dir)) for tensor_name, _ in tf.train.list_variables(model_dir): if 'ExponentialMovingAverage' in tensor_name: # These values are replaced by the swapping saver when using the # use_avg_model_params. continue if 'Adam' in tensor_name: # The adam optimizer values are not required. continue if 'global_step' in tensor_name: # The global step will be incremented by 1. continue self.assertAllClose(init_checkpoint.get_tensor(tensor_name), continue_checkpoint.get_tensor(tensor_name), atol=1e-3) # Verify that the serving estimator does exactly the same as the normal # estimator with all the parameters. continue_estimator_predict = tf.estimator.Estimator( model_fn=mock_t2r_model.model_fn, config=tf.estimator.RunConfig(model_dir=continue_model_dir)) continue_predictions = [ prediction['logit'] for prediction in list( continue_estimator_predict.predict( input_fn=mock_input_generator.create_dataset_input_fn( mode=tf.estimator.ModeKeys.EVAL))) ] self.assertTrue( np.allclose(initial_predictions, continue_predictions, atol=1e-2)) # A randomly initialized model estimator with all the parameters. random_estimator_predict = tf.estimator.Estimator( model_fn=mock_t2r_model.model_fn) random_predictions = [ prediction['logit'] for prediction in list( random_estimator_predict.predict( input_fn=mock_input_generator.create_dataset_input_fn( mode=tf.estimator.ModeKeys.EVAL))) ] self.assertFalse( np.allclose(initial_predictions, random_predictions, atol=1e-2))
def test_init_from_checkpoint_use_avg_model_params(self): """Tests that a simple model trains and exported models are valid.""" gin.bind_parameter('tf.estimator.RunConfig.save_checkpoints_steps', 100) gin.bind_parameter('tf.estimator.RunConfig.keep_checkpoint_max', 10) model_dir = self.create_tempdir().full_path mock_t2r_model = mocks.MockT2RModel( preprocessor_cls=noop_preprocessor.NoOpPreprocessor, use_avg_model_params=True) mock_input_generator_train = mocks.MockInputGenerator( batch_size=_BATCH_SIZE) mock_input_generator = mocks.MockInputGenerator(batch_size=1) mock_input_generator.set_specification_from_model( mock_t2r_model, tf.estimator.ModeKeys.TRAIN) train_eval.train_eval_model( t2r_model=mock_t2r_model, input_generator_train=mock_input_generator_train, max_train_steps=_MAX_TRAIN_STEPS, model_dir=model_dir) # Verify that the serving estimator does exactly the same as the normal # estimator with all the parameters. initial_estimator_predict = tf.estimator.Estimator( model_fn=mock_t2r_model.model_fn, config=tf.estimator.RunConfig(model_dir=model_dir)) # pylint: disable=g-complex-comprehension initial_predictions = [ prediction['logit'] for prediction in list( initial_estimator_predict.predict( input_fn=mock_input_generator.create_dataset_input_fn( mode=tf.estimator.ModeKeys.EVAL))) ] # The continuous training has its own directory. continue_model_dir = self.create_tempdir().full_path init_from_checkpoint_fn = functools.partial( abstract_model.default_init_from_checkpoint_fn, checkpoint=model_dir) continue_mock_t2r_model = mocks.MockT2RModel( preprocessor_cls=noop_preprocessor.NoOpPreprocessor, init_from_checkpoint_fn=init_from_checkpoint_fn) continue_mock_input_generator_train = mocks.MockInputGenerator( batch_size=_BATCH_SIZE) # Re-initialize the model and train for one step, basically the same # performance as the original model. train_eval.train_eval_model( t2r_model=continue_mock_t2r_model, input_generator_train=continue_mock_input_generator_train, model_dir=continue_model_dir, max_train_steps=_MAX_TRAIN_STEPS + 1) # Verify that the serving estimator does exactly the same as the normal # estimator with all the parameters. continue_estimator_predict = tf.estimator.Estimator( model_fn=mock_t2r_model.model_fn, config=tf.estimator.RunConfig(model_dir=continue_model_dir)) continue_predictions = [ prediction['logit'] for prediction in list( continue_estimator_predict.predict( input_fn=mock_input_generator.create_dataset_input_fn( mode=tf.estimator.ModeKeys.EVAL))) ] self.assertTrue( np.allclose(initial_predictions, continue_predictions, atol=1e-2)) # A randomly initialized model estimator with all the parameters. random_estimator_predict = tf.estimator.Estimator( model_fn=mock_t2r_model.model_fn) random_predictions = [ prediction['logit'] for prediction in list( random_estimator_predict.predict( input_fn=mock_input_generator.create_dataset_input_fn( mode=tf.estimator.ModeKeys.EVAL))) ] self.assertFalse( np.allclose(initial_predictions, random_predictions, atol=1e-2))
def test_regression(self): train_eval.train_eval_model( t2r_model=pose_env_models.PoseEnvRegressionModel(), input_generator_train=self._record_input_generator, input_generator_eval=self._record_input_generator, create_exporters_fn=None)
def test_mc(self): train_eval.train_eval_model( t2r_model=pose_env_models.PoseEnvContinuousMCModel(), input_generator_train=self._record_input_generator, input_generator_eval=self._record_input_generator, create_exporters_fn=None)
def test_train_eval_gin(test_case, model_dir, full_gin_path, max_train_steps, eval_steps, gin_overwrites_fn=None, assert_train_output_files=True, assert_eval_output_files=True): """Train and eval a runnable gin config. Until we have a proper gen_rule to create individual targets for every gin file automatically, gin files can be tested using the pattern below. Please, use 'test_train_eval_gin' as the test function name such that it is easy to convert these tests as soon as the gen_rule is available. @parameterized.parameters( ('first.gin',), ('second.gin',), ('third.gin',), ) def test_train_eval_gin(self, gin_file): full_gin_path = os.path.join(FLAGS.test_srcdir, BASE_GIN_PATH, gin_file) model_dir = os.path.join(FLAGS.test_tmpdir, 'test_train_eval_gin', gin_file) train_eval_test_utils.test_train_eval_gin( test_case=self, model_dir=model_dir, full_gin_path=full_gin_path, max_train_steps=MAX_TRAIN_STEPS, eval_steps=EVAL_STEPS) Args: test_case: The instance of the test used to assert that the output files are generated. model_dir: The path where the model should be stored. full_gin_path: The path of the gin file which parameterizes train_eval. max_train_steps: The maximum number of training steps, should be small since this is just for testing. eval_steps: The number of eval steps, should be small since this is just for testing. gin_overwrites_fn: Optional function which binds gin parameters to overwrite. assert_train_output_files: If True, the expected output files of the training run are checked, otherwise this check is skipped. If only evaluation is performed this should be set to False. assert_eval_output_files: If True, the output expected files of the evaluation run are checked, otherwise this check is skipped. If only training is performed this should be set to False. Note, if assert_train_output_files is set to False the model_dir is not deleted in order to load the model from training. """ # We clear all prior parameters set by gin to ensure that we can call this # function sequentially for all parameterized tests. gin.clear_config(clear_constants=True) gin.parse_config_file(full_gin_path) gin.bind_parameter('train_eval_model.model_dir', model_dir) if gin_overwrites_fn is not None: gin_overwrites_fn() # Make sure that the model dir is empty. This is important for running # tests locally. if tf.io.gfile.exists(model_dir) and assert_train_output_files: tf.io.gfile.rmtree(model_dir) train_eval.train_eval_model(model_dir=model_dir, max_train_steps=max_train_steps, eval_steps=eval_steps, create_exporters_fn=None) if assert_train_output_files: assert_output_files( test_case=test_case, model_dir=model_dir, expected_output_filename_patterns=DEFAULT_TRAIN_FILENAME_PATTERNS) if assert_eval_output_files: assert_output_files( test_case=test_case, model_dir=model_dir, expected_output_filename_patterns=DEFAULT_EVAL_FILENAME_PATTERNS)
def test_maml_model(self, num_inner_loop_steps): model_dir = os.path.join(FLAGS.test_tmpdir, str(num_inner_loop_steps)) gin.bind_parameter('tf.estimator.RunConfig.save_checkpoints_steps', _MAX_STEPS // 2) if tf.io.gfile.exists(model_dir): tf.io.gfile.rmtree(model_dir) mock_base_model = mocks.MockT2RModel( preprocessor_cls=noop_preprocessor.NoOpPreprocessor) mock_tf_model = MockMAMLModel( base_model=mock_base_model, num_inner_loop_steps=num_inner_loop_steps) # Note, we by choice use the same amount of conditioning samples for # inference as well during train and change the model for eval/inference # to only produce one output sample. mock_input_generator_train = MockMetaInputGenerator( batch_size=_BATCH_SIZE, num_condition_samples_per_task=_NUM_CONDITION_SAMPLES_PER_TASK, num_inference_samples_per_task=_NUM_CONDITION_SAMPLES_PER_TASK) mock_input_generator_train.set_specification_from_model( mock_tf_model, mode=tf.estimator.ModeKeys.TRAIN) mock_input_generator_eval = MockMetaInputGenerator( batch_size=_BATCH_SIZE, num_condition_samples_per_task=_NUM_CONDITION_SAMPLES_PER_TASK, num_inference_samples_per_task=1) mock_input_generator_eval.set_specification_from_model( mock_tf_model, mode=tf.estimator.ModeKeys.TRAIN) mock_export_generator = MockMetaExportGenerator( num_condition_samples_per_task=_NUM_CONDITION_SAMPLES_PER_TASK, num_inference_samples_per_task=1) train_eval.train_eval_model( t2r_model=mock_tf_model, input_generator_train=mock_input_generator_train, input_generator_eval=mock_input_generator_eval, max_train_steps=_MAX_STEPS, model_dir=model_dir, export_generator=mock_export_generator, create_exporters_fn=train_eval.create_default_exporters) export_dir = os.path.join(model_dir, 'export') # best_exporter_numpy, best_exporter_tf_example. self.assertLen(tf.io.gfile.glob(os.path.join(export_dir, '*')), 4) numpy_predictor_fn = tf.contrib.predictor.from_saved_model( tf.io.gfile.glob( os.path.join(export_dir, 'best_exporter_numpy', '*'))[-1]) feed_tensor_keys = sorted(numpy_predictor_fn.feed_tensors.keys()) self.assertCountEqual( [ 'condition/features/x', 'condition/labels/y', 'inference/features/x' ], feed_tensor_keys, ) tf_example_predictor_fn = tf.contrib.predictor.from_saved_model( tf.io.gfile.glob( os.path.join(export_dir, 'best_exporter_tf_example', '*'))[-1]) self.assertCountEqual(['input_example_tensor'], tf_example_predictor_fn.feed_tensors.keys())
def main(unused_argv): gin.parse_config_files_and_bindings(FLAGS.gin_configs, FLAGS.gin_bindings) train_eval.train_eval_model()