def test_train_loop_then_eval_loop(self): """Tests that Estimator and input function are constructed correctly.""" model_dir = tf.test.get_temp_dir() pipeline_config_path = get_pipeline_config_path(MODEL_NAME_FOR_TEST) new_pipeline_config_path = os.path.join(model_dir, 'new_pipeline.config') config_util.clear_fine_tune_checkpoint(pipeline_config_path, new_pipeline_config_path) config_kwarg_overrides = _get_config_kwarg_overrides() train_steps = 2 strategy = tf2.distribute.MirroredStrategy(['/cpu:0', '/cpu:1']) with strategy.scope(): model_lib_v2.train_loop(new_pipeline_config_path, model_dir=model_dir, train_steps=train_steps, checkpoint_every_n=1, **config_kwarg_overrides) model_lib_v2.eval_continuously(new_pipeline_config_path, model_dir=model_dir, checkpoint_dir=model_dir, train_steps=train_steps, wait_interval=1, timeout=10, **config_kwarg_overrides)
def test_export_metrics_json_serializable(self): """Tests that Estimator and input function are constructed correctly.""" strategy = tf2.distribute.OneDeviceStrategy(device='/cpu:0') def export(data, _): json.dumps(data) with mock.patch.dict(exporter_lib_v2.INPUT_BUILDER_UTIL_MAP, FAKE_BUILDER_MAP): with strategy.scope(): model_dir = tf.test.get_temp_dir() new_pipeline_config_path = os.path.join( model_dir, 'new_pipeline.config') pipeline_config_path = get_pipeline_config_path( MODEL_NAME_FOR_TEST) config_util.clear_fine_tune_checkpoint( pipeline_config_path, new_pipeline_config_path) train_steps = 2 with strategy.scope(): model_lib_v2.train_loop( new_pipeline_config_path, model_dir=model_dir, train_steps=train_steps, checkpoint_every_n=100, performance_summary_exporter=export, **_get_config_kwarg_overrides())
def test_checkpoint_max_to_keep(self): """Test that only the most recent checkpoints are kept.""" strategy = tf2.distribute.OneDeviceStrategy(device='/cpu:0') with mock.patch.object(model_builder, 'build', autospec=True) as mock_builder: with strategy.scope(): mock_builder.return_value = SimpleModel() model_dir = tempfile.mkdtemp(dir=self.get_temp_dir()) pipeline_config_path = get_pipeline_config_path( MODEL_NAME_FOR_TEST) new_pipeline_config_path = os.path.join(model_dir, 'new_pipeline.config') config_util.clear_fine_tune_checkpoint(pipeline_config_path, new_pipeline_config_path) config_kwarg_overrides = _get_config_kwarg_overrides() with strategy.scope(): model_lib_v2.train_loop(new_pipeline_config_path, model_dir=model_dir, train_steps=20, checkpoint_every_n=2, checkpoint_max_to_keep=3, **config_kwarg_overrides) ckpt_files = tf.io.gfile.glob( os.path.join(model_dir, 'ckpt-*.index')) self.assertEqual(len(ckpt_files), 3, '{} not of length 3.'.format(ckpt_files))
def test_checkpoint_max_to_keep(self): """Test that only the most recent checkpoints are kept.""" strategy = tf2.distribute.OneDeviceStrategy(device='/cpu:0') with mock.patch.dict(exporter_lib_v2.INPUT_BUILDER_UTIL_MAP, FAKE_BUILDER_MAP): model_dir = tempfile.mkdtemp(dir=self.get_temp_dir()) pipeline_config_path = get_pipeline_config_path( MODEL_NAME_FOR_TEST) new_pipeline_config_path = os.path.join(model_dir, 'new_pipeline.config') config_util.clear_fine_tune_checkpoint(pipeline_config_path, new_pipeline_config_path) config_kwarg_overrides = _get_config_kwarg_overrides() with strategy.scope(): model_lib_v2.train_loop(new_pipeline_config_path, model_dir=model_dir, train_steps=20, checkpoint_every_n=2, checkpoint_max_to_keep=3, **config_kwarg_overrides) ckpt_files = tf.io.gfile.glob( os.path.join(model_dir, 'ckpt-*.index')) self.assertEqual(len(ckpt_files), 3, '{} not of length 3.'.format(ckpt_files))