def back_up(self, epoch):
    """Back up the current state of training into a checkpoint file.

    Arguments:
      epoch: The current epoch information to be saved.
    """
    # pylint: disable=protected-access
    self._assert_in_multi_worker_mode()

    # Update `_ckpt_saved_epoch`.
    K.set_value(self._ckpt_saved_epoch, epoch)

    # If this is multi-worker training, and this worker should not
    # save checkpoint, we replace the filepath with a dummy filepath so
    # it writes to a file that will be removed at the end of _save_model()
    # call. This is because the SyncOnReadVariable needs to be synced across
    # all the workers in order to be read, and all workers need to initiate
    # that.
    if dc_context.get_current_worker_context().should_checkpoint:
      save_filepath = self._backup_filepath
    else:
      save_filepath = self._temp_filepath

    # Save the weights plus CKPT_SAVED_EPOCH variable.
    self._model.save_weights(save_filepath, overwrite=True)

    if not dc_context.get_current_worker_context().should_checkpoint:
      # Remove the file in multi-worker training where this worker should
      # not checkpoint. It is a dummy file previously saved for sync distributed
      # training.
      _remove_dir(self._temp_dir)
    def worker_fn(strategy):
        with ops.Graph().as_default():
            batch_size = 64
            steps = 10

            with strategy.scope():
                train_ds, _ = _mnist_synthetic_dataset(batch_size, steps)
                model = _clone_and_build_model(orig_model, strategy)

                orig_loss, orig_acc = model.evaluate(train_ds, steps=steps)

                # Workaround for the metrics issue (b/122928955) in async training. This
                # can only be used in standalone client mode.
                dc_context.get_current_worker_context().wait_for_other_workers(
                )

                model.fit(x=train_ds, epochs=2, steps_per_epoch=steps)

                dc_context.get_current_worker_context().wait_for_other_workers(
                )

                trained_loss, trained_acc = model.evaluate(train_ds,
                                                           steps=steps)

            test_obj.assertLessEqual(trained_loss, orig_loss)
            test_obj.assertGreaterEqual(trained_acc, orig_acc)
  def _worker_fn(strategy):
    """Function for worker task."""
    local_estimator = copy.deepcopy(estimator)
    # pylint: disable=protected-access
    local_estimator._config._train_distribute = strategy
    context = dc_context.get_current_worker_context()
    _init_run_config_from_worker_context(local_estimator._config, context)
    logging.info('Updated config: %s', str(vars(local_estimator._config)))
    local_estimator._train_distribution = strategy
    # pylint: enable=protected-access

    # In the standalone client, we don't need to run hooks on all threads
    # because logging hooks on all threads may be too much on the screen; also
    # tensor passed to one hook can only be fetched with the graph where the
    # tensor is defined. Other hooks such as checkpointing hooks will added by
    # MonitoredTrainingSession.
    # TODO(yuefengz): Is there a hook that does need to run on all threads in
    # standalone client mode?
    if (run_config._distribute_coordinator_mode ==  # pylint: disable=protected-access
        dc.CoordinatorMode.INDEPENDENT_WORKER or context.is_chief):
      hooks = list(train_spec.hooks)
    else:
      hooks = []

    # Prevent estimator.train from calling distribute coordinator again. This
    # function calls estimator.train which will use distribute coordinator path
    # again if `_distribute_coordinator_mode` is set.
    local_estimator._config._distribute_coordinator_mode = None  # pylint: disable=protected-access
    local_estimator.train(
        input_fn=train_spec.input_fn,
        max_steps=train_spec.max_steps,
        hooks=hooks)
def filter_distributed_callbacks(callbacks_list):
  """Filter Callbacks based on the worker context when running multi-worker.

  Arguments:
    callbacks_list: A list of `Callback` instances.

  Returns:
    The list of `Callback` instances that should be run on this worker.
  """

  if not K.in_multi_worker_mode():
    raise ValueError(
        'filter_distributed_callbacks() should only be called when Keras '
        'is in multi worker mode.')

  worker_context = dc_context.get_current_worker_context()
  callbacks_list = callbacks_list or []
  if not [
      c for c in callbacks_list if isinstance(c, callbacks.ModelCheckpoint)
  ]:
    # TODO(rchao): Consider providing a ModelCheckpoint here if the user
    # fails to.
    logging.warning('ModelCheckpoint callback is not provided. '
                    'Workers will need to restart training if any fails.')
  # TODO(rchao): Add similar warning for restoring callback (to be designed).

  if callbacks_list is None or worker_context.is_chief:
    return callbacks_list

  # Some Callbacks should only run on the chief worker.
  return [
      callback for callback in callbacks_list if not callback._chief_worker_only
  ]  # pylint: disable=protected-access
def configure_and_create_session(distribution_strategy):
  """Configure session config and create a session with it."""
  # TODO(priyag): Throw error if a session already exists.
  session_config = K.get_default_session_config()

  if is_tpu_strategy(distribution_strategy):
    # TODO(priyag, yuefengz): Remove this workaround when Distribute
    # Coordinator is integrated with keras and we can create a session from
    # there.
    distribution_strategy.configure(session_config)
    master = distribution_strategy.extended._tpu_cluster_resolver.master()  # pylint: disable=protected-access
    session = session_module.Session(config=session_config, target=master)
  else:
    worker_context = dc_context.get_current_worker_context()
    if worker_context:
      dc_session_config = worker_context.session_config
      # Merge the default session config to the one from distribute coordinator,
      # which is fine for now since they don't have conflicting configurations.
      dc_session_config.MergeFrom(session_config)
      session = session_module.Session(
          config=dc_session_config, target=worker_context.master_target)
    else:
      session = session_module.Session(config=session_config)

  K.set_session(session)
Пример #6
0
  def _worker_fn(strategy):
    """Function for worker task."""
    local_estimator = copy.deepcopy(estimator)
    # pylint: disable=protected-access
    local_estimator._config._train_distribute = strategy
    context = dc_context.get_current_worker_context()
    _init_run_config_from_worker_context(local_estimator._config, context)
    logging.info('Updated config: %s', str(vars(local_estimator._config)))
    local_estimator._train_distribution = strategy
    # pylint: enable=protected-access

    # In the standalone client, we don't need to run hooks on all threads
    # because logging hooks on all threads may be too much on the screen; also
    # tensor passed to one hook can only be fetched with the graph where the
    # tensor is defined. Other hooks such as checkpointing hooks will added by
    # MonitoredTrainingSession.
    # TODO(yuefengz): Is there a hook that does need to run on all threads in
    # standalone client mode?
    if (run_config._distribute_coordinator_mode ==  # pylint: disable=protected-access
        dc.CoordinatorMode.INDEPENDENT_WORKER or context.is_chief):
      hooks = list(train_spec.hooks)
    else:
      hooks = []

    # Prevent estimator.train from calling distribute coordinator again. This
    # function calls estimator.train which will use distribute coordinator path
    # again if `_distribute_coordinator_mode` is set.
    local_estimator._config._distribute_coordinator_mode = None  # pylint: disable=protected-access
    local_estimator.train(
        input_fn=train_spec.input_fn,
        max_steps=train_spec.max_steps,
        hooks=hooks)
Пример #7
0
def filter_distributed_callbacks(callbacks_list):
  """Filter Callbacks based on the worker context when running multi-worker.

  Arguments:
    callbacks_list: A list of `Callback` instances.

  Returns:
    The list of `Callback` instances that should be run on this worker.
  """

  if not K.in_multi_worker_mode():
    raise ValueError(
        'filter_distributed_callbacks() should only be called when Keras '
        'is in multi worker mode.')

  worker_context = dc_context.get_current_worker_context()
  callbacks_list = callbacks_list or []
  if not [
      c for c in callbacks_list if isinstance(c, callbacks.ModelCheckpoint)
  ]:
    # TODO(rchao): Consider providing a ModelCheckpoint here if the user
    # fails to.
    logging.warning('ModelCheckpoint callback is not provided. '
                    'Workers will need to restart training if any fails.')
  # TODO(rchao): Add similar warning for restoring callback (to be designed).

  if callbacks_list is None or worker_context.is_chief:
    return callbacks_list

  # Some Callbacks should only run on the chief worker.
  return [
      callback for callback in callbacks_list if not callback._chief_worker_only
  ]  # pylint: disable=protected-access
  def _between_graph_with_monitored_session(self, strategy):
    context = distribute_coordinator_context.get_current_worker_context()
    self.assertTrue(context is not None)
    with ops.device("/job:ps/task:0"):
      # TODO(yuefengz): investigate why not using resource variable will make
      # the test flaky.
      x = variable_scope.get_variable("xx", initializer=10.0, use_resource=True)
    with ops.device("/job:ps/task:1"):
      y = variable_scope.get_variable("yy", initializer=20.0, use_resource=True)

    x_add = x.assign_add(2.0)
    y_sub = y.assign_sub(2.0)
    train_op = control_flow_ops.group([x_add, y_sub])

    # The monitored session will run init or ready ops.
    with monitored_session.MonitoredSession() as sess:
      sess.run(train_op)

      # Synchronize workers after one step to make sure they all have finished
      # training.
      if context.has_barrier:
        context.wait_for_other_workers()
      else:
        self._barrier.wait()

      x_val, y_val = sess.run([x, y])

    self.assertEqual(x_val, 16.0)
    self.assertEqual(y_val, 14.0)
    if x_val == 16.0 and y_val == 14.0:
      with self._lock:
        self._result_correct += 1
  def _between_graph_with_monitored_session(self, strategy):
    context = distribute_coordinator_context.get_current_worker_context()
    self.assertTrue(context is not None)
    with ops.device("/job:ps/task:0"):
      # TODO(yuefengz): investigate why not using resource variable will make
      # the test flaky.
      x = variable_scope.get_variable("xx", initializer=10.0, use_resource=True)
    with ops.device("/job:ps/task:1"):
      y = variable_scope.get_variable("yy", initializer=20.0, use_resource=True)

    x_add = x.assign_add(2.0)
    y_sub = y.assign_sub(2.0)
    train_op = control_flow_ops.group([x_add, y_sub])

    # The monitored session will run init or ready ops.
    with monitored_session.MonitoredSession() as sess:
      sess.run(train_op)

      # Synchronize workers after one step to make sure they all have finished
      # training.
      if context.has_barrier:
        context.wait_for_other_workers()
      else:
        self._barrier.wait()

      x_val, y_val = sess.run([x, y])

    self.assertEqual(x_val, 16.0)
    self.assertEqual(y_val, 14.0)
    if x_val == 16.0 and y_val == 14.0:
      with self._lock:
        self._result_correct += 1
  def __init__(self, model, original_filepath):
    self._model = model

    # The directory and filepath that store the training state backup file.
    self._backup_dir, self._backup_filepath = self._get_backup_filepath(
        original_filepath)

    # For those who should not checkpoint (e.g. non-chief worker in sync
    # training), create a temporary directory to write to (that will be
    # removed later).
    if not dc_context.get_current_worker_context().should_checkpoint:
      self._temp_dir, self._temp_filepath = self._get_temp_filepath()

    # The epoch at which the checkpoint is saved. Used for fault-tolerance.
    # GPU device only has int64 dtype registered VarHandleOp.
    self._ckpt_saved_epoch = variables.Variable(
        initial_value=constant_op.constant(
            CKPT_SAVED_EPOCH_UNUSED_VALUE, dtype=dtypes.int64),
        name='ckpt_saved_epoch')

    # Variable initialization.
    K.set_value(self._ckpt_saved_epoch, CKPT_SAVED_EPOCH_UNUSED_VALUE)

    # Calling `AutoTrackable.__setattr__` to avoid getting added as a weight of
    # model (which is done in `Layer.__setattr__`), which breaks saving/loading
    # in hdf5 format. Once becomes an attr of `model`, _ckpt_saved_epoch gets
    # tracked and will be included in the checkpoint file when backing up.
    tracking.AutoTrackable.__setattr__(self._model, CKPT_SAVED_EPOCH,
                                       self._ckpt_saved_epoch)
Пример #11
0
def configure_and_create_session(distribution_strategy):
    """Configure session config and create a session with it."""
    # TODO(priyag): Throw error if a session already exists.
    session_config = K.get_default_session_config()

    if is_tpu_strategy(distribution_strategy):
        # TODO(priyag, yuefengz): Remove this workaround when Distribute
        # Coordinator is integrated with keras and we can create a session from
        # there.
        distribution_strategy.configure(session_config)
        master = distribution_strategy.extended._tpu_cluster_resolver.master()  # pylint: disable=protected-access
        session = session_module.Session(config=session_config, target=master)
    else:
        worker_context = dc_context.get_current_worker_context()
        if worker_context:
            dc_session_config = worker_context.session_config
            # Merge the default session config to the one from distribute coordinator,
            # which is fine for now since they don't have conflicting configurations.
            dc_session_config.MergeFrom(session_config)
            session = session_module.Session(
                config=dc_session_config, target=worker_context.master_target)
        else:
            distribution_strategy.configure(session_config)
            session = session_module.Session(config=session_config)

    K.set_session(session)
Пример #12
0
 def __enter__(self):
   old_context = distribute_coordinator_context.get_current_worker_context()
   if old_context:
     raise ValueError(
         "You cannot run distribute coordinator in a `worker_fn`.\t" +
         self._debug_message())
   # pylint: disable=protected-access
   distribute_coordinator_context._worker_context.current = self
Пример #13
0
 def _maybe_remove_file(self, file_handle, filepath):
     # Remove the file in multi-worker training where this worker should
     # not checkpoint. It is a dummy file previously saved for sync distributed
     # training.
     if K.in_multi_worker_mode(
     ) and not dc_context.get_current_worker_context().should_checkpoint:
         os.close(file_handle)
         os.remove(filepath)
 def __enter__(self):
   old_context = distribute_coordinator_context.get_current_worker_context()
   if old_context:
     raise ValueError(
         "You cannot run distribute coordinator in a `worker_fn`.\t" +
         self._debug_message())
   # pylint: disable=protected-access
   distribute_coordinator_context._worker_context.current = self
def init_restore_or_wait_for_variables():
  """Initialize or restore variables or wait for variables to be initialized."""
  session = K._get_session()  # pylint: disable=protected-access
  worker_context = dc_context.get_current_worker_context()
  if not worker_context or worker_context.experimental_should_init:
    # TODO(yuefengz): if checkpoints exist, restore from checkpoint.
    K._initialize_variables(session)  # pylint: disable=protected-access
  else:
    _wait_for_variable_initialization(session)
 def _worker_fn(self, strategy):
   worker_context = distribute_coordinator_context.get_current_worker_context()
   session_config = worker_context._session_config
   self._device_filters.extend(session_config.device_filters)
   self._intra_op_parallelism_threads = (
       session_config.intra_op_parallelism_threads)
   self._inter_op_parallelism_threads = (
       session_config.inter_op_parallelism_threads)
   return MockServer()
 def _worker_fn(self, strategy):
   worker_context = distribute_coordinator_context.get_current_worker_context()
   session_config = worker_context._session_config
   self._device_filters.extend(session_config.device_filters)
   self._intra_op_parallelism_threads = (
       session_config.intra_op_parallelism_threads)
   self._inter_op_parallelism_threads = (
       session_config.inter_op_parallelism_threads)
   return MockServer()
Пример #18
0
def init_restore_or_wait_for_variables():
    """Initialize or restore variables or wait for variables to be initialized."""
    session = K._get_session()  # pylint: disable=protected-access
    worker_context = dc_context.get_current_worker_context()
    if not worker_context or worker_context.should_init:
        # TODO(yuefengz): if checkpoints exit, restore from checkpoint.
        K._initialize_variables(session)  # pylint: disable=protected-access
    else:
        _wait_for_variable_initialization(session)
Пример #19
0
def should_load_checkpoint():
    """Returns whether the current worker should load checkpoints.

  In multi-worker training, if loading checkpoint is requested by user, or
  needed for fault-tolerance, the cluster should load checkpoint but not
  necessarily every worker in the cluster should.

  Returns:
      Whether this particular worker in the cluster should load checkpoints.
  """
    return dc_context.get_current_worker_context().experimental_should_init
Пример #20
0
  def _eval_fn(strategy):
    """Function for evaluator task."""
    local_estimator = copy.deepcopy(estimator)
    # pylint: disable=protected-access
    local_estimator._config._eval_distribute = strategy
    _init_run_config_from_worker_context(
        local_estimator._config, dc_context.get_current_worker_context())
    local_estimator._eval_distribution = strategy

    executor = executor_cls(local_estimator, train_spec, eval_spec)
    executor._start_continuous_evaluation()
    def _eval_fn(strategy):
        """Function for evaluator task."""
        local_estimator = copy.deepcopy(estimator)
        # pylint: disable=protected-access
        local_estimator._config._eval_distribute = strategy
        _init_run_config_from_worker_context(
            local_estimator._config, dc_context.get_current_worker_context())
        local_estimator._eval_distribution = strategy

        executor = executor_cls(local_estimator, train_spec, eval_spec)
        executor._start_continuous_evaluation()
  def delete_backup(self):
    """Delete the backup directories.

    Delete the backup directories which should not exist after `fit()`
    successfully finishes.
    """
    self._assert_in_multi_worker_mode()
    tracking.AutoTrackable.__delattr__(self._model, CKPT_SAVED_EPOCH)
    if dc_context.get_current_worker_context().should_checkpoint:
      _remove_dir(self._backup_dir)
    else:
      assert not os.path.exists(self._temp_dir)
    def _worker_fn(strategy):
        """Function for worker task."""
        local_estimator = copy.deepcopy(estimator)
        # pylint: disable=protected-access
        local_estimator._config._train_distribute = strategy
        _init_run_config_from_worker_context(
            local_estimator._config, dc_context.get_current_worker_context())
        local_estimator._train_distribution = strategy
        # pylint: enable=protected-access

        local_estimator.train(input_fn=train_spec.input_fn,
                              max_steps=train_spec.max_steps,
                              hooks=list(train_spec.hooks))
Пример #24
0
def should_save_checkpoint():
    """Returns whether the current worker should save checkpoints.

  In multi-worker training, if saving checkpoint is requested by user, or needed
  for fault-tolerance, the cluster should save checkpoint but not necessarily
  every worker in the cluster should.

  TODO(rchao): Consider generalizing this util to be `should_save_file` as there
  can be other files to save such as summary.

  Returns:
      Whether this particular worker in the cluster should save checkpoints.
  """
    return dc_context.get_current_worker_context().should_checkpoint
Пример #25
0
  def worker_fn(strategy):
    with ops.Graph().as_default():
      batch_size = 64
      steps = 2

      with strategy.scope():
        train_ds, _ = _mnist_synthetic_dataset(batch_size, steps)
        model = _clone_and_build_model(orig_model, strategy)

        orig_loss, orig_acc = model.evaluate(train_ds, steps=steps)

        # Workaround for the metrics issue (b/122928955) in async training. This
        # can only be used in standalone client mode.
        dc_context.get_current_worker_context().wait_for_other_workers()

        model.fit(x=train_ds, epochs=2, steps_per_epoch=steps)

        dc_context.get_current_worker_context().wait_for_other_workers()

        trained_loss, trained_acc = model.evaluate(train_ds, steps=steps)

      test_obj.assertLessEqual(trained_loss, orig_loss)
      test_obj.assertGreaterEqual(trained_acc, orig_acc)
Пример #26
0
  def _worker_fn(strategy):
    """Function for worker task."""
    local_estimator = copy.deepcopy(estimator)
    # pylint: disable=protected-access
    local_estimator._config._train_distribute = strategy
    _init_run_config_from_worker_context(
        local_estimator._config, dc_context.get_current_worker_context())
    local_estimator._train_distribution = strategy
    # pylint: enable=protected-access

    local_estimator.train(
        input_fn=train_spec.input_fn,
        max_steps=train_spec.max_steps,
        hooks=list(train_spec.hooks))
Пример #27
0
  def _worker_fn(strategy):
    """Function for evaluation."""
    local_estimator = copy.deepcopy(estimator)
    local_estimator._config._eval_distribute = strategy
    context = dc_context.get_current_worker_context()
    _init_run_config_from_worker_context(local_estimator._config, context)
    logging.info('Updated config: %s', str(vars(local_estimator._config)))
    local_estimator._eval_distribution = strategy

    if context.is_chief:
      chief_hooks = hooks
    else:
      chief_hooks = []
    return evaluate_distributed_fn(local_estimator, strategy, chief_hooks)
  def _worker_fn(strategy):
    """Function for evaluation."""
    local_estimator = copy.deepcopy(estimator)
    local_estimator._config._eval_distribute = strategy
    context = dc_context.get_current_worker_context()
    _init_run_config_from_worker_context(local_estimator._config, context)
    logging.info('Updated config: %s', str(vars(local_estimator._config)))
    local_estimator._eval_distribution = strategy

    if context.is_chief:
      chief_hooks = hooks
    else:
      chief_hooks = []
    return evaluate_distributed_fn(local_estimator, strategy, chief_hooks)
Пример #29
0
    def _between_graph_worker_fn(self, strategy):
        context = distribute_coordinator_context.get_current_worker_context()
        self.assertTrue(context is not None)
        with self._test_session(target=context.master_target) as sess:
            with ops.device("/job:ps/task:0"):
                # TODO(yuefengz): investigate why not using resource variable will make
                # the test flaky.
                x = variable_scope.get_variable("x",
                                                initializer=10.0,
                                                use_resource=True)
            with ops.device("/job:ps/task:1"):
                y = variable_scope.get_variable("y",
                                                initializer=20.0,
                                                use_resource=True)

            x_add = x.assign_add(2.0)
            y_sub = y.assign_sub(2.0)
            train_op = control_flow_ops.group([x_add, y_sub])

            if context.is_chief:
                variables.global_variables_initializer().run()

            # Synchronize workers after initializaton.
            if context.has_barrier:
                context.wait_for_other_workers()
            else:
                while True:
                    uninit_vars = sess.run(
                        variables.report_uninitialized_variables())
                    # pylint: disable=g-explicit-length-test
                    if len(uninit_vars) == 0:
                        break

            sess.run(train_op)

            # Synchronize workers after one step to make sure they all have finished
            # training.
            if context.has_barrier:
                context.wait_for_other_workers()
            else:
                self._barrier.wait()

            x_val, y_val = sess.run([x, y])

            self.assertEqual(x_val, 16.0)
            self.assertEqual(y_val, 14.0)
            if x_val == 16.0 and y_val == 14.0:
                with self._lock:
                    self._result_correct += 1
Пример #30
0
  def _eval_fn(strategy):
    """Function for evaluator task."""
    local_estimator = copy.deepcopy(estimator)
    # pylint: disable=protected-access
    local_estimator._config._eval_distribute = strategy
    _init_run_config_from_worker_context(
        local_estimator._config, dc_context.get_current_worker_context())
    logging.info('Updated config: %s', str(vars(local_estimator._config)))
    local_estimator._eval_distribution = strategy

    # Prevent estimator.evaluate from calling distribute coordinator again. This
    # function calls estimator.evaluate which will use distribute coordinator
    # path again if `_distribute_coordinator_mode` is set.
    local_estimator._config._distribute_coordinator_mode = None  # pylint: disable=protected-access

    executor = executor_cls(local_estimator, train_spec, eval_spec)
    executor._start_continuous_evaluation()
  def _between_graph_worker_fn(self, strategy):
    context = distribute_coordinator_context.get_current_worker_context()
    self.assertTrue(context is not None)
    with self._test_session(target=context.master_target) as sess:
      with ops.device("/job:ps/task:0"):
        # TODO(yuefengz): investigate why not using resource variable will make
        # the test flaky.
        x = variable_scope.get_variable(
            "x", initializer=10.0, use_resource=True)
      with ops.device("/job:ps/task:1"):
        y = variable_scope.get_variable(
            "y", initializer=20.0, use_resource=True)

      x_add = x.assign_add(2.0)
      y_sub = y.assign_sub(2.0)
      train_op = control_flow_ops.group([x_add, y_sub])

      if context.is_chief:
        self.evaluate(variables.global_variables_initializer())

      # Synchronize workers after initializaton.
      if context.has_barrier:
        context.wait_for_other_workers()
      else:
        while True:
          uninit_vars = sess.run(variables.report_uninitialized_variables())
          # pylint: disable=g-explicit-length-test
          if len(uninit_vars) == 0:
            break

      sess.run(train_op)

      # Synchronize workers after one step to make sure they all have finished
      # training.
      if context.has_barrier:
        context.wait_for_other_workers()
      else:
        self._barrier.wait()

      x_val, y_val = sess.run([x, y])

      self.assertEqual(x_val, 16.0)
      self.assertEqual(y_val, 14.0)
      if x_val == 16.0 and y_val == 14.0:
        with self._lock:
          self._result_correct += 1
  def _eval_fn(strategy):
    """Function for evaluator task."""
    local_estimator = copy.deepcopy(estimator)
    # pylint: disable=protected-access
    local_estimator._config._eval_distribute = strategy
    _init_run_config_from_worker_context(
        local_estimator._config, dc_context.get_current_worker_context())
    logging.info('Updated config: %s', str(vars(local_estimator._config)))
    local_estimator._eval_distribution = strategy

    # Prevent estimator.evaluate from calling distribute coordinator again. This
    # function calls estimator.evaluate which will use distribute coordinator
    # path again if `_distribute_coordinator_mode` is set.
    local_estimator._config._distribute_coordinator_mode = None  # pylint: disable=protected-access

    executor = executor_cls(local_estimator, train_spec, eval_spec)
    executor._start_continuous_evaluation()
Пример #33
0
  def _dump_strategy_property(self, strategy):
    context = distribute_coordinator_context.get_current_worker_context()
    self.assertTrue(context is not None)

    self.assertEqual(context._strategy.should_init, strategy.should_init)
    self.assertEqual(context.should_checkpoint, strategy.should_checkpoint)
    self.assertEqual(context.should_save_summary, strategy.should_save_summary)

    task_type = str(context.task_type)
    task_id = context.task_id or 0
    with self._lock:
      if task_type not in self._strategy_property:
        self._strategy_property[task_type] = []
      while len(self._strategy_property[task_type]) <= task_id:
        self._strategy_property[task_type].append(None)
      self._strategy_property[task_type][task_id] = (
          context._strategy.should_init, context.should_checkpoint,
          context.should_save_summary)
def filter_callbacks(callbacks_list):
    """Filter Callbacks based on the worker context when running multi-worker.

  Arguments:
    callbacks_list: A list of `Callback` instances.

  Returns:
    The list of `Callback` instances that should be run on this worker.
  """
    worker_context = dc_context.get_current_worker_context()
    if callbacks_list is None or worker_context.is_chief:
        return callbacks_list

    # Some Callbacks should only run on the chief worker.
    return [
        callback for callback in callbacks_list
        if not callback._chief_worker_only
    ]  # pylint: disable=protected-access
Пример #35
0
 def _get_file_handle_and_path(self, epoch, logs):
     """Returns the file handle and path."""
     # TODO(rchao): Replace dc_context reference with
     # distributed_training_utils.should_current_worker_checkpoint() once
     # distributed_training_utils.py no longer depends on callbacks.py.
     if not K.in_multi_worker_mode(
     ) or dc_context.get_current_worker_context().should_checkpoint:
         return None, self.filepath.format(epoch=epoch + 1, **logs)
     else:
         # If this is multi-worker training, and this worker should not
         # save checkpoint, we replace the filepath with a dummy filepath so
         # it writes to a file that will be removed at the end of _save_model()
         # call. This is because the SyncOnReadVariable needs to be synced across
         # all the workers in order to be read, and all workers need to initiate
         # that.
         file_handle, temp_file_name = tempfile.mkstemp()
         extension = os.path.splitext(self.filepath)[1]
         return file_handle, temp_file_name + extension
  def _dump_strategy_property(self, strategy):
    context = distribute_coordinator_context.get_current_worker_context()
    self.assertTrue(context is not None)

    self.assertEqual(context._strategy.should_init, strategy.should_init)
    self.assertEqual(context.should_checkpoint, strategy.should_checkpoint)
    self.assertEqual(context.should_save_summary, strategy.should_save_summary)

    task_type = str(context.task_type)
    task_id = context.task_id or 0
    with self._lock:
      if task_type not in self._strategy_property:
        self._strategy_property[task_type] = []
      while len(self._strategy_property[task_type]) <= task_id:
        self._strategy_property[task_type].append(None)
      self._strategy_property[task_type][task_id] = (
          context._strategy.should_init, context.should_checkpoint,
          context.should_save_summary)
  def _in_graph_worker_fn(self, strategy):
    context = distribute_coordinator_context.get_current_worker_context()
    self.assertTrue(context is not None)
    with self._test_session(target=context.master_target) as sess:
      xs = []
      expected = 0.0
      for i in range(context.num_workers):
        with ops.device("/job:worker/task:%d" % i):
          x = variable_scope.get_variable("x_%d" % i, initializer=10.0)
          x_add = x.assign_add(float(i))
          xs.append(x_add)
          expected += i + 10.0

      with ops.device("/job:worker/task:0"):
        result = math_ops.add_n(xs)

      variables.global_variables_initializer().run()
      result_value = sess.run(result)
    self.assertEqual(result_value, expected)
    if result_value == expected:
      self._result_correct += 1
  def _in_graph_worker_fn(self, strategy):
    context = distribute_coordinator_context.get_current_worker_context()
    self.assertTrue(context is not None)
    with self._test_session(target=context.master_target) as sess:
      xs = []
      expected = 0.0
      for i in range(context.num_workers):
        with ops.device("/job:worker/task:%d" % i):
          x = variable_scope.get_variable("x_%d" % i, initializer=10.0)
          x_add = x.assign_add(float(i))
          xs.append(x_add)
          expected += i + 10.0

      with ops.device("/job:worker/task:0"):
        result = math_ops.add_n(xs)

      self.evaluate(variables.global_variables_initializer())
      result_value = sess.run(result)
    self.assertEqual(result_value, expected)
    if result_value == expected:
      self._result_correct += 1
Пример #39
0
    def on_train_begin(self, logs=None):
        if K.in_multi_worker_mode():
            # pylint: disable=protected-access
            # MultiWorkerTrainingState is used to manage the training state needed
            # for preemption-recovery of a worker in multi-worker training.
            self.model._training_state = (
                multi_worker_training_state.MultiWorkerTrainingState(
                    self.model, self.filepath))
            self._training_state = self.model._training_state
            if self._training_state.restore():
                # If the training state needs to be and is successfully restored,
                # it is recovering from a previous failure (or preemption). In such
                # case, do not load the weights from user specified file path.
                return

        # If this is not multi worker training, restoring is not needed, or
        # restoring failed, check if it should load weights on restart.
        # TODO(rchao): Also restore the epoch in single-worker training when
        # `self.load_weights_on_restart=True`.
        if self.load_weights_on_restart:
            # In multi worker training, it only should if `experimental_should_init`
            # is True.
            # TODO(rchao): Reference `experimental_should_init` api from a util file.
            if not K.in_multi_worker_mode(
            ) or dc_context.get_current_worker_context(
            ).experimental_should_init:
                filepath_to_load = (
                    self._get_most_recently_modified_file_matching_pattern(
                        self.filepath))
                if filepath_to_load is not None and os.path.exists(
                        filepath_to_load):
                    try:
                        # `filepath` may contain placeholders such as `{epoch:02d}`, and
                        # thus it attempts to load the most recently modified file with file
                        # name matching the pattern.
                        self.model.load_weights(filepath_to_load)
                    except (IOError, ValueError) as e:
                        raise ValueError(
                            'Error loading file from {}. Reason: {}'.format(
                                filepath_to_load, e))
  def restore(self):
    """Restore the training state from the backed up checkpoint file.

    Returns:
      True if the training state is successfully restored. False if the training
      state doesn't need to be restored, or error occurred so it can't.
    """
    self._assert_in_multi_worker_mode()
    if not dc_context.get_current_worker_context().experimental_should_init:
      # For multi-worker training, it should not restore a model in certain
      # worker setting (e.g. non-chief worker in ParameterServerStrategy).
      return False
    if os.path.exists(self._backup_dir):
      try:
        # Load the weights plus CKPT_SAVED_EPOCH variable.
        self._model.load_weights(self._backup_filepath)
        return True

      except (IOError, ValueError) as e:
        raise ValueError('Error loading file from {}. Reason: {}'.format(
            self._backup_filepath, e))
    return False
Пример #41
0
    def _dump_worker_context(self, strategy):
        """Dumps the propoerties of each worker context.

    It dumps the context properties to a dict mapping from task_type to a list
    of tuples of master_target, num_workers, is_chief and distribute_mode, where
    the list is indexed by the task_id.

    Args:
      strategy: a `DistributionStrategy` object.
    """
        context = distribute_coordinator_context.get_current_worker_context()
        self.assertTrue(context is not None)
        task_type = str(context.task_type)
        task_id = context.task_id or 0
        with self._lock:
            if task_type not in self._worker_context:
                self._worker_context[task_type] = []
            while len(self._worker_context[task_type]) <= task_id:
                self._worker_context[task_type].append(None)
            self._worker_context[task_type][task_id] = (
                context.master_target, context.num_workers, context.is_chief,
                context.distributed_mode)
Пример #42
0
def is_chief(cluster_spec=None, task_type=None, task_id=None):
    """Returns whether the given task is chief in the cluster.

  Since there is at most one evaluator and the evaluator itself should be
  independent of the training cluster, the evaluator job is also a chief job on
  its own.

  If this is currently running under a `_WorkerContext` of distribute
  coordinator, the arguments can be omitted as the result is already available.

  Args:
    cluster_spec: a dict, `ClusterDef` or `ClusterSpec` object specifying the
      cluster configurations.
    task_type: the task type in the cluster.
    task_id: the task id in the cluster.

  Returns:
    a boolean indicating whether the given task is chief.

  Raises:
    ValueError: if `task_type` is not in the `cluster_spec` or `task_id` exceeds
      the maximum id of the `task_type`.
  """
    if has_worker_context():
        # If a worker context exists, use the value provided by it.
        return dc_context.get_current_worker_context().is_chief

    _validate_cluster_spec(cluster_spec, task_type, task_id)
    cluster_spec = normalize_cluster_spec(cluster_spec).as_dict()

    if task_type == "chief" or task_type == "evaluator":
        return True

    # If chief not in the cluster_spec, use the first worker as chief. This is
    # common in CollectiveAllReduceStrategy.
    if ("chief" not in cluster_spec and task_type == "worker"
            and task_id == 0):
        return True
    return False
  def _dump_worker_context(self, strategy):
    """Dumps the propoerties of each worker context.

    It dumps the context properties to a dict mapping from task_type to a list
    of tuples of master_target, num_workers, is_chief and distribute_mode, where
    the list is indexed by the task_id.

    Args:
      strategy: a `DistributionStrategy` object.
    """
    context = distribute_coordinator_context.get_current_worker_context()
    self.assertTrue(context is not None)
    task_type = str(context.task_type)
    task_id = context.task_id or 0
    with self._lock:
      if task_type not in self._worker_context:
        self._worker_context[task_type] = []
      while len(self._worker_context[task_type]) <= task_id:
        self._worker_context[task_type].append(None)
      self._worker_context[task_type][task_id] = (context.master_target,
                                                  context.num_workers,
                                                  context.is_chief,
                                                  context.distributed_mode)
def is_current_worker_chief():
  return dc_context.get_current_worker_context().is_chief
 def on_train_begin(self, logs):
   if not dc_context.get_current_worker_context().is_chief:
     # Non-chief workers shouldn't run this callback.
     self.filtered_correctly = False
Пример #46
0
 def on_train_begin(self, logs):
     if not dc_context.get_current_worker_context().is_chief:
         # Non-chief workers shouldn't run this callback.
         self.filtered_correctly = False
Пример #47
0
def has_worker_context():
    """Returns whether a worker context has been entered."""
    return dc_context.get_current_worker_context() is not None
Пример #48
0
def wait_for_other_workers():
    """Waits for other workers to reach the same call to this method."""
    return dc_context.get_current_worker_context().wait_for_other_workers()