def test_train_loop_then_eval_loop(self):
        """Tests that Estimator and input function are constructed correctly."""
        model_dir = tf.test.get_temp_dir()
        pipeline_config_path = get_pipeline_config_path(MODEL_NAME_FOR_TEST)
        new_pipeline_config_path = os.path.join(model_dir,
                                                'new_pipeline.config')
        config_util.clear_fine_tune_checkpoint(pipeline_config_path,
                                               new_pipeline_config_path)
        config_kwarg_overrides = _get_config_kwarg_overrides()

        train_steps = 2
        strategy = tf2.distribute.MirroredStrategy(['/cpu:0', '/cpu:1'])
        with strategy.scope():
            model_lib_v2.train_loop(new_pipeline_config_path,
                                    model_dir=model_dir,
                                    train_steps=train_steps,
                                    checkpoint_every_n=1,
                                    **config_kwarg_overrides)

        model_lib_v2.eval_continuously(new_pipeline_config_path,
                                       model_dir=model_dir,
                                       checkpoint_dir=model_dir,
                                       train_steps=train_steps,
                                       wait_interval=1,
                                       timeout=10,
                                       **config_kwarg_overrides)
示例#2
0
    def test_export_metrics_json_serializable(self):
        """Tests that Estimator and input function are constructed correctly."""

        strategy = tf2.distribute.OneDeviceStrategy(device='/cpu:0')

        def export(data, _):
            json.dumps(data)

        with mock.patch.dict(exporter_lib_v2.INPUT_BUILDER_UTIL_MAP,
                             FAKE_BUILDER_MAP):
            with strategy.scope():
                model_dir = tf.test.get_temp_dir()
                new_pipeline_config_path = os.path.join(
                    model_dir, 'new_pipeline.config')
                pipeline_config_path = get_pipeline_config_path(
                    MODEL_NAME_FOR_TEST)
                config_util.clear_fine_tune_checkpoint(
                    pipeline_config_path, new_pipeline_config_path)
                train_steps = 2
                with strategy.scope():
                    model_lib_v2.train_loop(
                        new_pipeline_config_path,
                        model_dir=model_dir,
                        train_steps=train_steps,
                        checkpoint_every_n=100,
                        performance_summary_exporter=export,
                        **_get_config_kwarg_overrides())
    def test_checkpoint_max_to_keep(self):
        """Test that only the most recent checkpoints are kept."""

        strategy = tf2.distribute.OneDeviceStrategy(device='/cpu:0')
        with mock.patch.object(model_builder, 'build',
                               autospec=True) as mock_builder:
            with strategy.scope():
                mock_builder.return_value = SimpleModel()
            model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
            pipeline_config_path = get_pipeline_config_path(
                MODEL_NAME_FOR_TEST)
            new_pipeline_config_path = os.path.join(model_dir,
                                                    'new_pipeline.config')
            config_util.clear_fine_tune_checkpoint(pipeline_config_path,
                                                   new_pipeline_config_path)
            config_kwarg_overrides = _get_config_kwarg_overrides()

            with strategy.scope():
                model_lib_v2.train_loop(new_pipeline_config_path,
                                        model_dir=model_dir,
                                        train_steps=20,
                                        checkpoint_every_n=2,
                                        checkpoint_max_to_keep=3,
                                        **config_kwarg_overrides)
            ckpt_files = tf.io.gfile.glob(
                os.path.join(model_dir, 'ckpt-*.index'))
            self.assertEqual(len(ckpt_files), 3,
                             '{} not of length 3.'.format(ckpt_files))
示例#4
0
    def test_checkpoint_max_to_keep(self):
        """Test that only the most recent checkpoints are kept."""

        strategy = tf2.distribute.OneDeviceStrategy(device='/cpu:0')
        with mock.patch.dict(exporter_lib_v2.INPUT_BUILDER_UTIL_MAP,
                             FAKE_BUILDER_MAP):

            model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
            pipeline_config_path = get_pipeline_config_path(
                MODEL_NAME_FOR_TEST)
            new_pipeline_config_path = os.path.join(model_dir,
                                                    'new_pipeline.config')
            config_util.clear_fine_tune_checkpoint(pipeline_config_path,
                                                   new_pipeline_config_path)
            config_kwarg_overrides = _get_config_kwarg_overrides()

            with strategy.scope():
                model_lib_v2.train_loop(new_pipeline_config_path,
                                        model_dir=model_dir,
                                        train_steps=20,
                                        checkpoint_every_n=2,
                                        checkpoint_max_to_keep=3,
                                        **config_kwarg_overrides)
            ckpt_files = tf.io.gfile.glob(
                os.path.join(model_dir, 'ckpt-*.index'))
            self.assertEqual(len(ckpt_files), 3,
                             '{} not of length 3.'.format(ckpt_files))