Beispiel #1
0
  def _train_model(self, input_fn, hooks):
    all_hooks = []
    with ops.Graph().as_default() as g, g.device(self._device_fn):
      random_seed.set_random_seed(self._config.tf_random_seed)
      global_step_tensor = training.create_global_step(g)
      with ops.device('/cpu:0'):
        features, labels = input_fn()
      estimator_spec = self._call_model_fn(features, labels,
                                           model_fn_lib.ModeKeys.FIT)
      ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss)
      all_hooks.extend([
          training.NanTensorHook(estimator_spec.loss),
          training.LoggingTensorHook(
              {
                  'loss': estimator_spec.loss,
                  'step': global_step_tensor
              },
              every_n_iter=100)
      ])
      all_hooks.extend(hooks)
      all_hooks.extend(estimator_spec.training_hooks)

      scaffold = estimator_spec.scaffold or training.Scaffold()
      if not (scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)):
        ops.add_to_collection(ops.GraphKeys.SAVERS,
                              training.Saver(
                                  sharded=True,
                                  max_to_keep=self._config.keep_checkpoint_max,
                                  defer_build=True))

      chief_hooks = []
      if (self._config.save_checkpoints_secs or
          self._config.save_checkpoints_steps):
        saver_hook_exists = any([
            isinstance(h, training.CheckpointSaverHook)
            for h in (all_hooks + chief_hooks +
                      estimator_spec.training_chief_hooks)
        ])
        if not saver_hook_exists:
          chief_hooks = [
              training.CheckpointSaverHook(
                  self._model_dir,
                  save_secs=self._config.save_checkpoints_secs,
                  save_steps=self._config.save_checkpoints_steps,
                  scaffold=scaffold)
          ]
      with training.MonitoredTrainingSession(
          master=self._config.master,
          is_chief=self._config.is_chief,
          checkpoint_dir=self._model_dir,
          scaffold=scaffold,
          hooks=all_hooks,
          chief_only_hooks=chief_hooks + estimator_spec.training_chief_hooks,
          save_checkpoint_secs=0,  # Saving is handled by a hook.
          save_summaries_steps=self._config.save_summary_steps,
          config=config_pb2.ConfigProto(allow_soft_placement=True)) as mon_sess:
        loss = None
        while not mon_sess.should_stop():
          _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
      return loss
  def _restore(self, path):
    """Restores this estimator from given path.

    Note: will rebuild the graph and initialize all parameters,
    and will ignore provided model.

    Args:
      path: Path to checkpoints and other information.
    """
    # Currently Saver requires absolute path to work correctly.
    path = os.path.abspath(path)

    self._graph = ops.Graph()
    with self._graph.as_default():
      endpoints_filename = os.path.join(path, 'endpoints')
      if not os.path.exists(endpoints_filename):
        raise ValueError("Restore folder doesn't contain endpoints.")
      with gfile.Open(endpoints_filename) as foutputs:
        endpoints = foutputs.read().split('\n')
      graph_filename = os.path.join(path, 'graph.pbtxt')
      if not os.path.exists(graph_filename):
        raise ValueError("Restore folder doesn't contain graph definition.")
      with gfile.Open(graph_filename) as fgraph:
        graph_def = graph_pb2.GraphDef()
        text_format.Merge(fgraph.read(), graph_def)
        (self._inp, self._out, self._model_predictions,
         self._model_loss) = importer.import_graph_def(
             graph_def, name='', return_elements=endpoints)
      saver_filename = os.path.join(path, 'saver.pbtxt')
      if not os.path.exists(saver_filename):
        raise ValueError("Restore folder doesn't contain saver definition.")
      with gfile.Open(saver_filename) as fsaver:
        saver_def = train.SaverDef()
        text_format.Merge(fsaver.read(), saver_def)
        self._saver = train.Saver(saver_def=saver_def)

      # Restore trainer
      self._global_step = self._graph.get_tensor_by_name('global_step:0')
      self._train = self._graph.get_operation_by_name('OptimizeLoss/train')

      # Restore summaries.
      self._summaries = self._graph.get_operation_by_name(
          'MergeSummary/MergeSummary')

      # Restore session.
      if not isinstance(self._config, RunConfig):
        self._config = RunConfig(verbose=self.verbose)
      self._session = session.Session(self._config.master,
                                      config=self._config.tf_config)
      checkpoint_path = train.latest_checkpoint(path)
      if checkpoint_path is None:
        raise ValueError(
            'Missing checkpoint files in the %s. Please '
            'make sure you are you have checkpoint file that describes '
            'latest checkpoints and appropriate checkpoints are there. '
            'If you have moved the folder, you at this point need to '
            'update manually update the paths in the checkpoint file.' % path)
      self._saver.restore(self._session, checkpoint_path)
    # Set to be initialized.
    self._initialized = True
Beispiel #3
0
 def model_fn(features, labels, mode):
   _, _ = features, labels
   return estimator_lib.EstimatorSpec(
       mode,
       loss=constant_op.constant(3.),
       scaffold=training.Scaffold(saver=training.Saver()),
       train_op=constant_op.constant(5.),
       eval_metric_ops={
           'mean_of_features': metrics_lib.mean(constant_op.constant(2.))
       })
Beispiel #4
0
 def model_fn(features, labels, mode):
     _, _ = features, labels
     mean = metrics_module.Mean()
     mean.update_state(constant_op.constant(2.))
     return estimator_lib.EstimatorSpec(
         mode,
         loss=constant_op.constant(3.),
         scaffold=training.Scaffold(saver=training.Saver()),
         train_op=constant_op.constant(5.),
         eval_metric_ops={
             'mean_of_features': mean,
         })
Beispiel #5
0
 def test_stop_if_checkpoint_step_is_laststep(self):
   model_dir = tempfile.mkdtemp()
   with ops.Graph().as_default():
     step = training.create_global_step()
     assign_ten = step.assign(10)
     no_op = control_flow_ops.no_op()
     hook = hooks_lib._StopAtCheckpointStepHook(
         model_dir=model_dir, last_step=10)
     with tf_session.Session() as sess:
       sess.run(assign_ten)
       training.Saver().save(sess, os.path.join(model_dir, 'model.ckpt'))
     with training.SingularMonitoredSession(hooks=[hook]) as mon_sess:
       mon_sess.raw_session().run(assign_ten)
       with test.mock.patch.object(time, 'sleep') as mock_sleep:
         mon_sess.run(no_op)
         self.assertFalse(mock_sleep.called)
       self.assertTrue(mon_sess.should_stop())
Beispiel #6
0
  def _train_model(self, input_fn, hooks, saving_listeners):
    worker_hooks = []
    with ops.Graph().as_default() as g, g.device(self._device_fn):
      random_seed.set_random_seed(self._config.tf_random_seed)
      global_step_tensor = self._create_and_assert_global_step(g)
      global_step_read_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
      features, labels = self._get_features_and_labels_from_input_fn(
          input_fn, model_fn_lib.ModeKeys.TRAIN)
      with ops.control_dependencies([global_step_read_tensor]):
        estimator_spec = self._call_model_fn(
            features, labels, model_fn_lib.ModeKeys.TRAIN, self.config)
      # Check if the user created a loss summary, and add one if they didn't.
      # We assume here that the summary is called 'loss'. If it is not, we will
      # make another one with the name 'loss' to ensure it shows up in the right
      # graph in TensorBoard.
      if not any([x.op.name == 'loss'
                  for x in ops.get_collection(ops.GraphKeys.SUMMARIES)]):
        summary.scalar('loss', estimator_spec.loss)
      ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss)
      worker_hooks.extend(hooks)
      worker_hooks.extend([
          training.NanTensorHook(estimator_spec.loss),
          training.LoggingTensorHook(
              {
                  'loss': estimator_spec.loss,
                  'step': global_step_tensor
              },
              every_n_iter=100)
      ])
      worker_hooks.extend(estimator_spec.training_hooks)

      if not (estimator_spec.scaffold.saver or
              ops.get_collection(ops.GraphKeys.SAVERS)):
        ops.add_to_collection(
            ops.GraphKeys.SAVERS,
            training.Saver(
                sharded=True,
                max_to_keep=self._config.keep_checkpoint_max,
                keep_checkpoint_every_n_hours=(
                    self._config.keep_checkpoint_every_n_hours),
                defer_build=True,
                save_relative_paths=True))

      chief_hooks = []
      all_hooks = worker_hooks + list(estimator_spec.training_chief_hooks)
      saver_hooks = [
          h for h in all_hooks if isinstance(h, training.CheckpointSaverHook)]
      if (self._config.save_checkpoints_secs or
          self._config.save_checkpoints_steps):
        if not saver_hooks:
          chief_hooks = [
              training.CheckpointSaverHook(
                  self._model_dir,
                  save_secs=self._config.save_checkpoints_secs,
                  save_steps=self._config.save_checkpoints_steps,
                  scaffold=estimator_spec.scaffold)
          ]
          saver_hooks = [chief_hooks[0]]
      if saving_listeners:
        if not saver_hooks:
          raise ValueError(
              'There should be a CheckpointSaverHook to use saving_listeners. '
              'Please set one of the RunConfig.save_checkpoints_steps or '
              'RunConfig.save_checkpoints_secs.')
        else:
          # It is expected to have one CheckpointSaverHook. If multiple, we pick
          # up the first one to add listener.
          saver_hooks[0]._listeners.extend(saving_listeners)  # pylint: disable=protected-access
      with training.MonitoredTrainingSession(
          master=self._config.master,
          is_chief=self._config.is_chief,
          checkpoint_dir=self._model_dir,
          scaffold=estimator_spec.scaffold,
          hooks=worker_hooks,
          chief_only_hooks=(
              tuple(chief_hooks) + tuple(estimator_spec.training_chief_hooks)),
          save_checkpoint_secs=0,  # Saving is handled by a hook.
          save_summaries_steps=self._config.save_summary_steps,
          config=self._session_config,
          log_step_count_steps=self._config.log_step_count_steps) as mon_sess:
        loss = None
        while not mon_sess.should_stop():
          _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
      return loss
Beispiel #7
0
    def _setup_training(self):
        """Sets up graph, model and trainer."""
        # Create config if not given.
        if self._config is None:
            self._config = RunConfig(verbose=self.verbose)
        # Create new graph.
        self._graph = ops.Graph()
        self._graph.add_to_collection("IS_TRAINING", True)
        with self._graph.as_default():
            random_seed.set_random_seed(self._config.tf_random_seed)
            self._global_step = variables.Variable(0,
                                                   name="global_step",
                                                   trainable=False)

            # Setting up inputs and outputs.
            self._inp, self._out = self._data_feeder.input_builder()

            # If class weights are provided, add them to the graph.
            # Different loss functions can use this tensor by name.
            if self.class_weight:
                self._class_weight_node = constant_op.constant(
                    self.class_weight, name='class_weight')

            # Add histograms for X and y if they are floats.
            if self._data_feeder.input_dtype in (np.float32, np.float64):
                logging_ops.histogram_summary("X", self._inp)
            if self._data_feeder.output_dtype in (np.float32, np.float64):
                logging_ops.histogram_summary("y", self._out)

            # Create model's graph.
            self._model_predictions, self._model_loss = self.model_fn(
                self._inp, self._out)

            # Set up a single operator to merge all the summaries
            self._summaries = logging_ops.merge_all_summaries()

            # Create trainer and augment graph with gradients and optimizer.
            # Additionally creates initialization ops.
            learning_rate = self.learning_rate
            optimizer = self.optimizer
            if callable(learning_rate):
                learning_rate = learning_rate(self._global_step)
            if callable(optimizer):
                optimizer = optimizer(learning_rate)
            self._train = optimizers.optimize_loss(
                self._model_loss,
                self._global_step,
                learning_rate=learning_rate,
                optimizer=optimizer,
                clip_gradients=self.clip_gradients)

            # Update ops during training, e.g. batch_norm_ops
            self._train = control_flow_ops.group(
                self._train, *ops.get_collection('update_ops'))

            # Get all initializers for all trainable variables.
            self._initializers = variables.initialize_all_variables()

            # Create model's saver capturing all the nodes created up until now.
            self._saver = train.Saver(
                max_to_keep=self._config.keep_checkpoint_max,
                keep_checkpoint_every_n_hours=self._config.
                keep_checkpoint_every_n_hours)

            # Enable monitor to create validation data dict with appropriate tf placeholders
            self._monitor.create_val_feed_dict(self._inp, self._out)

            # Create session to run model with.
            self._session = session.Session(self._config.tf_master,
                                            config=self._config.tf_config)

            # Run parameter initializers.
            self._session.run(self._initializers)
Beispiel #8
0
    def _train_model(self, input_fn, hooks):
        all_hooks = []
        with ops.Graph().as_default() as g, g.device(self._device_fn):
            random_seed.set_random_seed(self._config.tf_random_seed)
            global_step_tensor = self._create_and_assert_global_step(g)
            features, labels = self._get_features_and_labels_from_input_fn(
                input_fn, model_fn_lib.ModeKeys.TRAIN)
            estimator_spec = self._call_model_fn(features, labels,
                                                 model_fn_lib.ModeKeys.TRAIN,
                                                 self.config)
            ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss)
            all_hooks.extend(hooks)
            all_hooks.extend([
                training.NanTensorHook(estimator_spec.loss),
                training.LoggingTensorHook(
                    {
                        'loss': estimator_spec.loss,
                        'step': global_step_tensor
                    },
                    every_n_iter=100)
            ])
            all_hooks.extend(estimator_spec.training_hooks)

            if not (estimator_spec.scaffold.saver
                    or ops.get_collection(ops.GraphKeys.SAVERS)):
                ops.add_to_collection(
                    ops.GraphKeys.SAVERS,
                    training.Saver(
                        sharded=True,
                        max_to_keep=self._config.keep_checkpoint_max,
                        keep_checkpoint_every_n_hours=(
                            self._config.keep_checkpoint_every_n_hours),
                        defer_build=True,
                        save_relative_paths=True))

            chief_hooks = []
            if (self._config.save_checkpoints_secs
                    or self._config.save_checkpoints_steps):
                saver_hook_exists = any([
                    isinstance(h, training.CheckpointSaverHook)
                    for h in (all_hooks + chief_hooks +
                              list(estimator_spec.training_chief_hooks))
                ])
                if not saver_hook_exists:
                    chief_hooks = [
                        training.CheckpointSaverHook(
                            self._model_dir,
                            save_secs=self._config.save_checkpoints_secs,
                            save_steps=self._config.save_checkpoints_steps,
                            scaffold=estimator_spec.scaffold)
                    ]
            with training.MonitoredTrainingSession(
                    master=self._config.master,
                    is_chief=self._config.is_chief,
                    checkpoint_dir=self._model_dir,
                    scaffold=estimator_spec.scaffold,
                    hooks=all_hooks,
                    chief_only_hooks=(
                        tuple(chief_hooks) +
                        tuple(estimator_spec.training_chief_hooks)),
                    save_checkpoint_secs=0,  # Saving is handled by a hook.
                    save_summaries_steps=self._config.save_summary_steps,
                    config=self._session_config,
                    log_step_count_steps=self._config.log_step_count_steps
            ) as mon_sess:
                loss = None
                while not mon_sess.should_stop():
                    _, loss = mon_sess.run(
                        [estimator_spec.train_op, estimator_spec.loss])
            return loss
Beispiel #9
0
    def _train_with_estimator_spec(self, estimator_spec, worker_hooks, hooks,
                                   global_step_tensor, saving_listeners,
                                   save_best_ckpt):
        """Train a model with the given Estimator Spec."""
        if self._warm_start_settings:
            logging.info('Warm-starting with WarmStartSettings: %s' %
                         (self._warm_start_settings, ))
            warm_starting_util.warm_start(*self._warm_start_settings)
        worker_hooks.extend(hooks)
        worker_hooks.append(training.NanTensorHook(estimator_spec.loss))
        if self._config.log_step_count_steps is not None:
            tensors = {"loss": estimator_spec.loss, "step": global_step_tensor}
            tensors.update({
                key.replace("/", ""): val
                for key, val in estimator_spec.predictions.items()
                if "/" in key
            })
            worker_hooks.append(
                training.LoggingTensorHook(
                    tensors, every_n_iter=self._config.log_step_count_steps))
        worker_hooks.extend(estimator_spec.training_hooks)

        # Create Saver object
        if not (estimator_spec.scaffold.saver
                or ops.get_collection(ops.GraphKeys.SAVERS)):
            ops.add_to_collection(
                ops.GraphKeys.SAVERS,
                training.Saver(sharded=True,
                               max_to_keep=self._config.keep_checkpoint_max,
                               keep_checkpoint_every_n_hours=(
                                   self._config.keep_checkpoint_every_n_hours),
                               defer_build=True,
                               save_relative_paths=True))

        chief_hooks = []
        all_hooks = worker_hooks + list(estimator_spec.training_chief_hooks)
        saver_hooks = [
            h for h in all_hooks if isinstance(h, training.CheckpointSaverHook)
        ]
        if (self._config.save_checkpoints_secs
                or self._config.save_checkpoints_steps):
            if not saver_hooks:
                chief_hooks = [
                    training.CheckpointSaverHook(
                        self._model_dir,
                        save_secs=self._config.save_checkpoints_secs,
                        save_steps=self._config.save_checkpoints_steps,
                        scaffold=estimator_spec.scaffold)
                ]
                saver_hooks = [chief_hooks[0]]
        if saving_listeners:
            if not saver_hooks:
                raise ValueError(
                    'There should be a CheckpointSaverHook to use saving_listeners. '
                    'Please set one of the RunConfig.save_checkpoints_steps or '
                    'RunConfig.save_checkpoints_secs.')
            else:
                # It is expected to have one CheckpointSaverHook. If multiple, we pick
                # up the first one to add listener.
                saver_hooks[0]._listeners.extend(saving_listeners)  # pylint: disable=protected-access

        if self._train_with_eval:
            self.dataset_handle_hook = IteratorStringHandleHook(
                self.train_iterator, self.eval_iterator)
            worker_hooks.append(self.dataset_handle_hook)
            self._predict_keys = estimator_spec.predictions

        if save_best_ckpt:
            EvaluatorCls = self._params.get("evaluator", None)
            if not issubclass(EvaluatorCls, EvaluateBase):
                raise TypeError(
                    "Parameter `evaluator` must be a EvaluateBase instance, but got {}"
                    .format(type(EvaluatorCls)))
            eval_kwargs = self._params.get("eval_kwargs", {})
            eval_steps = self._params.get("eval_steps", 2500)
            primary_metric = self._params.get("primary_metric", None)
            secondary_metric = self._params.get("secondary_metric", None)

            # We must construct Evaluator inside a graph scope
            evaluator = EvaluatorCls(self, **eval_kwargs)

            worker_hooks.append(
                BestCheckpointSaverHook(evaluator=evaluator,
                                        checkpoint_dir=self._model_dir,
                                        compare_fn=partial(
                                            evaluator.compare,
                                            primary_metric=primary_metric,
                                            secondary_metric=secondary_metric),
                                        tag=self._params["args"].tag,
                                        save_steps=eval_steps))

        # Training session monitor
        with training.MonitoredTrainingSession(
                master=self._config.master,
                is_chief=self._config.is_chief,
                checkpoint_dir=self._model_dir,
                scaffold=estimator_spec.scaffold,
                hooks=worker_hooks,
                chief_only_hooks=(tuple(chief_hooks) +
                                  tuple(estimator_spec.training_chief_hooks)),
                save_checkpoint_secs=0,
                save_summaries_steps=self._config.save_summary_steps,
                config=self._session_config,
                log_step_count_steps=self._config.log_step_count_steps
        ) as mon_sess:
            loss = None

            # Make sure that use self.dataset_handle_hook.xxx_handle after create MonitoredSession()
            self._feed_dict = _add_key_value(
                self._feed_dict, self.handler,
                self.dataset_handle_hook.train_handle)
            while not mon_sess.should_stop():
                _, loss = mon_sess.run(
                    [estimator_spec.train_op, estimator_spec.loss],
                    self._feed_dict)
            return loss
Beispiel #10
0
    def _train_with_estimator_spec(self, estimator_spec, worker_hooks, hooks,
                                   global_step_tensor, saving_listeners):
        """Train a model with the given Estimator Spec."""
        if self._warm_start_settings:
            logging.info('Warm-starting with WarmStartSettings: %s' %
                         (self._warm_start_settings, ))
            warm_starting_util.warm_start(*self._warm_start_settings)
        # Check if the user created a loss summary, and add one if they didn't.
        # We assume here that the summary is called 'loss'. If it is not, we will
        # make another one with the name 'loss' to ensure it shows up in the right
        # graph in TensorBoard.
        # if not any([x.op.name == 'loss' for x in ops.get_collection(ops.GraphKeys.SUMMARIES)]):
        #     summary.scalar('loss', estimator_spec.loss)
        ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss)
        worker_hooks.extend(hooks)
        # worker_hooks.extend([
        #     training.NanTensorHook(estimator_spec.loss)
        # ])

        worker_hooks.extend(estimator_spec.training_hooks)

        if not (estimator_spec.scaffold.saver
                or ops.get_collection(ops.GraphKeys.SAVERS)):
            ops.add_to_collection(
                ops.GraphKeys.SAVERS,
                training.Saver(sharded=True,
                               max_to_keep=self._config.keep_checkpoint_max,
                               keep_checkpoint_every_n_hours=(
                                   self._config.keep_checkpoint_every_n_hours),
                               defer_build=True,
                               save_relative_paths=True))

        chief_hooks = []
        all_hooks = worker_hooks + list(estimator_spec.training_chief_hooks)
        saver_hooks = [
            h for h in all_hooks if isinstance(h, training.CheckpointSaverHook)
        ]
        if (self._config.save_checkpoints_secs
                or self._config.save_checkpoints_steps):
            if not saver_hooks:
                chief_hooks = [
                    training.CheckpointSaverHook(
                        self._model_dir,
                        save_secs=self._config.save_checkpoints_secs,
                        save_steps=self._config.save_checkpoints_steps,
                        scaffold=estimator_spec.scaffold)
                ]
                saver_hooks = [chief_hooks[0]]
        if saving_listeners:
            if not saver_hooks:
                raise ValueError(
                    'There should be a CheckpointSaverHook to use saving_listeners. '
                    'Please set one of the RunConfig.save_checkpoints_steps or '
                    'RunConfig.save_checkpoints_secs.')
            else:
                # It is expected to have one CheckpointSaverHook. If multiple, we pick
                # up the first one to add listener.
                saver_hooks[0]._listeners.extend(saving_listeners)  # pylint: disable=protected-access

        if is_rank0():
            log_step_count_steps = self._config.log_step_count_steps
            checkpoint_dir = self.model_dir
            chief_only_hooks = (tuple(chief_hooks) +
                                tuple(estimator_spec.training_chief_hooks))
        else:
            log_step_count_steps = None
            checkpoint_dir = None
            chief_only_hooks = None

        with MonitoredTrainingSession(
                master=self._config.master,
                is_chief=is_rank0(),
                checkpoint_dir=checkpoint_dir,
                scaffold=estimator_spec.scaffold,
                hooks=worker_hooks,
                chief_only_hooks=chief_only_hooks,
                save_checkpoint_secs=0,  # Saving is handled by a hook.
                save_summaries_steps=self._config.save_summary_steps,
                config=self._session_config,
                log_step_count_steps=log_step_count_steps) as mon_sess:
            loss = None
            while not mon_sess.should_stop():
                _, loss = mon_sess.run(
                    [estimator_spec.train_op, estimator_spec.loss])
        return loss