示例#1
0
    def _train_model(self, input_fn, hooks):
        all_hooks = []
        with ops.Graph().as_default() as g, g.device(self._device_fn):
            random_seed.set_random_seed(self._config.tf_random_seed)
            global_step = training.get_or_create_global_step(g)
            features, labels = self._get_features_and_labels_from_input_fn(
                input_fn, Modes.TRAIN)
            estimator_spec = self._call_model_fn(features, labels, Modes.TRAIN)
            ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss)
            all_hooks.extend([
                plx_hooks.NanTensorHook(estimator_spec.loss),
                plx_hooks.StepLoggingTensorHook(
                    {
                        'loss': estimator_spec.loss,
                        'step': global_step
                    },
                    every_n_iter=100)
            ])
            all_hooks.extend(hooks)
            all_hooks.extend(estimator_spec.training_hooks)

            scaffold = estimator_spec.scaffold
            if not (scaffold.saver
                    or ops.get_collection(ops.GraphKeys.SAVERS)):
                ops.add_to_collection(
                    ops.GraphKeys.SAVERS,  # TODO remove non restorable vars
                    saver.Saver(
                        sharded=True,
                        max_to_keep=self._config.keep_checkpoint_max,
                        keep_checkpoint_every_n_hours=(
                            self._config.keep_checkpoint_every_n_hours),
                        defer_build=True,
                        save_relative_paths=True))

            chief_hooks = []
            if self._config.save_checkpoints_secs or self._config.save_checkpoints_steps:
                saver_hook_exists = any([
                    isinstance(h, plx_hooks.StepCheckpointSaverHook)
                    for h in (all_hooks + chief_hooks +
                              list(estimator_spec.training_chief_hooks))
                ])
                if not saver_hook_exists:
                    chief_hooks += [
                        plx_hooks.StepCheckpointSaverHook(
                            self._model_dir,
                            save_secs=self._config.save_checkpoints_secs,
                            save_steps=self._config.save_checkpoints_steps,
                            scaffold=scaffold)
                    ]
            if self._config.save_summary_steps:
                saver_hook_exists = any([
                    isinstance(h, plx_hooks.StepSummarySaverHook)
                    for h in (all_hooks + chief_hooks +
                              list(estimator_spec.training_chief_hooks))
                ])
                if not saver_hook_exists:
                    chief_hooks += [
                        plx_hooks.StepSummarySaverHook(
                            scaffold=scaffold,
                            save_steps=self._config.save_summary_steps,
                            output_dir=self._model_dir,
                        )
                    ]

            with monitored_session.MonitoredTrainingSession(
                    master=self._config.master,
                    is_chief=self._config.is_chief,
                    checkpoint_dir=self._model_dir,
                    scaffold=scaffold,
                    hooks=all_hooks,
                    chief_only_hooks=chief_hooks +
                    list(estimator_spec.training_chief_hooks),
                    save_checkpoint_secs=
                    0,  # Saving checkpoint is handled by a hook.
                    save_summaries_steps=
                    0,  # Saving summaries is handled by a hook.
                    config=self._session_config) as mon_sess:
                loss = None
                while not mon_sess.should_stop():
                    _, loss = mon_sess.run(
                        [estimator_spec.train_op, estimator_spec.loss])
            return loss
示例#2
0
    def _train_model(self, env, first_update, update_frequency, hooks):
        all_hooks = []
        self._graph = ops.Graph()
        with self._graph.as_default() as g, g.device(self._device_fn):
            random_seed.set_random_seed(self._config.tf_random_seed)
            global_step = training.get_or_create_global_step(g)
            global_episode = get_or_create_global_episode(g)
            global_timestep = get_or_create_global_timestep(g)
            update_episode_op = tf.assign_add(global_episode, 1)
            update_timestep_op = tf.assign_add(global_timestep, 1)
            no_run_hooks = tf.no_op(name='no_run_hooks')
            with ops.device('/cpu:0'):
                features, labels = self._prepare_input_fn(Modes.TRAIN, env)
            estimator_spec = self._call_model_fn(features, labels, Modes.TRAIN)
            ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss)
            all_hooks.extend([
                plx_hooks.NanTensorHook(estimator_spec.loss),
                plx_hooks.StepLoggingTensorHook(
                    {
                        'loss': estimator_spec.loss,
                        'step': global_step,
                        'timestep': global_timestep,
                        'global_episode': global_episode,
                        'max_reward': labels['max_reward'],
                        'min_reward': labels['min_reward'],
                        'total_reward': labels['total_reward'],
                    },
                    every_n_iter=100)
            ])
            all_hooks.extend(hooks)
            all_hooks.extend(estimator_spec.training_hooks)

            scaffold = estimator_spec.scaffold or monitored_session.Scaffold()
            if not (scaffold.saver
                    or ops.get_collection(ops.GraphKeys.SAVERS)):
                ops.add_to_collection(
                    ops.GraphKeys.SAVERS,  # TODO remove non restorable vars
                    saver.Saver(
                        sharded=True,  # TODO `var_list`
                        max_to_keep=self._config.keep_checkpoint_max,
                        defer_build=True))

            chief_hooks = [
                plx_hooks.EpisodeLoggingTensorHook(
                    {
                        'loss': estimator_spec.loss,
                        'step': global_step,
                        'global_timestep': global_timestep,
                        'global_episode': global_episode,
                        'max_reward': labels['max_reward'],
                        'min_reward': labels['min_reward'],
                        'total_reward': labels['total_reward'],
                    },
                    every_n_episodes=1),  # TODO: save every episode?
                plx_hooks.EpisodeCounterHook(output_dir=self.model_dir)
            ]
            if self._config.save_checkpoints_secs or self._config.save_checkpoints_steps:
                saver_hook_exists = any([
                    isinstance(h, plx_hooks.EpisodeCheckpointSaverHook)
                    for h in (all_hooks + chief_hooks +
                              list(estimator_spec.training_chief_hooks))
                ])
                if not saver_hook_exists:
                    chief_hooks += [
                        plx_hooks.EpisodeCheckpointSaverHook(
                            self._model_dir,
                            save_episodes=1,  # TODO: save every episode?
                            scaffold=scaffold)
                    ]
            if self._config.save_summary_steps:
                saver_hook_exists = any([
                    isinstance(h, plx_hooks.EpisodeSummarySaverHook)
                    for h in (all_hooks + chief_hooks +
                              list(estimator_spec.training_chief_hooks))
                ])
                if not saver_hook_exists:
                    chief_hooks += [
                        plx_hooks.EpisodeSummarySaverHook(
                            scaffold=scaffold,
                            save_episodes=1,  # TODO: save every episode?
                            output_dir=self._model_dir,
                        )
                    ]
            with monitored_session.MonitoredTrainingSession(
                    master=self._config.master,
                    is_chief=self._config.is_chief,
                    checkpoint_dir=self._model_dir,
                    scaffold=scaffold,
                    hooks=all_hooks,
                    chief_only_hooks=chief_hooks +
                    list(estimator_spec.training_chief_hooks),
                    save_checkpoint_secs=
                    0,  # Saving checkpoint is handled by a hook.
                    save_summaries_steps=
                    0,  # Saving summaries is handled by a hook.
                    config=self._session_config) as mon_sess:
                loss = None
                while not mon_sess.should_stop():
                    loss = self.run_episode(
                        env=env,
                        sess=mon_sess,
                        features=features,
                        labels=labels,
                        no_run_hooks=no_run_hooks,
                        global_step=global_step,
                        update_episode_op=update_episode_op,
                        update_timestep_op=update_timestep_op,
                        first_update=first_update,
                        update_frequency=update_frequency,
                        estimator_spec=estimator_spec)
            summary_io.SummaryWriterCache.clear()
            return loss