def _train_model(self, input_fn, hooks): all_hooks = [] with ops.Graph().as_default() as g, g.device(self._device_fn): random_seed.set_random_seed(self._config.tf_random_seed) global_step = training.get_or_create_global_step(g) features, labels = self._get_features_and_labels_from_input_fn( input_fn, Modes.TRAIN) estimator_spec = self._call_model_fn(features, labels, Modes.TRAIN) ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss) all_hooks.extend([ plx_hooks.NanTensorHook(estimator_spec.loss), plx_hooks.StepLoggingTensorHook( { 'loss': estimator_spec.loss, 'step': global_step }, every_n_iter=100) ]) all_hooks.extend(hooks) all_hooks.extend(estimator_spec.training_hooks) scaffold = estimator_spec.scaffold if not (scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)): ops.add_to_collection( ops.GraphKeys.SAVERS, # TODO remove non restorable vars saver.Saver( sharded=True, max_to_keep=self._config.keep_checkpoint_max, keep_checkpoint_every_n_hours=( self._config.keep_checkpoint_every_n_hours), defer_build=True, save_relative_paths=True)) chief_hooks = [] if self._config.save_checkpoints_secs or self._config.save_checkpoints_steps: saver_hook_exists = any([ isinstance(h, plx_hooks.StepCheckpointSaverHook) for h in (all_hooks + chief_hooks + list(estimator_spec.training_chief_hooks)) ]) if not saver_hook_exists: chief_hooks += [ plx_hooks.StepCheckpointSaverHook( self._model_dir, save_secs=self._config.save_checkpoints_secs, save_steps=self._config.save_checkpoints_steps, scaffold=scaffold) ] if self._config.save_summary_steps: saver_hook_exists = any([ isinstance(h, plx_hooks.StepSummarySaverHook) for h in (all_hooks + chief_hooks + list(estimator_spec.training_chief_hooks)) ]) if not saver_hook_exists: chief_hooks += [ plx_hooks.StepSummarySaverHook( scaffold=scaffold, save_steps=self._config.save_summary_steps, output_dir=self._model_dir, ) ] with monitored_session.MonitoredTrainingSession( master=self._config.master, is_chief=self._config.is_chief, checkpoint_dir=self._model_dir, scaffold=scaffold, hooks=all_hooks, chief_only_hooks=chief_hooks + list(estimator_spec.training_chief_hooks), save_checkpoint_secs= 0, # Saving checkpoint is handled by a hook. save_summaries_steps= 0, # Saving summaries is handled by a hook. config=self._session_config) as mon_sess: loss = None while not mon_sess.should_stop(): _, loss = mon_sess.run( [estimator_spec.train_op, estimator_spec.loss]) return loss
def _train_model(self, env, first_update, update_frequency, hooks): all_hooks = [] self._graph = ops.Graph() with self._graph.as_default() as g, g.device(self._device_fn): random_seed.set_random_seed(self._config.tf_random_seed) global_step = training.get_or_create_global_step(g) global_episode = get_or_create_global_episode(g) global_timestep = get_or_create_global_timestep(g) update_episode_op = tf.assign_add(global_episode, 1) update_timestep_op = tf.assign_add(global_timestep, 1) no_run_hooks = tf.no_op(name='no_run_hooks') with ops.device('/cpu:0'): features, labels = self._prepare_input_fn(Modes.TRAIN, env) estimator_spec = self._call_model_fn(features, labels, Modes.TRAIN) ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss) all_hooks.extend([ plx_hooks.NanTensorHook(estimator_spec.loss), plx_hooks.StepLoggingTensorHook( { 'loss': estimator_spec.loss, 'step': global_step, 'timestep': global_timestep, 'global_episode': global_episode, 'max_reward': labels['max_reward'], 'min_reward': labels['min_reward'], 'total_reward': labels['total_reward'], }, every_n_iter=100) ]) all_hooks.extend(hooks) all_hooks.extend(estimator_spec.training_hooks) scaffold = estimator_spec.scaffold or monitored_session.Scaffold() if not (scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)): ops.add_to_collection( ops.GraphKeys.SAVERS, # TODO remove non restorable vars saver.Saver( sharded=True, # TODO `var_list` max_to_keep=self._config.keep_checkpoint_max, defer_build=True)) chief_hooks = [ plx_hooks.EpisodeLoggingTensorHook( { 'loss': estimator_spec.loss, 'step': global_step, 'global_timestep': global_timestep, 'global_episode': global_episode, 'max_reward': labels['max_reward'], 'min_reward': labels['min_reward'], 'total_reward': labels['total_reward'], }, every_n_episodes=1), # TODO: save every episode? plx_hooks.EpisodeCounterHook(output_dir=self.model_dir) ] if self._config.save_checkpoints_secs or self._config.save_checkpoints_steps: saver_hook_exists = any([ isinstance(h, plx_hooks.EpisodeCheckpointSaverHook) for h in (all_hooks + chief_hooks + list(estimator_spec.training_chief_hooks)) ]) if not saver_hook_exists: chief_hooks += [ plx_hooks.EpisodeCheckpointSaverHook( self._model_dir, save_episodes=1, # TODO: save every episode? scaffold=scaffold) ] if self._config.save_summary_steps: saver_hook_exists = any([ isinstance(h, plx_hooks.EpisodeSummarySaverHook) for h in (all_hooks + chief_hooks + list(estimator_spec.training_chief_hooks)) ]) if not saver_hook_exists: chief_hooks += [ plx_hooks.EpisodeSummarySaverHook( scaffold=scaffold, save_episodes=1, # TODO: save every episode? output_dir=self._model_dir, ) ] with monitored_session.MonitoredTrainingSession( master=self._config.master, is_chief=self._config.is_chief, checkpoint_dir=self._model_dir, scaffold=scaffold, hooks=all_hooks, chief_only_hooks=chief_hooks + list(estimator_spec.training_chief_hooks), save_checkpoint_secs= 0, # Saving checkpoint is handled by a hook. save_summaries_steps= 0, # Saving summaries is handled by a hook. config=self._session_config) as mon_sess: loss = None while not mon_sess.should_stop(): loss = self.run_episode( env=env, sess=mon_sess, features=features, labels=labels, no_run_hooks=no_run_hooks, global_step=global_step, update_episode_op=update_episode_op, update_timestep_op=update_timestep_op, first_update=first_update, update_frequency=update_frequency, estimator_spec=estimator_spec) summary_io.SummaryWriterCache.clear() return loss