Exemple #1
0
  def _train_model(self,
                   input_fn,
                   steps,
                   feed_fn=None,
                   init_op=None,
                   init_feed_fn=None,
                   init_fn=None,
                   device_fn=None,
                   monitors=None,
                   log_every_steps=100,
                   fail_on_nan_loss=True,
                   max_steps=None):
    # TODO(wicke): Remove this once Model and associated code are gone.
    if hasattr(self._config, 'execution_mode'):
      if self._config.execution_mode not in ('all', 'train'):
        return

      # Stagger startup of worker sessions based on task id.
      sleep_secs = min(
          self._config.training_worker_max_startup_secs,
          self._config.task *
          self._config.training_worker_session_startup_stagger_secs)
      if sleep_secs:
        logging.info('Waiting %d secs before starting task %d.', sleep_secs,
                     self._config.task)
        time.sleep(sleep_secs)

    # Device allocation
    device_fn = device_fn or self._device_fn

    self._graph = ops.Graph()
    with self._graph.as_default() as g, g.device(device_fn):
      random_seed.set_random_seed(self._config.tf_random_seed)
      global_step = contrib_framework.create_global_step(g)
      features, targets = input_fn()
      self._check_inputs(features, targets)
      train_op, loss_op = self._get_train_ops(features, targets)

      # Add default monitors.
      if monitors is None:
        monitors = []

      hooks = [m for m in monitors
               if isinstance(m, session_run_hook.SessionRunHook)]

      deprecated_monitors = [
          m for m in monitors
          if not isinstance(m, session_run_hook.SessionRunHook)
      ]

      supervisor_is_chief = self._config.is_chief
      if not supervisor_is_chief:
        # Prune list of monitor to the ones runnable on all workers.
        deprecated_monitors = [m for m in deprecated_monitors
                               if m.run_on_all_workers]

      # Setup monitors.
      for monitor in deprecated_monitors:
        monitor.set_estimator(self)

      if deprecated_monitors:
        hooks.append(monitor_lib.RunHookAdapterForMonitors(deprecated_monitors))

      return graph_actions._monitored_train(  # pylint: disable=protected-access
          graph=g,
          output_dir=self._model_dir,
          train_op=train_op,
          loss_op=loss_op,
          global_step_tensor=global_step,
          init_op=init_op,
          init_feed_dict=init_feed_fn() if init_feed_fn is not None else None,
          init_fn=init_fn,
          log_every_steps=log_every_steps,
          supervisor_is_chief=supervisor_is_chief,
          supervisor_master=self._config.master,
          supervisor_save_model_secs=self._config.save_checkpoints_secs,
          supervisor_save_summaries_steps=self._config.save_summary_steps,
          keep_checkpoint_max=self._config.keep_checkpoint_max,
          feed_fn=feed_fn,
          steps=steps,
          fail_on_nan_loss=fail_on_nan_loss,
          hooks=hooks,
          max_steps=max_steps)
Exemple #2
0
  def _train_model(self,
                   input_fn,
                   steps,
                   feed_fn=None,
                   init_op=None,
                   init_feed_fn=None,
                   init_fn=None,
                   device_fn=None,
                   monitors=None,
                   log_every_steps=100,
                   fail_on_nan_loss=True,
                   max_steps=None):
    # TODO(wicke): Remove this once Model and associated code are gone.
    if hasattr(self._config, 'execution_mode'):
      if self._config.execution_mode not in ('all', 'train'):
        return

      # Stagger startup of worker sessions based on task id.
      sleep_secs = min(
          self._config.training_worker_max_startup_secs,
          self._config.task *
          self._config.training_worker_session_startup_stagger_secs)
      if sleep_secs:
        logging.info('Waiting %d secs before starting task %d.', sleep_secs,
                     self._config.task)
        time.sleep(sleep_secs)

    # Device allocation
    device_fn = device_fn or self._device_fn

    self._graph = ops.Graph()
    with self._graph.as_default() as g, g.device(device_fn):
      random_seed.set_random_seed(self._config.tf_random_seed)
      global_step = contrib_framework.create_global_step(g)
      features, labels = input_fn()
      self._check_inputs(features, labels)

      # The default return type of _get_train_ops is ModelFnOps. But there are
      # some subclasses of tf.contrib.learn.Estimator which override this
      # method and use the legacy signature, namely _get_train_ops returns a
      # (train_op, loss) tuple. The following else-statement code covers these
      # cases, but will soon be deleted after the subclasses are updated.
      # TODO(b/32664904): Update subclasses and delete the else-statement.
      train_ops = self._get_train_ops(features, labels)
      if isinstance(train_ops, ModelFnOps):  # Default signature
        train_op = train_ops.train_op
        loss_op = train_ops.loss
      else:  # Legacy signature
        if len(train_ops) != 2:
          raise ValueError('Expected a tuple of train_op and loss, got {}'.
                           format(train_ops))
        train_op = train_ops[0]
        loss_op = train_ops[1]

      hooks = monitor_lib.replace_monitors_with_hooks(monitors, self)

      ops.add_to_collection(ops.GraphKeys.LOSSES, loss_op)
      return graph_actions._monitored_train(  # pylint: disable=protected-access
          graph=g,
          output_dir=self._model_dir,
          train_op=train_op,
          loss_op=loss_op,
          global_step_tensor=global_step,
          init_op=init_op,
          init_feed_dict=init_feed_fn() if init_feed_fn is not None else None,
          init_fn=init_fn,
          log_every_steps=log_every_steps,
          supervisor_is_chief=self.config.is_chief,
          supervisor_master=self._config.master,
          supervisor_save_model_secs=self._config.save_checkpoints_secs,
          supervisor_save_model_steps=self._config.save_checkpoints_steps,
          supervisor_save_summaries_steps=self._config.save_summary_steps,
          keep_checkpoint_max=self._config.keep_checkpoint_max,
          feed_fn=feed_fn,
          steps=steps,
          fail_on_nan_loss=fail_on_nan_loss,
          hooks=hooks,
          max_steps=max_steps)
Exemple #3
0
    def _train_model(self,
                     input_fn,
                     steps,
                     feed_fn=None,
                     init_op=None,
                     init_feed_fn=None,
                     init_fn=None,
                     device_fn=None,
                     monitors=None,
                     log_every_steps=100,
                     fail_on_nan_loss=True,
                     max_steps=None):
        # TODO(wicke): Remove this once Model and associated code are gone.
        if hasattr(self._config, 'execution_mode'):
            if self._config.execution_mode not in ('all', 'train'):
                return

            # Stagger startup of worker sessions based on task id.
            sleep_secs = min(
                self._config.training_worker_max_startup_secs,
                self._config.task *
                self._config.training_worker_session_startup_stagger_secs)
            if sleep_secs:
                logging.info('Waiting %d secs before starting task %d.',
                             sleep_secs, self._config.task)
                time.sleep(sleep_secs)

        # Device allocation
        device_fn = device_fn or self._device_fn

        self._graph = ops.Graph()
        with self._graph.as_default() as g, g.device(device_fn):
            random_seed.set_random_seed(self._config.tf_random_seed)
            global_step = contrib_framework.create_global_step(g)
            features, targets = input_fn()
            self._check_inputs(features, targets)
            train_op, loss_op = self._get_train_ops(features, targets)

            # Add default monitors.
            if monitors is None:
                monitors = []

            hooks = [
                m for m in monitors
                if isinstance(m, session_run_hook.SessionRunHook)
            ]

            deprecated_monitors = [
                m for m in monitors
                if not isinstance(m, session_run_hook.SessionRunHook)
            ]

            supervisor_is_chief = self._config.is_chief
            if not supervisor_is_chief:
                # Prune list of monitor to the ones runnable on all workers.
                deprecated_monitors = [
                    m for m in deprecated_monitors if m.run_on_all_workers
                ]

            # Setup monitors.
            for monitor in deprecated_monitors:
                monitor.set_estimator(self)

            if deprecated_monitors:
                hooks.append(
                    monitor_lib.RunHookAdapterForMonitors(deprecated_monitors))

            ops.add_to_collection(ops.GraphKeys.LOSSES, loss_op)
            return graph_actions._monitored_train(  # pylint: disable=protected-access
                graph=g,
                output_dir=self._model_dir,
                train_op=train_op,
                loss_op=loss_op,
                global_step_tensor=global_step,
                init_op=init_op,
                init_feed_dict=init_feed_fn()
                if init_feed_fn is not None else None,
                init_fn=init_fn,
                log_every_steps=log_every_steps,
                supervisor_is_chief=supervisor_is_chief,
                supervisor_master=self._config.master,
                supervisor_save_model_secs=self._config.save_checkpoints_secs,
                supervisor_save_model_steps=self._config.
                save_checkpoints_steps,
                supervisor_save_summaries_steps=self._config.
                save_summary_steps,
                keep_checkpoint_max=self._config.keep_checkpoint_max,
                feed_fn=feed_fn,
                steps=steps,
                fail_on_nan_loss=fail_on_nan_loss,
                hooks=hooks,
                max_steps=max_steps)
  def _train_model(self,
                   input_fn,
                   steps,
                   feed_fn=None,
                   init_op=None,
                   init_feed_fn=None,
                   init_fn=None,
                   device_fn=None,
                   monitors=None,
                   log_every_steps=100,
                   fail_on_nan_loss=True,
                   max_steps=None):
    # TODO(wicke): Remove this once Model and associated code are gone.
    if hasattr(self._config, 'execution_mode'):
      if self._config.execution_mode not in ('all', 'train'):
        return

      # Stagger startup of worker sessions based on task id.
      sleep_secs = min(
          self._config.training_worker_max_startup_secs,
          self._config.task_id *
          self._config.training_worker_session_startup_stagger_secs)
      if sleep_secs:
        logging.info('Waiting %d secs before starting task %d.', sleep_secs,
                     self._config.task_id)
        time.sleep(sleep_secs)

    # Device allocation
    device_fn = device_fn or self._device_fn

    self._graph = ops.Graph()
    with self._graph.as_default() as g, g.device(device_fn):
      random_seed.set_random_seed(self._config.tf_random_seed)
      global_step = contrib_framework.create_global_step(g)
      features, labels = input_fn()
      self._check_inputs(features, labels)

      # The default return type of _get_train_ops is ModelFnOps. But there are
      # some subclasses of tf.contrib.learn.Estimator which override this
      # method and use the legacy signature, namely _get_train_ops returns a
      # (train_op, loss) tuple. The following else-statement code covers these
      # cases, but will soon be deleted after the subclasses are updated.
      # TODO(b/32664904): Update subclasses and delete the else-statement.
      train_ops = self._get_train_ops(features, labels)
      if isinstance(train_ops, model_fn_lib.ModelFnOps):  # Default signature
        train_op = train_ops.train_op
        loss_op = train_ops.loss
      else:  # Legacy signature
        if len(train_ops) != 2:
          raise ValueError('Expected a tuple of train_op and loss, got {}'.
                           format(train_ops))
        train_op = train_ops[0]
        loss_op = train_ops[1]

      hooks = monitor_lib.replace_monitors_with_hooks(monitors, self)

      ops.add_to_collection(ops.GraphKeys.LOSSES, loss_op)
      return graph_actions._monitored_train(  # pylint: disable=protected-access
          graph=g,
          output_dir=self._model_dir,
          train_op=train_op,
          loss_op=loss_op,
          global_step_tensor=global_step,
          init_op=init_op,
          init_feed_dict=init_feed_fn() if init_feed_fn is not None else None,
          init_fn=init_fn,
          log_every_steps=log_every_steps,
          supervisor_is_chief=self.config.is_chief,
          supervisor_master=self._config.master,
          supervisor_save_model_secs=self._config.save_checkpoints_secs,
          supervisor_save_model_steps=self._config.save_checkpoints_steps,
          supervisor_save_summaries_steps=self._config.save_summary_steps,
          keep_checkpoint_max=self._config.keep_checkpoint_max,
          feed_fn=feed_fn,
          steps=steps,
          fail_on_nan_loss=fail_on_nan_loss,
          hooks=hooks,
          max_steps=max_steps)