def _infer_model(self, mode, input_fn=None, predict_keys=None, hooks=None, checkpoint_path=None): """Returns predictions for given features given an inference mode. Args: mode: The inference to use, possible values: PREDICT, GENERATE, ENCODE. input_fn: Input function returning features which is a dictionary of string feature name to `Tensor` or `SparseTensor`. If it returns a tuple, first item is extracted as features. Prediction continues until `input_fn` raises an end-of-input exception (`OutOfRangeError` or `StopIteration`). predict_keys: list of `str`, name of the keys to predict. It is used if the `EstimatorSpec.predictions` is a `dict`. If `predict_keys` is used then rest of the predictions will be filtered from the dictionary. If `None`, returns all. hooks: List of `SessionRunHook` subclass instances. Used for callbacks inside the prediction call. checkpoint_path: Path of a specific checkpoint to predict. If `None`, the latest checkpoint in `model_dir` is used. Yields: Evaluated values of `predictions` tensors. Raises: ValueError: Could not find a trained model in model_dir. ValueError: if batch length of predictions are not same. ValueError: If there is a conflict between `predict_keys` and `predictions`. For example if `predict_keys` is not `None` but `EstimatorSpec.predictions` is not a `dict`. """ hooks = self._check_hooks(hooks) # Check that model has been trained. if not checkpoint_path: checkpoint_path = saver.latest_checkpoint(self._model_dir) if not checkpoint_path: raise ValueError("Could not find trained model at %s." % self._model_dir) with ops.Graph().as_default() as g: random_seed.set_random_seed(self._config.tf_random_seed) training.get_or_create_global_step(g) features = self._get_features_from_input_fn(input_fn) estimator_spec = self._call_model_fn(features, None, mode) predictions = self._extract_keys(estimator_spec.predictions, predict_keys) with monitored_session.MonitoredSession( session_creator=monitored_session.ChiefSessionCreator( checkpoint_filename_with_path=checkpoint_path, scaffold=estimator_spec.scaffold, config=self._session_config), hooks=hooks) as mon_sess: while not mon_sess.should_stop(): preds_evaluated = mon_sess.run(predictions) if not isinstance(predictions, dict): for pred in preds_evaluated: yield pred else: for i in xrange(extract_batch_length(preds_evaluated)): yield {key: value[i] for key, value in six.iteritems(preds_evaluated)}
def build_subnetwork(self, features, logits_dimension, training, iteration_step, summary, previous_ensemble=None): assert features is not None assert training is not None assert iteration_step is not None assert summary is not None # Trainable variables collection should always be empty when # build_subnetwork is called. assert not tf_compat.v1.get_collection( tf_compat.v1.GraphKeys.TRAINABLE_VARIABLES) # Subnetworks get iteration steps instead of global steps. step_name = "subnetwork_test/iteration_step" assert step_name == tf_compat.tensor_name( tf_compat.v1.train.get_global_step()) assert step_name == tf_compat.tensor_name(train.get_global_step()) assert step_name == tf_compat.tensor_name(training_util.get_global_step()) assert step_name == tf_compat.tensor_name(tf_v1.train.get_global_step()) assert step_name == tf_compat.tensor_name( tf_compat.v1.train.get_or_create_global_step()) assert step_name == tf_compat.tensor_name(train.get_or_create_global_step()) assert step_name == tf_compat.tensor_name( training_util.get_or_create_global_step()) assert step_name == tf_compat.tensor_name( tf_v1.train.get_or_create_global_step()) # Subnetworks get scoped summaries. assert "fake_scalar" == tf_compat.v1.summary.scalar("scalar", 1.) assert "fake_image" == tf_compat.v1.summary.image("image", 1.) assert "fake_histogram" == tf_compat.v1.summary.histogram("histogram", 1.) assert "fake_audio" == tf_compat.v1.summary.audio("audio", 1., 1.) last_layer = tu.dummy_tensor(shape=(2, 3)) def logits_fn(logits_dim): return tf_compat.v1.layers.dense( last_layer, units=logits_dim, kernel_initializer=tf_compat.v1.glorot_uniform_initializer( seed=self._seed)) if self._multi_head: logits = { "head1": logits_fn(logits_dimension / 2), "head2": logits_fn(logits_dimension / 2) } last_layer = {"head1": last_layer, "head2": last_layer} else: logits = logits_fn(logits_dimension) return Subnetwork( last_layer=logits if self._use_logits_last_layer else last_layer, logits=logits, complexity=2, persisted_tensors={})
def _train_model(self, checkpoint_dir, num_steps): """Trains a simple classification model. Note that the data has been configured such that after around 300 steps, the model has memorized the dataset (e.g. we can expect %100 accuracy). Args: checkpoint_dir: The directory where the checkpoint is written to. num_steps: The number of steps to train for. """ with ops.Graph().as_default(): random_seed.set_random_seed(0) tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32) tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32) tf_predictions = logistic_classifier(tf_inputs) loss_op = losses.log_loss(labels=tf_labels, predictions=tf_predictions) optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0) train_op = optimizer.minimize(loss_op, training.get_or_create_global_step()) with monitored_session.MonitoredTrainingSession( checkpoint_dir=checkpoint_dir, hooks=[basic_session_run_hooks.StopAtStepHook(num_steps)]) as session: loss = None while not session.should_stop(): _, loss = session.run([train_op, loss_op]) if num_steps >= 300: assert loss < .015
def _build_train_op(self, loss): """Creates the training operation""" optimizer = self._build_optimizer() train_op = tf.contrib.layers.optimize_loss( loss=loss, global_step=training.get_or_create_global_step(), learning_rate=None, clip_gradients=self._clip_gradients_fn, optimizer=optimizer, summaries=[]) return train_op
def _build_train_op(self, loss): """Creates the training operation""" optimizer = self._build_optimizer() train_op = tf.contrib.layers.optimize_loss( loss=loss, global_step=training.get_or_create_global_step(), learning_rate=None, clip_gradients=self._clip_gradients_fn, optimizer=optimizer, summaries=[]) return train_op
def model_fn(features, labels, mode, params): """The model_fn argument for creating an Estimator.""" model = Model(params["data_format"]) image = features if isinstance(image, dict): image = features["image"] if mode == estimator.ModeKeys.PREDICT: logits = model(image, training=False) predictions = { "classes": math_ops.argmax(logits, axis=1), "probabilities": nn.softmax(logits), } return estimator.EstimatorSpec( mode=estimator.ModeKeys.PREDICT, predictions=predictions, export_outputs={ "classify": estimator.export.PredictOutput(predictions) }) elif mode == estimator.ModeKeys.TRAIN: optimizer = train.AdamOptimizer(learning_rate=1e-4) logits = model(image, training=True) loss = losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) return estimator.EstimatorSpec(mode=estimator.ModeKeys.TRAIN, loss=loss, train_op=optimizer.minimize( loss, train.get_or_create_global_step())) elif mode == estimator.ModeKeys.EVAL: logits = model(image, training=False) loss = losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) return estimator.EstimatorSpec( mode=estimator.ModeKeys.EVAL, loss=loss, eval_metric_ops={ "accuracy": ops.metrics.accuracy(labels=labels, predictions=math_ops.argmax(logits, axis=1)), })
def model_fn(features, labels, mode, params): """The model_fn argument for creating an Estimator.""" model = Model(params["data_format"]) image = features if isinstance(image, dict): image = features["image"] if mode == estimator.ModeKeys.PREDICT: logits = model(image, training=False) predictions = { "classes": math_ops.argmax(logits, axis=1), "probabilities": nn.softmax(logits), } return estimator.EstimatorSpec( mode=estimator.ModeKeys.PREDICT, predictions=predictions, export_outputs={ "classify": estimator.export.PredictOutput(predictions) }) elif mode == estimator.ModeKeys.TRAIN: optimizer = train.AdamOptimizer(learning_rate=1e-4) logits = model(image, training=True) loss = losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) return estimator.EstimatorSpec( mode=estimator.ModeKeys.TRAIN, loss=loss, train_op=optimizer.minimize(loss, train.get_or_create_global_step())) elif mode == estimator.ModeKeys.EVAL: logits = model(image, training=False) loss = losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) return estimator.EstimatorSpec( mode=estimator.ModeKeys.EVAL, loss=loss, eval_metric_ops={ "accuracy": ops.metrics.accuracy( labels=labels, predictions=math_ops.argmax(logits, axis=1)), })
def _train_model(self, checkpoint_dir, num_steps): """Trains a simple classification model. Note that the data has been configured such that after around 300 steps, the model has memorized the dataset (e.g. we can expect %100 accuracy). Args: checkpoint_dir: The directory where the checkpoint is written to. num_steps: The number of steps to train for. """ with ops.Graph().as_default(): random_seed.set_random_seed(0) tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32) tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32) tf_predictions = logistic_classifier(tf_inputs) loss_op = losses.log_loss(labels=tf_labels, predictions=tf_predictions) optimizer = gradient_descent.GradientDescentOptimizer( learning_rate=1.0) train_op = optimizer.minimize(loss_op, training.get_or_create_global_step()) with monitored_session.MonitoredTrainingSession( checkpoint_dir=checkpoint_dir, hooks=[basic_session_run_hooks.StopAtStepHook(num_steps) ]) as session: loss = None while not session.should_stop(): _, loss = session.run([train_op, loss_op]) if num_steps >= 300: assert loss < .015
def test_build_ensemble_spec( self, want_logits, want_loss=None, want_adanet_loss=None, want_ensemble_trainable_vars=None, adanet_lambda=0., adanet_beta=0., ensemble_spec_fn=lambda: None, use_bias=False, use_logits_last_layer=False, mixture_weight_type=MixtureWeightType.MATRIX, mixture_weight_initializer=tf_compat.v1.zeros_initializer(), warm_start_mixture_weights=True, subnetwork_builder_class=_Builder, mode=tf.estimator.ModeKeys.TRAIN, multi_head=False, want_subnetwork_trainable_vars=2): seed = 64 if multi_head: head = multi_head_lib.MultiHead(heads=[ binary_class_head.BinaryClassHead( name="head1", loss_reduction=tf_compat.SUM), binary_class_head.BinaryClassHead(name="head2", loss_reduction=tf_compat.SUM) ]) else: head = binary_class_head.BinaryClassHead( loss_reduction=tf_compat.SUM) builder = _EnsembleBuilder(head=head) def _subnetwork_train_op_fn(loss, var_list): self.assertLen(var_list, want_subnetwork_trainable_vars) self.assertEqual( var_list, tf_compat.v1.get_collection( tf_compat.v1.GraphKeys.TRAINABLE_VARIABLES)) # Subnetworks get iteration steps instead of global steps. self.assertEqual("subnetwork_test/iteration_step", tf_compat.v1.train.get_global_step().op.name) # Subnetworks get scoped summaries. self.assertEqual("fake_scalar", tf_compat.v1.summary.scalar("scalar", 1.)) self.assertEqual("fake_image", tf_compat.v1.summary.image("image", 1.)) self.assertEqual("fake_histogram", tf_compat.v1.summary.histogram("histogram", 1.)) self.assertEqual("fake_audio", tf_compat.v1.summary.audio("audio", 1., 1.)) optimizer = tf_compat.v1.train.GradientDescentOptimizer( learning_rate=.1) return optimizer.minimize(loss, var_list=var_list) def _mixture_weights_train_op_fn(loss, var_list): self.assertLen(var_list, want_ensemble_trainable_vars) self.assertEqual( var_list, tf_compat.v1.get_collection( tf_compat.v1.GraphKeys.TRAINABLE_VARIABLES)) # Subnetworks get iteration steps instead of global steps. self.assertEqual("ensemble_test/iteration_step", tf_compat.v1.train.get_global_step().op.name) # Subnetworks get scoped summaries. self.assertEqual("fake_scalar", tf_compat.v1.summary.scalar("scalar", 1.)) self.assertEqual("fake_image", tf_compat.v1.summary.image("image", 1.)) self.assertEqual("fake_histogram", tf_compat.v1.summary.histogram("histogram", 1.)) self.assertEqual("fake_audio", tf_compat.v1.summary.audio("audio", 1., 1.)) optimizer = tf_compat.v1.train.GradientDescentOptimizer( learning_rate=.1) return optimizer.minimize(loss, var_list=var_list) previous_ensemble = None previous_ensemble_spec = ensemble_spec_fn() if previous_ensemble_spec: previous_ensemble = previous_ensemble_spec.ensemble subnetwork_manager = _SubnetworkManager(head) subnetwork_builder = subnetwork_builder_class( _subnetwork_train_op_fn, _mixture_weights_train_op_fn, use_logits_last_layer, seed, multi_head=multi_head) with tf.Graph().as_default() as g: # A trainable variable to later verify that creating models does not # affect the global variables collection. _ = tf_compat.v1.get_variable("some_var", 0., trainable=True) features = {"x": tf.constant([[1.], [2.]])} if multi_head: labels = { "head1": tf.constant([0, 1]), "head2": tf.constant([0, 1]) } else: labels = tf.constant([0, 1]) subnetwork_spec = subnetwork_manager.build_subnetwork_spec( name="test", subnetwork_builder=subnetwork_builder, iteration_step=tf_compat.v1.train.get_or_create_global_step(), summary=_FakeSummary(), features=features, mode=mode, labels=labels, previous_ensemble=previous_ensemble) ensemble_spec = builder.build_ensemble_spec( # Note: when ensemble_spec is not None and warm_start_mixture_weights # is True, we need to make sure that the bias and mixture weights are # already saved to the checkpoint_dir. name="test", previous_ensemble_spec=previous_ensemble_spec, candidate=EnsembleCandidate("foo", [subnetwork_builder], None), ensembler=ComplexityRegularizedEnsembler( mixture_weight_type=mixture_weight_type, mixture_weight_initializer=mixture_weight_initializer, warm_start_mixture_weights=warm_start_mixture_weights, model_dir=self.test_subdirectory, adanet_lambda=adanet_lambda, adanet_beta=adanet_beta, use_bias=use_bias), subnetwork_specs=[subnetwork_spec], summary=_FakeSummary(), features=features, iteration_number=1, iteration_step=tf_compat.v1.train.get_or_create_global_step(), labels=labels, mode=mode) with tf_compat.v1.Session(graph=g).as_default() as sess: sess.run(tf_compat.v1.global_variables_initializer()) # Equals the number of subnetwork and ensemble trainable variables, # plus the one 'some_var' created earlier. self.assertLen( tf_compat.v1.trainable_variables(), want_subnetwork_trainable_vars + want_ensemble_trainable_vars + 1) # Get the real global step outside a subnetwork's context. self.assertEqual("global_step", tf_compat.v1.train.get_global_step().op.name) self.assertEqual("global_step", train.get_global_step().op.name) self.assertEqual("global_step", tf_v1.train.get_global_step().op.name) self.assertEqual("global_step", training_util.get_global_step().op.name) self.assertEqual( "global_step", tf_compat.v1.train.get_or_create_global_step().op.name) self.assertEqual("global_step", train.get_or_create_global_step().op.name) self.assertEqual( "global_step", tf_v1.train.get_or_create_global_step().op.name) self.assertEqual( "global_step", training_util.get_or_create_global_step().op.name) # Get global tf.summary outside a subnetwork's context. self.assertNotEqual("fake_scalar", tf_compat.v1.summary.scalar("scalar", 1.)) self.assertNotEqual("fake_image", tf_compat.v1.summary.image("image", 1.)) self.assertNotEqual( "fake_histogram", tf_compat.v1.summary.histogram("histogram", 1.)) self.assertNotEqual( "fake_audio", tf_compat.v1.summary.audio("audio", 1., 1.)) if mode == tf.estimator.ModeKeys.PREDICT: self.assertAllClose(want_logits, sess.run( ensemble_spec.ensemble.logits), atol=1e-3) self.assertIsNone(ensemble_spec.loss) self.assertIsNone(ensemble_spec.adanet_loss) self.assertIsNone(ensemble_spec.train_op) self.assertIsNotNone(ensemble_spec.export_outputs) return # Verify that train_op works, previous loss should be greater than loss # after a train op. loss = sess.run(ensemble_spec.loss) train_op = tf.group(subnetwork_spec.train_op.train_op, ensemble_spec.train_op.train_op) for _ in range(3): sess.run(train_op) self.assertGreater(loss, sess.run(ensemble_spec.loss)) self.assertAllClose(want_logits, sess.run(ensemble_spec.ensemble.logits), atol=1e-3) # Bias should learn a non-zero value when used. bias = sess.run(ensemble_spec.ensemble.bias) if isinstance(bias, dict): bias = sum(abs(b) for b in bias.values()) if use_bias: self.assertNotEqual(0., bias) else: self.assertAlmostEqual(0., bias) self.assertAlmostEqual(want_loss, sess.run(ensemble_spec.loss), places=3) self.assertAlmostEqual(want_adanet_loss, sess.run(ensemble_spec.adanet_loss), places=3)
def _train_model(self, input_fn, hooks): all_hooks = [] self._graph = ops.Graph() with self._graph.as_default() as g, g.device(self._device_fn): random_seed.set_random_seed(self._config.tf_random_seed) global_step = training.get_or_create_global_step(g) features, labels = input_fn() estimator_spec = self._call_model_fn(features, labels, ModeKeys.TRAIN) all_hooks.extend([ plx_hooks.NanTensorHook(estimator_spec.loss), plx_hooks.LoggingTensorHook( { 'loss': estimator_spec.loss, 'step': global_step }, every_n_iter=100) ]) all_hooks.extend(hooks) all_hooks.extend(estimator_spec.training_hooks) scaffold = estimator_spec.scaffold or monitored_session.Scaffold() if not (scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)): ops.add_to_collection( ops.GraphKeys.SAVERS, # TODO remove non restorable vars saver.Saver( sharded=True, # TODO `var_list` max_to_keep=self._config.keep_checkpoint_max, defer_build=True)) chief_hooks = [] if self._config.save_checkpoints_secs or self._config.save_checkpoints_steps: saver_hook_exists = any([ isinstance(h, plx_hooks.CheckpointSaverHook) for h in (all_hooks + estimator_spec.training_hooks + chief_hooks + estimator_spec.training_chief_hooks) ]) if not saver_hook_exists: chief_hooks = [ plx_hooks.CheckpointSaverHook( self._model_dir, save_secs=self._config.save_checkpoints_secs, save_steps=self._config.save_checkpoints_steps, scaffold=scaffold) ] with monitored_session.MonitoredTrainingSession( master=self._config.master, is_chief=self._config.is_chief, checkpoint_dir=self._model_dir, scaffold=scaffold, hooks=all_hooks + estimator_spec.training_hooks, chief_only_hooks=chief_hooks + estimator_spec.training_chief_hooks, save_checkpoint_secs=0, # Saving is handled by a hook. save_summaries_steps=self._config.save_summary_steps, config=self._session_config) as mon_sess: loss = None while not mon_sess.should_stop(): _, loss = mon_sess.run( [estimator_spec.train_op, estimator_spec.loss]) summary_io.SummaryWriterCache.clear() return loss
def predict(self, input_fn=None, predict_keys=None, hooks=None, checkpoint_path=None): """Returns predictions for given features. Args: input_fn: Input function returning features which is a dictionary of string feature name to `Tensor` or `SparseTensor`. If it returns a tuple, first item is extracted as features. Prediction continues until `input_fn` raises an end-of-input exception (`OutOfRangeError` or `StopIteration`). predict_keys: list of `str`, name of the keys to predict. It is used if the `EstimatorSpec.predictions` is a `dict`. If `predict_keys` is used then rest of the predictions will be filtered from the dictionary. If `None`, returns all. hooks: List of `SessionRunHook` subclass instances. Used for callbacks inside the prediction call. checkpoint_path: Path of a specific checkpoint to predict. If `None`, the latest checkpoint in `model_dir` is used. Yields: Evaluated values of `predictions` tensors. Raises: ValueError: Could not find a trained model in model_dir. ValueError: if batch length of predictions are not same. ValueError: If there is a conflict between `predict_keys` and `predictions`. For example if `predict_keys` is not `None` but `EstimatorSpec.predictions` is not a `dict`. """ hooks = self._check_hooks(hooks) # Check that model has been trained. if not checkpoint_path: checkpoint_path = saver.latest_checkpoint(self._model_dir) if not checkpoint_path: raise ValueError("Could not find trained model at %s." % self._model_dir) with ops.Graph().as_default() as g: random_seed.set_random_seed(self._config.tf_random_seed) training.get_or_create_global_step(g) features = self._get_features_from_input_fn(input_fn) estimator_spec = self._call_model_fn(features, None, ModeKeys.PREDICT) predictions = self._extract_keys(estimator_spec.predictions, predict_keys) with monitored_session.MonitoredSession( session_creator=monitored_session.ChiefSessionCreator( checkpoint_filename_with_path=checkpoint_path, scaffold=estimator_spec.scaffold, config=self._session_config), hooks=hooks) as mon_sess: while not mon_sess.should_stop(): preds_evaluated = mon_sess.run(predictions) if not isinstance(predictions, dict): for pred in preds_evaluated: yield pred else: for i in range(extract_batch_length(preds_evaluated)): yield { key: value[i] for key, value in six.iteritems( preds_evaluated) }
def export_savedmodel(self, export_dir_base, serving_input_receiver_fn, assets_extra=None, as_text=False, checkpoint_path=None): """Exports inference graph as a SavedModel into given dir. This method builds a new graph by first calling the serving_input_receiver_fn to obtain feature `Tensor`s, and then calling this `Estimator`'s model_fn to generate the model graph based on those features. It restores the given checkpoint (or, lacking that, the most recent checkpoint) into this graph in a fresh session. Finally it creates a timestamped export directory below the given export_dir_base, and writes a `SavedModel` into it containing a single `MetaGraphDef` saved from this session. The exported `MetaGraphDef` will provide one `SignatureDef` for each element of the export_outputs dict returned from the model_fn, named using the same keys. One of these keys is always signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, indicating which signature will be served when a serving request does not specify one. For each signature, the outputs are provided by the corresponding `ExportOutput`s, and the inputs are always the input receivers provided by the serving_input_receiver_fn. Extra assets may be written into the SavedModel via the extra_assets argument. This should be a dict, where each key gives a destination path (including the filename) relative to the assets.extra directory. The corresponding value gives the full path of the source file to be copied. For example, the simple case of copying a single file without renaming it is specified as `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`. Args: export_dir_base: A string containing a directory in which to create timestamped subdirectories containing exported SavedModels. serving_input_receiver_fn: A function that takes no argument and returns a `ServingInputReceiver`. assets_extra: A dict specifying how to populate the assets.extra directory within the exported SavedModel, or `None` if no extra assets are needed. as_text: whether to write the SavedModel proto in text format. checkpoint_path: The checkpoint path to export. If `None` (the default), the most recent checkpoint found within the model directory is chosen. Returns: The string path to the exported directory. Raises: ValueError: if no serving_input_receiver_fn is provided, no export_outputs are provided, or no checkpoint can be found. """ if serving_input_receiver_fn is None: raise ValueError('serving_input_receiver_fn must be defined.') with ops.Graph().as_default() as g: training.get_or_create_global_step(g) random_seed.set_random_seed(self._config.tf_random_seed) serving_input_receiver = serving_input_receiver_fn() # Call the model_fn and collect the export_outputs. estimator_spec = self._call_model_fn( features=serving_input_receiver.features, labels=None, mode=model_fn_lib.ModeKeys.PREDICT) # Build the SignatureDefs from receivers and all outputs signature_def_map = build_all_signature_defs( serving_input_receiver.receiver_tensors, estimator_spec.export_outputs) if not checkpoint_path: # Locate the latest checkpoint checkpoint_path = saver.latest_checkpoint(self._model_dir) if not checkpoint_path: raise ValueError("Couldn't find trained model at %s." % self._model_dir) export_dir = get_timestamped_export_dir(export_dir_base) with tf_session.Session() as session: saver_for_restore = estimator_spec.scaffold.saver or saver.Saver( sharded=True) saver_for_restore.restore(session, checkpoint_path) # pylint: disable=protected-access local_init_op = ( estimator_spec.scaffold.local_init_op or monitored_session.Scaffold._default_local_init_op()) # pylint: enable=protected-access # Perform the export builder = saved_model_builder.SavedModelBuilder(export_dir) builder.add_meta_graph_and_variables( session, [tag_constants.SERVING], signature_def_map=signature_def_map, assets_collection=ops.get_collection( ops.GraphKeys.ASSET_FILEPATHS), legacy_init_op=local_init_op) builder.save(as_text) # Add the extra assets if assets_extra: assets_extra_path = os.path.join( compat.as_bytes(export_dir), compat.as_bytes('assets.extra')) for dest_relative, source in assets_extra.items(): dest_absolute = os.path.join( compat.as_bytes(assets_extra_path), compat.as_bytes(dest_relative)) dest_path = os.path.dirname(dest_absolute) gfile.MakeDirs(dest_path) gfile.Copy(source, dest_absolute) return export_dir
def _train_model(self, input_fn, hooks): all_hooks = [] self._graph = ops.Graph() with self._graph.as_default() as g, g.device(self._device_fn): random_seed.set_random_seed(self._config.tf_random_seed) global_step = training.get_or_create_global_step(g) with ops.device('/cpu:0'): features, labels = input_fn() estimator_spec = self._call_model_fn(features, labels, Modes.TRAIN) ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss) all_hooks.extend([ plx_hooks.NanTensorHook(estimator_spec.loss), plx_hooks.StepLoggingTensorHook( { 'loss': estimator_spec.loss, 'step': global_step }, every_n_iter=100) ]) all_hooks.extend(hooks) all_hooks.extend(estimator_spec.training_hooks) scaffold = estimator_spec.scaffold or monitored_session.Scaffold() if not (scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)): ops.add_to_collection(ops.GraphKeys.SAVERS, # TODO remove non restorable vars saver.Saver(sharded=True, # TODO `var_list` max_to_keep=self._config.keep_checkpoint_max, defer_build=True)) chief_hooks = [] if self._config.save_checkpoints_secs or self._config.save_checkpoints_steps: saver_hook_exists = any( [isinstance(h, plx_hooks.StepCheckpointSaverHook) for h in (all_hooks + chief_hooks + list(estimator_spec.training_chief_hooks))]) if not saver_hook_exists: chief_hooks += [ plx_hooks.StepCheckpointSaverHook( self._model_dir, save_secs=self._config.save_checkpoints_secs, save_steps=self._config.save_checkpoints_steps, scaffold=scaffold) ] if self._config.save_summary_steps: saver_hook_exists = any( [isinstance(h, plx_hooks.StepSummarySaverHook) for h in (all_hooks + chief_hooks + list(estimator_spec.training_chief_hooks))]) if not saver_hook_exists: chief_hooks += [ plx_hooks.StepSummarySaverHook( scaffold=scaffold, save_steps=self._config.save_summary_steps, output_dir=self._model_dir, ) ] with monitored_session.MonitoredTrainingSession( master=self._config.master, is_chief=self._config.is_chief, checkpoint_dir=self._model_dir, scaffold=scaffold, hooks=all_hooks, chief_only_hooks=chief_hooks + list(estimator_spec.training_chief_hooks), save_checkpoint_secs=0, # Saving checkpoint is handled by a hook. save_summaries_steps=0, # Saving summaries is handled by a hook. config=self._session_config) as mon_sess: loss = None while not mon_sess.should_stop(): _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss]) summary_io.SummaryWriterCache.clear() return loss
def export_savedmodel(self, export_dir_base, serving_input_receiver_fn, assets_extra=None, as_text=False, checkpoint_path=None): """Exports inference graph as a SavedModel into given dir. This method builds a new graph by first calling the serving_input_receiver_fn to obtain feature `Tensor`s, and then calling this `Estimator`'s model_fn to generate the model graph based on those features. It restores the given checkpoint (or, lacking that, the most recent checkpoint) into this graph in a fresh session. Finally it creates a timestamped export directory below the given export_dir_base, and writes a `SavedModel` into it containing a single `MetaGraphDef` saved from this session. The exported `MetaGraphDef` will provide one `SignatureDef` for each element of the export_outputs dict returned from the model_fn, named using the same keys. One of these keys is always signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, indicating which signature will be served when a serving request does not specify one. For each signature, the outputs are provided by the corresponding `ExportOutput`s, and the inputs are always the input receivers provided by the serving_input_receiver_fn. Extra assets may be written into the SavedModel via the extra_assets argument. This should be a dict, where each key gives a destination path (including the filename) relative to the assets.extra directory. The corresponding value gives the full path of the source file to be copied. For example, the simple case of copying a single file without renaming it is specified as `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`. Args: export_dir_base: A string containing a directory in which to create timestamped subdirectories containing exported SavedModels. serving_input_receiver_fn: A function that takes no argument and returns a `ServingInputReceiver`. assets_extra: A dict specifying how to populate the assets.extra directory within the exported SavedModel, or `None` if no extra assets are needed. as_text: whether to write the SavedModel proto in text format. checkpoint_path: The checkpoint path to export. If `None` (the default), the most recent checkpoint found within the model directory is chosen. Returns: The string path to the exported directory. Raises: ValueError: if no serving_input_receiver_fn is provided, no export_outputs are provided, or no checkpoint can be found. """ if serving_input_receiver_fn is None: raise ValueError('serving_input_receiver_fn must be defined.') with ops.Graph().as_default() as g: training.get_or_create_global_step(g) random_seed.set_random_seed(self._config.tf_random_seed) serving_input_receiver = serving_input_receiver_fn() # Call the model_fn and collect the export_outputs. estimator_spec = self._call_model_fn( features=serving_input_receiver.features, labels=None, mode=Modes.PREDICT) # Build the SignatureDefs from receivers and all outputs signature_def_map = build_all_signature_defs( serving_input_receiver.receiver_tensors, estimator_spec.export_outputs) if not checkpoint_path: # Locate the latest checkpoint checkpoint_path = saver.latest_checkpoint(self._model_dir) if not checkpoint_path: raise ValueError("Couldn't find trained model at %s." % self._model_dir) export_dir = get_timestamped_export_dir(export_dir_base) with tf_session.Session() as session: saver_for_restore = estimator_spec.scaffold.saver or saver.Saver(sharded=True) saver_for_restore.restore(session, checkpoint_path) local_init_op = (estimator_spec.scaffold.local_init_op or monitored_session.Scaffold._default_local_init_op()) # Perform the export builder = saved_model_builder.SavedModelBuilder(export_dir) builder.add_meta_graph_and_variables( session, [tag_constants.SERVING], signature_def_map=signature_def_map, assets_collection=ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS), legacy_init_op=local_init_op) builder.save(as_text) # Add the extra assets if assets_extra: assets_extra_path = os.path.join(compat.as_bytes(export_dir), compat.as_bytes('assets.extra')) for dest_relative, source in assets_extra.items(): dest_absolute = os.path.join(compat.as_bytes(assets_extra_path), compat.as_bytes(dest_relative)) dest_path = os.path.dirname(dest_absolute) gfile.MakeDirs(dest_path) gfile.Copy(source, dest_absolute) return export_dir
def test_build_ensemble_spec( self, want_logits, want_loss=None, want_adanet_loss=None, want_ensemble_trainable_vars=None, adanet_lambda=0., adanet_beta=0., ensemble_spec_fn=lambda: None, use_bias=False, use_logits_last_layer=False, mixture_weight_type=MixtureWeightType.MATRIX, mixture_weight_initializer=tf_compat.v1.zeros_initializer(), warm_start_mixture_weights=True, subnetwork_builder_class=_Builder, mode=tf.estimator.ModeKeys.TRAIN, multi_head=False, want_subnetwork_trainable_vars=2, ensembler_class=ComplexityRegularizedEnsembler, my_ensemble_index=None, want_replay_indices=None, want_predictions=None, export_subnetworks=False, previous_ensemble_spec=None, previous_iteration_checkpoint=None): seed = 64 if multi_head: head = multi_head_lib.MultiHead(heads=[ binary_class_head.BinaryClassHead( name="head1", loss_reduction=tf_compat.SUM), binary_class_head.BinaryClassHead(name="head2", loss_reduction=tf_compat.SUM) ]) else: head = binary_class_head.BinaryClassHead( loss_reduction=tf_compat.SUM) builder = _EnsembleBuilder( head=head, export_subnetwork_logits=export_subnetworks, export_subnetwork_last_layer=export_subnetworks) def _subnetwork_train_op_fn(loss, var_list): self.assertLen(var_list, want_subnetwork_trainable_vars) self.assertEqual( var_list, tf_compat.v1.get_collection( tf_compat.v1.GraphKeys.TRAINABLE_VARIABLES)) # Subnetworks get iteration steps instead of global steps. self.assertEqual("subnetwork_test/iteration_step", tf_compat.v1.train.get_global_step().op.name) # Subnetworks get scoped summaries. self.assertEqual("fake_scalar", tf_compat.v1.summary.scalar("scalar", 1.)) self.assertEqual("fake_image", tf_compat.v1.summary.image("image", 1.)) self.assertEqual("fake_histogram", tf_compat.v1.summary.histogram("histogram", 1.)) self.assertEqual("fake_audio", tf_compat.v1.summary.audio("audio", 1., 1.)) optimizer = tf_compat.v1.train.GradientDescentOptimizer( learning_rate=.1) return optimizer.minimize(loss, var_list=var_list) def _mixture_weights_train_op_fn(loss, var_list): self.assertLen(var_list, want_ensemble_trainable_vars) self.assertEqual( var_list, tf_compat.v1.get_collection( tf_compat.v1.GraphKeys.TRAINABLE_VARIABLES)) # Subnetworks get iteration steps instead of global steps. self.assertEqual("ensemble_test/iteration_step", tf_compat.v1.train.get_global_step().op.name) # Subnetworks get scoped summaries. self.assertEqual("fake_scalar", tf_compat.v1.summary.scalar("scalar", 1.)) self.assertEqual("fake_image", tf_compat.v1.summary.image("image", 1.)) self.assertEqual("fake_histogram", tf_compat.v1.summary.histogram("histogram", 1.)) self.assertEqual("fake_audio", tf_compat.v1.summary.audio("audio", 1., 1.)) if not var_list: return tf.no_op() optimizer = tf_compat.v1.train.GradientDescentOptimizer( learning_rate=.1) return optimizer.minimize(loss, var_list=var_list) previous_ensemble = None previous_ensemble_spec = ensemble_spec_fn() if previous_ensemble_spec: previous_ensemble = previous_ensemble_spec.ensemble subnetwork_manager = _SubnetworkManager(head) subnetwork_builder = subnetwork_builder_class( _subnetwork_train_op_fn, _mixture_weights_train_op_fn, use_logits_last_layer, seed, multi_head=multi_head) with tf.Graph().as_default() as g: tf_compat.v1.train.get_or_create_global_step() # A trainable variable to later verify that creating models does not # affect the global variables collection. _ = tf_compat.v1.get_variable("some_var", shape=0, trainable=True) features = {"x": tf.constant([[1.], [2.]])} if multi_head: labels = { "head1": tf.constant([0, 1]), "head2": tf.constant([0, 1]) } else: labels = tf.constant([0, 1]) session_config = tf.compat.v1.ConfigProto( gpu_options=tf.compat.v1.GPUOptions(allow_growth=True)) subnetwork_spec = subnetwork_manager.build_subnetwork_spec( name="test", subnetwork_builder=subnetwork_builder, summary=_FakeSummary(), features=features, mode=mode, labels=labels, previous_ensemble=previous_ensemble) ensembler_kwargs = {} if ensembler_class is ComplexityRegularizedEnsembler: ensembler_kwargs.update({ "mixture_weight_type": mixture_weight_type, "mixture_weight_initializer": mixture_weight_initializer, "warm_start_mixture_weights": warm_start_mixture_weights, "model_dir": self.test_subdirectory, "adanet_lambda": adanet_lambda, "adanet_beta": adanet_beta, "use_bias": use_bias }) if ensembler_class is MeanEnsembler: ensembler_kwargs.update( {"add_mean_last_layer_predictions": True}) ensemble_spec = builder.build_ensemble_spec( # Note: when ensemble_spec is not None and warm_start_mixture_weights # is True, we need to make sure that the bias and mixture weights are # already saved to the checkpoint_dir. name="test", previous_ensemble_spec=previous_ensemble_spec, candidate=EnsembleCandidate("foo", [subnetwork_builder], None), ensembler=ensembler_class(**ensembler_kwargs), subnetwork_specs=[subnetwork_spec], summary=_FakeSummary(), features=features, iteration_number=1, labels=labels, my_ensemble_index=my_ensemble_index, mode=mode, previous_iteration_checkpoint=previous_iteration_checkpoint) if want_replay_indices: self.assertAllEqual(want_replay_indices, ensemble_spec.architecture.replay_indices) with tf_compat.v1.Session( graph=g, config=session_config).as_default() as sess: sess.run(tf_compat.v1.global_variables_initializer()) # Equals the number of subnetwork and ensemble trainable variables, # plus the one 'some_var' created earlier. self.assertLen( tf_compat.v1.trainable_variables(), want_subnetwork_trainable_vars + want_ensemble_trainable_vars + 1) # Get the real global step outside a subnetwork's context. self.assertEqual("global_step", tf_compat.v1.train.get_global_step().op.name) self.assertEqual("global_step", train.get_global_step().op.name) self.assertEqual("global_step", tf_v1.train.get_global_step().op.name) self.assertEqual("global_step", training_util.get_global_step().op.name) self.assertEqual( "global_step", tf_compat.v1.train.get_or_create_global_step().op.name) self.assertEqual("global_step", train.get_or_create_global_step().op.name) self.assertEqual( "global_step", tf_v1.train.get_or_create_global_step().op.name) self.assertEqual( "global_step", training_util.get_or_create_global_step().op.name) # Get global tf.summary outside a subnetwork's context. self.assertNotEqual("fake_scalar", tf_compat.v1.summary.scalar("scalar", 1.)) self.assertNotEqual("fake_image", tf_compat.v1.summary.image("image", 1.)) self.assertNotEqual( "fake_histogram", tf_compat.v1.summary.histogram("histogram", 1.)) self.assertNotEqual( "fake_audio", tf_compat.v1.summary.audio("audio", 1., 1.)) if mode == tf.estimator.ModeKeys.PREDICT: self.assertAllClose(want_logits, sess.run( ensemble_spec.ensemble.logits), atol=1e-3) self.assertIsNone(ensemble_spec.loss) self.assertIsNone(ensemble_spec.adanet_loss) self.assertIsNone(ensemble_spec.train_op) self.assertIsNotNone(ensemble_spec.export_outputs) if not export_subnetworks: return if not multi_head: subnetwork_logits = sess.run( ensemble_spec.export_outputs[ _EnsembleBuilder. _SUBNETWORK_LOGITS_EXPORT_SIGNATURE].outputs) self.assertAllClose( subnetwork_logits["test"], sess.run(subnetwork_spec.subnetwork.logits)) subnetwork_last_layer = sess.run( ensemble_spec.export_outputs[ _EnsembleBuilder. _SUBNETWORK_LAST_LAYER_EXPORT_SIGNATURE]. outputs) self.assertAllClose( subnetwork_last_layer["test"], sess.run(subnetwork_spec.subnetwork.last_layer)) else: self.assertIn("subnetwork_logits_head2", ensemble_spec.export_outputs) subnetwork_logits_head1 = sess.run( ensemble_spec. export_outputs["subnetwork_logits_head1"].outputs) self.assertAllClose( subnetwork_logits_head1["test"], sess.run( subnetwork_spec.subnetwork.logits["head1"])) self.assertIn("subnetwork_logits_head2", ensemble_spec.export_outputs) subnetwork_last_layer_head1 = sess.run( ensemble_spec.export_outputs[ "subnetwork_last_layer_head1"].outputs) self.assertAllClose( subnetwork_last_layer_head1["test"], sess.run(subnetwork_spec.subnetwork. last_layer["head1"])) return # Verify that train_op works, previous loss should be greater than loss # after a train op. loss = sess.run(ensemble_spec.loss) train_op = tf.group(subnetwork_spec.train_op.train_op, ensemble_spec.train_op.train_op) for _ in range(3): sess.run(train_op) self.assertGreater(loss, sess.run(ensemble_spec.loss)) self.assertAllClose(want_logits, sess.run(ensemble_spec.ensemble.logits), atol=1e-3) if ensembler_class is ComplexityRegularizedEnsembler: # Bias should learn a non-zero value when used. bias = sess.run(ensemble_spec.ensemble.bias) if isinstance(bias, dict): bias = sum(abs(b) for b in bias.values()) if use_bias: self.assertNotEqual(0., bias) else: self.assertAlmostEqual(0., bias) self.assertAlmostEqual(want_loss, sess.run(ensemble_spec.loss), places=3) self.assertAlmostEqual(want_adanet_loss, sess.run(ensemble_spec.adanet_loss), places=3) if want_predictions: self.assertAllClose( want_predictions, sess.run(ensemble_spec.ensemble.predictions), atol=1e-3)
def _train_model(self, input_fn, hooks): all_hooks = [] with ops.Graph().as_default() as g, g.device(self._device_fn): random_seed.set_random_seed(self._config.tf_random_seed) global_step = training.get_or_create_global_step(g) features, labels = self._get_features_and_labels_from_input_fn( input_fn, Modes.TRAIN) estimator_spec = self._call_model_fn(features, labels, Modes.TRAIN) ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss) all_hooks.extend([ plx_hooks.NanTensorHook(estimator_spec.loss), plx_hooks.StepLoggingTensorHook( { 'loss': estimator_spec.loss, 'step': global_step }, every_n_iter=100) ]) all_hooks.extend(hooks) all_hooks.extend(estimator_spec.training_hooks) scaffold = estimator_spec.scaffold if not (scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)): ops.add_to_collection( ops.GraphKeys.SAVERS, # TODO remove non restorable vars saver.Saver( sharded=True, max_to_keep=self._config.keep_checkpoint_max, keep_checkpoint_every_n_hours=( self._config.keep_checkpoint_every_n_hours), defer_build=True, save_relative_paths=True)) chief_hooks = [] if self._config.save_checkpoints_secs or self._config.save_checkpoints_steps: saver_hook_exists = any([ isinstance(h, plx_hooks.StepCheckpointSaverHook) for h in (all_hooks + chief_hooks + list(estimator_spec.training_chief_hooks)) ]) if not saver_hook_exists: chief_hooks += [ plx_hooks.StepCheckpointSaverHook( self._model_dir, save_secs=self._config.save_checkpoints_secs, save_steps=self._config.save_checkpoints_steps, scaffold=scaffold) ] if self._config.save_summary_steps: saver_hook_exists = any([ isinstance(h, plx_hooks.StepSummarySaverHook) for h in (all_hooks + chief_hooks + list(estimator_spec.training_chief_hooks)) ]) if not saver_hook_exists: chief_hooks += [ plx_hooks.StepSummarySaverHook( scaffold=scaffold, save_steps=self._config.save_summary_steps, output_dir=self._model_dir, ) ] with monitored_session.MonitoredTrainingSession( master=self._config.master, is_chief=self._config.is_chief, checkpoint_dir=self._model_dir, scaffold=scaffold, hooks=all_hooks, chief_only_hooks=chief_hooks + list(estimator_spec.training_chief_hooks), save_checkpoint_secs= 0, # Saving checkpoint is handled by a hook. save_summaries_steps= 0, # Saving summaries is handled by a hook. config=self._session_config) as mon_sess: loss = None while not mon_sess.should_stop(): _, loss = mon_sess.run( [estimator_spec.train_op, estimator_spec.loss]) return loss
def _global_step(self): return training.get_or_create_global_step()
def bad_input_fn(): training.get_or_create_global_step() return dataset_ops.Dataset.from_tensors(( {'x': constant_op.constant([[1], [1]], dtype=dtypes.int64)}, constant_op.constant([[1], [1]], dtype=dtypes.float32)))
def _train_model(self, env, hooks): all_hooks = [] self._graph = ops.Graph() with self._graph.as_default() as g, g.device(self._device_fn): random_seed.set_random_seed(self._config.tf_random_seed) global_step = training.get_or_create_global_step(g) global_episode = get_or_create_global_episode(g) global_timestep = get_or_create_global_timestep(g) update_episode_op = tf.assign_add(global_episode, 1) update_timestep_op = tf.assign_add(global_timestep, 1) no_run_hooks = tf.no_op(name='no_run_hooks') with ops.device('/cpu:0'): features, labels = self._prepare_input_fn(Modes.TRAIN, env) estimator_spec = self._call_model_fn(features, labels, Modes.TRAIN) ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss) all_hooks.extend([ plx_hooks.NanTensorHook(estimator_spec.loss), plx_hooks.StepLoggingTensorHook( { 'loss': estimator_spec.loss, 'step': global_step, 'timestep': global_timestep, 'global_episode': global_episode, 'max_reward': labels['max_reward'], 'min_reward': labels['min_reward'], 'total_reward': labels['total_reward'], }, every_n_iter=100) ]) all_hooks.extend(hooks) all_hooks.extend(estimator_spec.training_hooks) scaffold = estimator_spec.scaffold or monitored_session.Scaffold() if not (scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)): ops.add_to_collection(ops.GraphKeys.SAVERS, # TODO remove non restorable vars saver.Saver(sharded=True, # TODO `var_list` max_to_keep=self._config.keep_checkpoint_max, defer_build=True)) chief_hooks = [ plx_hooks.EpisodeLoggingTensorHook( { 'loss': estimator_spec.loss, 'step': global_step, 'global_timestep': global_timestep, 'global_episode': global_episode, 'max_reward': labels['max_reward'], 'min_reward': labels['min_reward'], 'total_reward': labels['total_reward'], }, every_n_episodes=100), # TODO: save every episode? plx_hooks.EpisodeCounterHook(output_dir=self.model_dir) ] if self._config.save_checkpoints_secs or self._config.save_checkpoints_steps: saver_hook_exists = any( [isinstance(h, plx_hooks.EpisodeCheckpointSaverHook) for h in (all_hooks + chief_hooks + list(estimator_spec.training_chief_hooks))]) if not saver_hook_exists: chief_hooks += [ plx_hooks.EpisodeCheckpointSaverHook( self._model_dir, save_episodes=100, # TODO: save every episode? scaffold=scaffold) ] if self._config.save_summary_steps: saver_hook_exists = any( [isinstance(h, plx_hooks.EpisodeSummarySaverHook) for h in (all_hooks + chief_hooks + list(estimator_spec.training_chief_hooks))]) if not saver_hook_exists: chief_hooks += [ plx_hooks.EpisodeSummarySaverHook( scaffold=scaffold, save_episodes=100, # TODO: save every episode? output_dir=self._model_dir, ) ] with monitored_session.MonitoredTrainingSession( master=self._config.master, is_chief=self._config.is_chief, checkpoint_dir=self._model_dir, scaffold=scaffold, hooks=all_hooks, chief_only_hooks=chief_hooks + list(estimator_spec.training_chief_hooks), save_checkpoint_secs=0, # Saving checkpoint is handled by a hook. save_summaries_steps=0, # Saving summaries is handled by a hook. config=self._session_config) as mon_sess: loss = None while not mon_sess.should_stop(): loss = self.run_episode( env=env, sess=mon_sess, features=features, labels=labels, no_run_hooks=no_run_hooks, global_step=global_step, update_episode_op=update_episode_op, update_timestep_op=update_timestep_op, estimator_spec=estimator_spec) summary_io.SummaryWriterCache.clear() return loss
def bad_input_fn(): training.get_or_create_global_step() return dataset_ops.Dataset.from_tensors(({ 'x': constant_op.constant([[1], [1]], dtype=dtypes.int64) }, constant_op.constant([[1], [1]], dtype=dtypes.float32)))
def _train_model(self, env, first_update, update_frequency, hooks): all_hooks = [] self._graph = ops.Graph() with self._graph.as_default() as g, g.device(self._device_fn): random_seed.set_random_seed(self._config.tf_random_seed) global_step = training.get_or_create_global_step(g) global_episode = get_or_create_global_episode(g) global_timestep = get_or_create_global_timestep(g) update_episode_op = tf.assign_add(global_episode, 1) update_timestep_op = tf.assign_add(global_timestep, 1) no_run_hooks = tf.no_op(name='no_run_hooks') with ops.device('/cpu:0'): features, labels = self._prepare_input_fn(Modes.TRAIN, env) estimator_spec = self._call_model_fn(features, labels, Modes.TRAIN) ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss) all_hooks.extend([ plx_hooks.NanTensorHook(estimator_spec.loss), plx_hooks.StepLoggingTensorHook( { 'loss': estimator_spec.loss, 'step': global_step, 'timestep': global_timestep, 'global_episode': global_episode, 'max_reward': labels['max_reward'], 'min_reward': labels['min_reward'], 'total_reward': labels['total_reward'], }, every_n_iter=100) ]) all_hooks.extend(hooks) all_hooks.extend(estimator_spec.training_hooks) scaffold = estimator_spec.scaffold or monitored_session.Scaffold() if not (scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)): ops.add_to_collection( ops.GraphKeys.SAVERS, # TODO remove non restorable vars saver.Saver( sharded=True, # TODO `var_list` max_to_keep=self._config.keep_checkpoint_max, defer_build=True)) chief_hooks = [ plx_hooks.EpisodeLoggingTensorHook( { 'loss': estimator_spec.loss, 'step': global_step, 'global_timestep': global_timestep, 'global_episode': global_episode, 'max_reward': labels['max_reward'], 'min_reward': labels['min_reward'], 'total_reward': labels['total_reward'], }, every_n_episodes=1), # TODO: save every episode? plx_hooks.EpisodeCounterHook(output_dir=self.model_dir) ] if self._config.save_checkpoints_secs or self._config.save_checkpoints_steps: saver_hook_exists = any([ isinstance(h, plx_hooks.EpisodeCheckpointSaverHook) for h in (all_hooks + chief_hooks + list(estimator_spec.training_chief_hooks)) ]) if not saver_hook_exists: chief_hooks += [ plx_hooks.EpisodeCheckpointSaverHook( self._model_dir, save_episodes=1, # TODO: save every episode? scaffold=scaffold) ] if self._config.save_summary_steps: saver_hook_exists = any([ isinstance(h, plx_hooks.EpisodeSummarySaverHook) for h in (all_hooks + chief_hooks + list(estimator_spec.training_chief_hooks)) ]) if not saver_hook_exists: chief_hooks += [ plx_hooks.EpisodeSummarySaverHook( scaffold=scaffold, save_episodes=1, # TODO: save every episode? output_dir=self._model_dir, ) ] with monitored_session.MonitoredTrainingSession( master=self._config.master, is_chief=self._config.is_chief, checkpoint_dir=self._model_dir, scaffold=scaffold, hooks=all_hooks, chief_only_hooks=chief_hooks + list(estimator_spec.training_chief_hooks), save_checkpoint_secs= 0, # Saving checkpoint is handled by a hook. save_summaries_steps= 0, # Saving summaries is handled by a hook. config=self._session_config) as mon_sess: loss = None while not mon_sess.should_stop(): loss = self.run_episode( env=env, sess=mon_sess, features=features, labels=labels, no_run_hooks=no_run_hooks, global_step=global_step, update_episode_op=update_episode_op, update_timestep_op=update_timestep_op, first_update=first_update, update_frequency=update_frequency, estimator_spec=estimator_spec) summary_io.SummaryWriterCache.clear() return loss