コード例 #1
0
  def testInGraphContextWithEval(self):
    # Adds a EVALUATOR job.
    cluster_spec = copy.deepcopy(self._cluster_spec)
    cluster_spec[EVALUATOR] = ["fake_evaluator"]

    # Dumps the task contexts to the self._worker_context dict.
    distribute_coordinator.run_distribute_coordinator(
        self._dump_worker_context,
        MockStrategy(between_graph=False),
        cluster_spec=cluster_spec,
        rpc_layer=None)

    # There are one "None" task and one EVALUATOR task.
    self.assertEqual(len(self._worker_context), 2)
    self.assertTrue("None" in self._worker_context)
    self.assertTrue(EVALUATOR in self._worker_context)
    self.assertEqual(len(self._worker_context["None"]), 1)
    self.assertEqual(len(self._worker_context[EVALUATOR]), 1)

    # Check whether each task has the right master_target, num_workers, is_chief
    # and distributed_mode.
    self.assertEqual(self._worker_context["None"][0], (_strip_protocol(
        _bytes_to_str(self._workers[0].target)), 3, True, True))
    self.assertEqual(self._worker_context[EVALUATOR][0],
                     ("fake_evaluator", 3, True, False))
コード例 #2
0
def _run_standalone_client(test_obj, strategy, cluster_spec):
  input_shape = (28, 28, 1)
  with strategy.scope():
    orig_model = _get_model(input_shape)

  def worker_fn(strategy):
    with ops.Graph().as_default():
      batch_size = 64
      steps = 2

      with strategy.scope():
        train_ds, _ = _mnist_synthetic_dataset(batch_size, steps)
        model = _clone_and_build_model(orig_model, strategy)

        orig_loss, orig_acc = model.evaluate(train_ds, steps=steps)

        # Workaround for the metrics issue (b/122928955) in async training. This
        # can only be used in standalone client mode.
        dc_context.get_current_worker_context().wait_for_other_workers()

        model.fit(x=train_ds, epochs=2, steps_per_epoch=steps)

        dc_context.get_current_worker_context().wait_for_other_workers()

        trained_loss, trained_acc = model.evaluate(train_ds, steps=steps)

      test_obj.assertLessEqual(trained_loss, orig_loss)
      test_obj.assertGreaterEqual(trained_acc, orig_acc)

  dc.run_distribute_coordinator(
      worker_fn,
      strategy,
      mode=dc.CoordinatorMode.STANDALONE_CLIENT,
      cluster_spec=cluster_spec)
コード例 #3
0
  def testBetweenGraphContextWithChief(self):
    # Adds a chief node, so there are NUM_WORKERS + 1 workers in total.
    cluster_spec = copy.deepcopy(self._cluster_spec)
    cluster_spec[CHIEF] = ["fake_chief"]

    # Dumps the task contexts to the self._worker_context dict.
    distribute_coordinator.run_distribute_coordinator(
        self._dump_worker_context,
        MockStrategy(between_graph=True),
        cluster_spec=cluster_spec,
        rpc_layer="grpc")

    # There are one CHIEF and three workers.
    self.assertEqual(len(self._worker_context), 2)
    self.assertTrue(CHIEF in self._worker_context)
    self.assertTrue(WORKER in self._worker_context)
    self.assertEqual(len(self._worker_context[CHIEF]), 1)
    self.assertEqual(len(self._worker_context[WORKER]), NUM_WORKERS)

    # Check whether each task has the right master_target, num_workers, is_chief
    # and distributed_mode.
    self.assertEqual(self._worker_context[CHIEF][0],
                     ("grpc://fake_chief", 4, True, True))
    self.assertEqual(
        self._worker_context[WORKER][0],
        (_bytes_to_str(self._workers[0].target), NUM_WORKERS + 1, False, True))
    self.assertEqual(
        self._worker_context[WORKER][1],
        (_bytes_to_str(self._workers[1].target), NUM_WORKERS + 1, False, True))
    self.assertEqual(
        self._worker_context[WORKER][2],
        (_bytes_to_str(self._workers[2].target), NUM_WORKERS + 1, False, True))
コード例 #4
0
 def testInGraphSplitMode(self):
   """Test it runs in-graph replication in split client mode."""
   distribute_coordinator.run_distribute_coordinator(
       self._in_graph_worker_fn,
       cluster_spec=self._cluster_spec,
       between_graph=False)
   self.assertEqual(self._result_correct, 1)
コード例 #5
0
 def testInGraph(self):
   """Test it runs in-graph replicated training correctly."""
   distribute_coordinator.run_distribute_coordinator(
       self._in_graph_worker_fn,
       cluster_spec=self._cluster_spec,
       between_graph=False)
   self.assertEqual(self._result_correct, 1)
コード例 #6
0
 def testInGraphStandaloneMode(self):
   """Test it runs in-graph replication in standalone client mode."""
   distribute_coordinator.run_distribute_coordinator(
       self._in_graph_worker_fn,
       MockStrategy(between_graph=False),
       cluster_spec=self._cluster_spec)
   self.assertEqual(self._result_correct, 1)
コード例 #7
0
  def testRpcLayerEnvironmentVariable(self):
    cluster_spec = {"worker": ["fake_worker"], "ps": ["fake_ps"]}
    tf_config = {"cluster": cluster_spec, "rpc_layer": "cake"}

    rpc_layer_from_coordinator = [None]

    def _run_mock_server(cluster_spec=None,
                         task_type=None,
                         task_id=None,
                         session_config=None,
                         rpc_layer=None,
                         environment=None):
      del cluster_spec, task_type, task_id, session_config, environment
      rpc_layer_from_coordinator[0] = rpc_layer
      return MockServer()

    with test.mock.patch.dict(
        "os.environ",
        {"TF_CONFIG": json.dumps(tf_config)}), test.mock.patch.object(
            distribute_coordinator, "_run_std_server", _run_mock_server):
      distribute_coordinator.run_distribute_coordinator(
          None,
          MockStrategy(between_graph=True),
          mode=INDEPENDENT_WORKER,
          cluster_spec=cluster_spec,
          task_type="ps",
          task_id=0)
    self.assertEqual(rpc_layer_from_coordinator[0], "cake")
コード例 #8
0
 def _thread_fn(cluster_spec):
   distribute_coordinator.run_distribute_coordinator(
       None,
       None,
       mode=INDEPENDENT_WORKER,
       cluster_spec=cluster_spec,
       task_type="ps",
       task_id=0)
コード例 #9
0
 def _thread_fn(cluster_spec):
   distribute_coordinator.run_distribute_coordinator(
       None,
       MockStrategy(between_graph=True),
       mode=INDEPENDENT_WORKER,
       cluster_spec=cluster_spec,
       task_type="ps",
       task_id=0)
コード例 #10
0
  def testBetweenGraphWithMonitoredSession(self):
    """Test monitored session in standalone client mode."""
    distribute_coordinator.run_distribute_coordinator(
        self._between_graph_with_monitored_session,
        MockStrategy(between_graph=True),
        cluster_spec=self._cluster_spec)

    # Each finished worker will increment self._result_correct.
    self.assertEqual(self._result_correct, NUM_WORKERS)
コード例 #11
0
  def testBetweenGraph(self):
    """Test it runs between-graph replication in standalone client mode."""
    distribute_coordinator.run_distribute_coordinator(
        self._between_graph_worker_fn,
        MockStrategy(between_graph=True),
        cluster_spec=self._cluster_spec)

    # Each finished worker will increment self._result_correct.
    self.assertEqual(self._result_correct, NUM_WORKERS)
コード例 #12
0
  def testBetweenGraph(self):
    """Test it runs between-graph replicated training correctly."""
    distribute_coordinator.run_distribute_coordinator(
        self._between_graph_worker_fn,
        cluster_spec=self._cluster_spec,
        between_graph=True)

    # Each finished worker will increment self._result_correct.
    self.assertEqual(self._result_correct, NUM_WORKERS)
コード例 #13
0
  def testLocalContext(self):
    # Dumps the task contexts to the self._task_context dict.
    distribute_coordinator.run_distribute_coordinator(
        self._dump_task_context, cluster_spec=None, between_graph=True)

    # There is only a "None" task.
    self.assertEqual(len(self._task_context), 1)
    self.assertTrue("None" in self._task_context)
    self.assertEqual(len(self._task_context["None"]), 1)

    # Check whether each task has the right master_target, num_workers, is_chief
    # and distributed_mode.
    self.assertEqual(self._task_context["None"][0], ("local", 0, True, False))
コード例 #14
0
  def test_session_config_in_session_creator(self):
    cluster_spec = {"worker": ["localhost:0"]}
    tf_config = {"cluster": cluster_spec}

    with test.mock.patch.dict("os.environ",
                              {"TF_CONFIG": json.dumps(tf_config)}):
      distribute_coordinator.run_distribute_coordinator(
          self._worker_fn,
          MockStrategy(between_graph=True),
          mode=INDEPENDENT_WORKER,
          cluster_spec=cluster_spec,
          task_type="worker",
          task_id=0)
    self.assertEqual(self._device_filters, ["/job:worker/task:0", "/job:ps"])
    self.assertEqual(self._intra_op_parallelism_threads, 2)
    self.assertEqual(self._inter_op_parallelism_threads, 0)
コード例 #15
0
  def testInGraphContext(self):
    # Dumps the task contexts to the self._worker_context dict.
    distribute_coordinator.run_distribute_coordinator(
        self._dump_worker_context,
        MockStrategy(between_graph=False),
        cluster_spec=self._cluster_spec)

    # There is only a "None" task in the dumped task context.
    self.assertEqual(len(self._worker_context), 1)
    self.assertTrue("None" in self._worker_context)
    self.assertEqual(len(self._worker_context["None"]), 1)

    # Check whether each task has the right master_target, num_workers, is_chief
    # and distributed_mode.
    self.assertEqual(
        self._worker_context["None"][0],
        (_bytes_to_str(self._workers[0].target), NUM_WORKERS, True, True))
コード例 #16
0
  def testBetweenGraphStrategyProperties(self):
    # Dumps properties of the strategy objects.
    distribute_coordinator.run_distribute_coordinator(
        self._dump_strategy_property,
        MockStrategy(between_graph=True, should_init=True),
        cluster_spec=self._cluster_spec)

    # There is only one type of task and there three such tasks.
    self.assertEqual(len(self._strategy_property), 1)
    self.assertTrue(WORKER in self._strategy_property)
    self.assertEqual(len(self._strategy_property[WORKER]), NUM_WORKERS)

    # Check whether each task has the right properties of should_init,
    # should_checkpoint and should_save_summary.
    self.assertEqual(self._strategy_property[WORKER][0], (True, True, True))
    self.assertEqual(self._strategy_property[WORKER][1], (True, False, False))
    self.assertEqual(self._strategy_property[WORKER][2], (True, False, False))
コード例 #17
0
  def test_eval_strategy_configure(self):
    cluster_spec = {"evaluator": ["localhost:0"]}
    tf_config = {"cluster": cluster_spec}

    with test.mock.patch.dict("os.environ",
                              {"TF_CONFIG": json.dumps(tf_config)}):
      distribute_coordinator.run_distribute_coordinator(
          lambda _: None,
          MockStrategy(between_graph=False),
          eval_fn=self._worker_fn,
          eval_strategy=MockStrategy(between_graph=True),
          mode=INDEPENDENT_WORKER,
          cluster_spec=cluster_spec,
          task_type="evaluator",
          task_id=0)
    self.assertEqual(self._device_filters, ["/job:somejob"])
    self.assertEqual(self._intra_op_parallelism_threads, 0)
    self.assertEqual(self._inter_op_parallelism_threads, 2)
コード例 #18
0
  def test_session_config_in_std_server(self):
    cluster_spec = {"worker": ["fake_worker"], "ps": ["fake_ps"]}
    tf_config = {"cluster": cluster_spec}

    with test.mock.patch.dict(
        "os.environ",
        {"TF_CONFIG": json.dumps(tf_config)}), test.mock.patch.object(
            distribute_coordinator, "_run_std_server",
            self._dump_device_filters):
      distribute_coordinator.run_distribute_coordinator(
          lambda _: None,
          MockStrategy(between_graph=True),
          mode=INDEPENDENT_WORKER,
          cluster_spec=cluster_spec,
          task_type="worker",
          task_id=0)
    self.assertEqual(self._intra_op_parallelism_threads, 1)
    self.assertEqual(self._inter_op_parallelism_threads, 0)
コード例 #19
0
    def test_session_config_in_session_creator(self):
        cluster_spec = {"worker": ["localhost:0"]}
        tf_config = {"cluster": cluster_spec}

        # Reset the saved Server state.
        distribute_coordinator._thread_local = threading.local()  # pylint: disable=protected-access

        with test.mock.patch.dict("os.environ",
                                  {"TF_CONFIG": json.dumps(tf_config)}):
            distribute_coordinator.run_distribute_coordinator(
                self._worker_fn,
                MockStrategy(between_graph=True),
                mode=INDEPENDENT_WORKER,
                cluster_spec=cluster_spec,
                task_type="worker",
                task_id=0)
        self.assertEqual(self._device_filters,
                         ["/job:worker/task:0", "/job:ps"])
        self.assertEqual(self._intra_op_parallelism_threads, 2)
        self.assertEqual(self._inter_op_parallelism_threads, 0)
コード例 #20
0
    def wrapper(model, **kwargs):
        def _worker_fn(_):
            callbacks = kwargs.pop('callbacks', None)
            filtered_callbacks = dist_utils.filter_distributed_callbacks(
                callbacks, model)
            kwargs['callbacks'] = filtered_callbacks
            return method(model, **kwargs)

        return dc.run_distribute_coordinator(_worker_fn,
                                             model._distribution_strategy,
                                             mode='independent_worker')
コード例 #21
0
    def testBetweenGraphStrategyProperties(self):
        # Dumps properties of the strategy objects.
        distribute_coordinator.run_distribute_coordinator(
            self._dump_strategy_property,
            MockStrategy(between_graph=True, should_init=True),
            cluster_spec=self._cluster_spec)

        # There is only one type of task and there three such tasks.
        self.assertEqual(len(self._strategy_property), 1)
        self.assertTrue(WORKER in self._strategy_property)
        self.assertEqual(len(self._strategy_property[WORKER]), NUM_WORKERS)

        # Check whether each task has the right properties of should_init,
        # should_checkpoint and should_save_summary.
        self.assertEqual(self._strategy_property[WORKER][0],
                         (True, True, True))
        self.assertEqual(self._strategy_property[WORKER][1],
                         (True, False, False))
        self.assertEqual(self._strategy_property[WORKER][2],
                         (True, False, False))
コード例 #22
0
  def wrapper(model, **kwargs):
    def _worker_fn(_):
      callbacks = kwargs.pop('callbacks', None)
      filtered_callbacks = dist_utils.filter_distributed_callbacks(callbacks)
      kwargs['callbacks'] = filtered_callbacks
      return method(model, **kwargs)

    return dc.run_distribute_coordinator(
        _worker_fn,
        model._distribution_strategy,
        mode=dc.CoordinatorMode.INDEPENDENT_WORKER)
コード例 #23
0
  def testBetweenGraphContext(self):
    # Dumps the task contexts to the self._worker_context dict.
    distribute_coordinator.run_distribute_coordinator(
        self._dump_worker_context,
        MockStrategy(between_graph=True),
        cluster_spec=self._cluster_spec)

    # There is only one type of task and there three such tasks.
    self.assertEqual(len(self._worker_context), 1)
    self.assertTrue(WORKER in self._worker_context)
    self.assertEqual(len(self._worker_context[WORKER]), NUM_WORKERS)

    # Check whether each task has the right master_target, num_workers, is_chief
    # and distributed_mode.
    self.assertEqual(
        self._worker_context[WORKER][0],
        (_bytes_to_str(self._workers[0].target), NUM_WORKERS, True, True))
    self.assertEqual(
        self._worker_context[WORKER][1],
        (_bytes_to_str(self._workers[1].target), NUM_WORKERS, False, True))
    self.assertEqual(
        self._worker_context[WORKER][2],
        (_bytes_to_str(self._workers[2].target), NUM_WORKERS, False, True))
コード例 #24
0
    def testBetweenGraphContext(self):
        # Dumps the task contexts to the self._worker_context dict.
        distribute_coordinator.run_distribute_coordinator(
            self._dump_worker_context,
            MockStrategy(between_graph=True),
            cluster_spec=self._cluster_spec)

        # There is only one type of task and there three such tasks.
        self.assertEqual(len(self._worker_context), 1)
        self.assertTrue(WORKER in self._worker_context)
        self.assertEqual(len(self._worker_context[WORKER]), NUM_WORKERS)

        # Check whether each task has the right master_target, num_workers, is_chief
        # and distributed_mode.
        self.assertEqual(
            self._worker_context[WORKER][0],
            (_bytes_to_str(self._workers[0].target), NUM_WORKERS, True, True))
        self.assertEqual(
            self._worker_context[WORKER][1],
            (_bytes_to_str(self._workers[1].target), NUM_WORKERS, False, True))
        self.assertEqual(
            self._worker_context[WORKER][2],
            (_bytes_to_str(self._workers[2].target), NUM_WORKERS, False, True))
コード例 #25
0
def _run_standalone_client(test_obj, strategy, cluster_spec):
    input_shape = (28, 28, 1)
    with strategy.scope():
        orig_model = multi_worker_testing_utils.get_mnist_model(input_shape)

    def worker_fn(strategy):
        with ops.Graph().as_default():
            batch_size = 64
            steps = 2

            with strategy.scope():
                train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset(
                    batch_size, steps)
                model = _clone_and_build_model(orig_model, strategy)

                orig_loss, orig_acc = model.evaluate(train_ds, steps=steps)

                # Workaround for the metrics issue (b/122928955) in async training. This
                # can only be used in standalone client mode.
                dc_context.get_current_worker_context().wait_for_other_workers(
                )

                model.fit(x=train_ds, epochs=2, steps_per_epoch=steps)

                dc_context.get_current_worker_context().wait_for_other_workers(
                )

                trained_loss, trained_acc = model.evaluate(train_ds,
                                                           steps=steps)

            test_obj.assertLessEqual(trained_loss, orig_loss)
            test_obj.assertGreaterEqual(trained_acc, orig_acc)

    dc.run_distribute_coordinator(worker_fn,
                                  strategy,
                                  mode=dc.CoordinatorMode.STANDALONE_CLIENT,
                                  cluster_spec=cluster_spec)
コード例 #26
0
def estimator_train(estimator, train_distributed_fn, hooks):
    """Run distribute coordinator for Estimator's `train` method."""
    assert estimator._config._distribute_coordinator_mode
    run_config = estimator._config
    assert estimator._config.cluster_spec
    cluster_spec = multi_worker_util.normalize_cluster_spec(
        estimator._config.cluster_spec)
    assert estimator._config._train_distribute

    if 'evaluator' in cluster_spec.jobs:
        raise ValueError("'evaluator' job is not supported if you don't use "
                         '`train_and_evaluate`')

    if (estimator._config._distribute_coordinator_mode !=  # pylint: disable=protected-access
            dc.CoordinatorMode.STANDALONE_CLIENT):
        raise ValueError(
            'Only `STANDALONE_CLIENT` mode is supported when you call '
            '`estimator.train`')

    if estimator._config._train_distribute.extended.experimental_between_graph:
        # TODO(yuefengz): remove this limitation once we figure out how to merge
        # return values from `_worker_fn`s.
        raise ValueError(
            '`Estimator.train` API is not supported for %s with '
            '`STANDALONE_CLIENT` mode.' %
            estimator._config._train_distribute.__class__.__name__)

    def _worker_fn(strategy):
        """Function for worker task."""
        local_estimator = copy.deepcopy(estimator)
        local_estimator._config._train_distribute = strategy
        context = dc_context.get_current_worker_context()
        _init_run_config_from_worker_context(local_estimator._config, context)
        logging.info('Updated config: %s', str(vars(local_estimator._config)))
        local_estimator._train_distribution = strategy

        if context.is_chief:
            chief_hooks = hooks
        else:
            chief_hooks = []
        train_distributed_fn(local_estimator, strategy, chief_hooks)
        return local_estimator

    return dc.run_distribute_coordinator(
        _worker_fn,
        estimator._config.train_distribute,
        mode=run_config._distribute_coordinator_mode,
        cluster_spec=cluster_spec,
        session_config=run_config.session_config)
コード例 #27
0
def estimator_train(estimator, train_distributed_fn, hooks):
  """Run distribute coordinator for Estimator's `train` method."""
  assert estimator._config._distribute_coordinator_mode
  run_config = estimator._config
  assert estimator._config.cluster_spec
  cluster_spec = multi_worker_util.normalize_cluster_spec(
      estimator._config.cluster_spec)
  assert estimator._config._train_distribute

  if 'evaluator' in cluster_spec.jobs:
    raise ValueError("'evaluator' job is not supported if you don't use "
                     '`train_and_evaluate`')

  if (estimator._config._distribute_coordinator_mode !=  # pylint: disable=protected-access
      dc.CoordinatorMode.STANDALONE_CLIENT):
    raise ValueError('Only `STANDALONE_CLIENT` mode is supported when you call '
                     '`estimator.train`')

  if estimator._config._train_distribute.extended.experimental_between_graph:
    # TODO(yuefengz): remove this limitation once we figure out how to merge
    # return values from `_worker_fn`s.
    raise ValueError('`Estimator.train` API is not supported for %s with '
                     '`STANDALONE_CLIENT` mode.' %
                     estimator._config._train_distribute.__class__.__name__)

  def _worker_fn(strategy):
    """Function for worker task."""
    local_estimator = copy.deepcopy(estimator)
    local_estimator._config._train_distribute = strategy
    context = dc_context.get_current_worker_context()
    _init_run_config_from_worker_context(local_estimator._config, context)
    logging.info('Updated config: %s', str(vars(local_estimator._config)))
    local_estimator._train_distribution = strategy

    if context.is_chief:
      chief_hooks = hooks
    else:
      chief_hooks = []
    train_distributed_fn(local_estimator, strategy, chief_hooks)
    return local_estimator

  return dc.run_distribute_coordinator(
      _worker_fn,
      estimator._config.train_distribute,
      mode=run_config._distribute_coordinator_mode,
      cluster_spec=cluster_spec,
      session_config=run_config.session_config)
コード例 #28
0
    def fit(self,
            model,
            x=None,
            y=None,
            batch_size=None,
            epochs=1,
            verbose=1,
            callbacks=None,
            validation_split=0.,
            validation_data=None,
            shuffle=True,
            class_weight=None,
            sample_weight=None,
            initial_epoch=0,
            steps_per_epoch=None,
            validation_steps=None,
            validation_freq=1,
            **kwargs):
        # Multi-Worker mode runs the Keras training loop on multiple
        # servers via the Distribute Coordinator.
        def _worker_fn(_):
            """Run training inside the distributed coordinator."""
            filtered_callbacks = dist_utils.filter_distributed_callbacks(
                callbacks)
            return fit_distributed(model,
                                   x=x,
                                   y=y,
                                   batch_size=batch_size,
                                   epochs=epochs,
                                   verbose=verbose,
                                   callbacks=filtered_callbacks,
                                   validation_split=validation_split,
                                   validation_data=validation_data,
                                   shuffle=shuffle,
                                   class_weight=class_weight,
                                   sample_weight=sample_weight,
                                   initial_epoch=initial_epoch,
                                   steps_per_epoch=steps_per_epoch,
                                   validation_steps=validation_steps,
                                   validation_freq=validation_freq)

        # Independent worker only for now.
        return dc.run_distribute_coordinator(
            _worker_fn,
            model._distribution_strategy,
            mode=dc.CoordinatorMode.INDEPENDENT_WORKER)
コード例 #29
0
def train_and_evaluate(estimator, train_spec, eval_spec, executor_cls):
  """Run distribute coordinator for Estimator's `train_and_evaluate`.

  Args:
    estimator: An `Estimator` instance to train and evaluate.
    train_spec: A `TrainSpec` instance to specify the training specification.
    eval_spec: A `EvalSpec` instance to specify the evaluation and export
      specification.
    executor_cls: the evaluation executor class of Estimator.

  Raises:
    ValueError: if `distribute_coordinator_mode` is None in RunConfig.
  """
  run_config = estimator.config
  if not run_config._distribute_coordinator_mode:  # pylint: disable=protected-access
    raise ValueError(
        'Distribute coordinator mode is not specified in `RunConfig`.')

  def _worker_fn(strategy):
    """Function for worker task."""
    local_estimator = copy.deepcopy(estimator)
    # pylint: disable=protected-access
    local_estimator._config._train_distribute = strategy
    context = dc_context.get_current_worker_context()
    _init_run_config_from_worker_context(local_estimator._config, context)
    logging.info('Updated config: %s', str(vars(local_estimator._config)))
    local_estimator._train_distribution = strategy
    # pylint: enable=protected-access

    # In the standalone client, we don't need to run hooks on all threads
    # because logging hooks on all threads may be too much on the screen; also
    # tensor passed to one hook can only be fetched with the graph where the
    # tensor is defined. Other hooks such as checkpointing hooks will added by
    # MonitoredTrainingSession.
    # TODO(yuefengz): Is there a hook that does need to run on all threads in
    # standalone client mode?
    if (run_config._distribute_coordinator_mode ==  # pylint: disable=protected-access
        dc.CoordinatorMode.INDEPENDENT_WORKER or context.is_chief):
      hooks = list(train_spec.hooks)
    else:
      hooks = []

    # Prevent estimator.train from calling distribute coordinator again. This
    # function calls estimator.train which will use distribute coordinator path
    # again if `_distribute_coordinator_mode` is set.
    local_estimator._config._distribute_coordinator_mode = None  # pylint: disable=protected-access
    local_estimator.train(
        input_fn=train_spec.input_fn,
        max_steps=train_spec.max_steps,
        hooks=hooks)

  def _eval_fn(strategy):
    """Function for evaluator task."""
    local_estimator = copy.deepcopy(estimator)
    # pylint: disable=protected-access
    local_estimator._config._eval_distribute = strategy
    _init_run_config_from_worker_context(
        local_estimator._config, dc_context.get_current_worker_context())
    logging.info('Updated config: %s', str(vars(local_estimator._config)))
    local_estimator._eval_distribution = strategy

    # Prevent estimator.evaluate from calling distribute coordinator again. This
    # function calls estimator.evaluate which will use distribute coordinator
    # path again if `_distribute_coordinator_mode` is set.
    local_estimator._config._distribute_coordinator_mode = None  # pylint: disable=protected-access

    executor = executor_cls(local_estimator, train_spec, eval_spec)
    executor._start_continuous_evaluation()
    # pylint: enable=protected-access

  # pylint: disable=protected-access
  if (run_config._distribute_coordinator_mode ==
      dc.CoordinatorMode.STANDALONE_CLIENT):
    cluster_spec = run_config.cluster_spec
    assert cluster_spec
  else:
    # The cluster_spec comes from TF_CONFIG environment variable if it is
    # INDEPENDENT_WORKER mode.
    cluster_spec = None

  dc.run_distribute_coordinator(
      _worker_fn,
      run_config.train_distribute,
      _eval_fn,
      run_config.eval_distribute,
      mode=run_config._distribute_coordinator_mode,
      cluster_spec=cluster_spec,
      session_config=run_config.session_config)
コード例 #30
0
def train_and_evaluate(estimator, train_spec, eval_spec, executor_cls):
  """Run distribute coordinator for Estimator's `train_and_evaluate`.

  Args:
    estimator: An `Estimator` instance to train and evaluate.
    train_spec: A `TrainSpec` instance to specify the training specification.
    eval_spec: A `EvalSpec` instance to specify the evaluation and export
      specification.
    executor_cls: the evaluation executor class of Estimator.

  Raises:
    ValueError: if `distribute_coordinator_mode` is None in RunConfig.
  """
  run_config = estimator.config
  if not run_config._distribute_coordinator_mode:  # pylint: disable=protected-access
    raise ValueError(
        'Distribute coordinator mode is not specified in `RunConfig`.')

  def _worker_fn(strategy):
    """Function for worker task."""
    local_estimator = copy.deepcopy(estimator)
    # pylint: disable=protected-access
    local_estimator._config._train_distribute = strategy
    _init_run_config_from_worker_context(
        local_estimator._config, dc_context.get_current_worker_context())
    local_estimator._train_distribution = strategy
    # pylint: enable=protected-access

    local_estimator.train(
        input_fn=train_spec.input_fn,
        max_steps=train_spec.max_steps,
        hooks=list(train_spec.hooks))

  def _eval_fn(strategy):
    """Function for evaluator task."""
    local_estimator = copy.deepcopy(estimator)
    # pylint: disable=protected-access
    local_estimator._config._eval_distribute = strategy
    _init_run_config_from_worker_context(
        local_estimator._config, dc_context.get_current_worker_context())
    local_estimator._eval_distribution = strategy

    executor = executor_cls(local_estimator, train_spec, eval_spec)
    executor._start_continuous_evaluation()
    # pylint: enable=protected-access

  # pylint: disable=protected-access
  if (run_config._distribute_coordinator_mode ==
      dc.CoordinatorMode.STANDALONE_CLIENT):
    cluster_spec = run_config.cluster_spec
    assert cluster_spec
  else:
    # The cluster_spec comes from TF_CONFIG environment variable if it is
    # INDEPENDENT_WORKER mode.
    cluster_spec = None

  dc.run_distribute_coordinator(
      _worker_fn,
      run_config.train_distribute,
      _eval_fn,
      run_config.eval_distribute,
      mode=run_config._distribute_coordinator_mode,
      cluster_spec=cluster_spec,
      session_config=run_config.session_config)
コード例 #31
0
def train_and_evaluate(estimator, train_spec, eval_spec, executor_cls):
    """Run distribute coordinator for Estimator's `train_and_evaluate`.

  Args:
    estimator: An `Estimator` instance to train and evaluate.
    train_spec: A `TrainSpec` instance to specify the training specification.
    eval_spec: A `EvalSpec` instance to specify the evaluation and export
      specification.
    executor_cls: the evaluation executor class of Estimator.

  Raises:
    ValueError: if `distribute_coordinator_mode` is None in RunConfig.
  """
    run_config = estimator.config
    if not run_config._distribute_coordinator_mode:  # pylint: disable=protected-access
        raise ValueError(
            'Distribute coordinator mode is not specified in `RunConfig`.')

    def _worker_fn(strategy):
        """Function for worker task."""
        local_estimator = copy.deepcopy(estimator)
        # pylint: disable=protected-access
        local_estimator._config._train_distribute = strategy
        _init_run_config_from_worker_context(
            local_estimator._config, dc_context.get_current_worker_context())
        local_estimator._train_distribution = strategy
        # pylint: enable=protected-access

        local_estimator.train(input_fn=train_spec.input_fn,
                              max_steps=train_spec.max_steps,
                              hooks=list(train_spec.hooks))

    def _eval_fn(strategy):
        """Function for evaluator task."""
        local_estimator = copy.deepcopy(estimator)
        # pylint: disable=protected-access
        local_estimator._config._eval_distribute = strategy
        _init_run_config_from_worker_context(
            local_estimator._config, dc_context.get_current_worker_context())
        local_estimator._eval_distribution = strategy

        executor = executor_cls(local_estimator, train_spec, eval_spec)
        executor._start_continuous_evaluation()
        # pylint: enable=protected-access

    # pylint: disable=protected-access
    if (run_config._distribute_coordinator_mode ==
            dc.CoordinatorMode.STANDALONE_CLIENT):
        cluster_spec = run_config.cluster_spec
        assert cluster_spec
    else:
        # The cluster_spec comes from TF_CONFIG environment variable if it is
        # INDEPENDENT_WORKER mode.
        cluster_spec = None

    dc.run_distribute_coordinator(_worker_fn,
                                  run_config.train_distribute,
                                  _eval_fn,
                                  run_config.eval_distribute,
                                  mode=run_config._distribute_coordinator_mode,
                                  cluster_spec=cluster_spec,
                                  session_config=run_config.session_config)
コード例 #32
0
def train_and_evaluate(estimator, train_spec, eval_spec, executor_cls):
  """Run distribute coordinator for Estimator's `train_and_evaluate`.

  Args:
    estimator: An `Estimator` instance to train and evaluate.
    train_spec: A `TrainSpec` instance to specify the training specification.
    eval_spec: A `EvalSpec` instance to specify the evaluation and export
      specification.
    executor_cls: the evaluation executor class of Estimator.

  Raises:
    ValueError: if `distribute_coordinator_mode` is None in RunConfig.
  """
  run_config = estimator.config
  if not run_config._distribute_coordinator_mode:  # pylint: disable=protected-access
    raise ValueError(
        'Distribute coordinator mode is not specified in `RunConfig`.')

  def _worker_fn(strategy):
    """Function for worker task."""
    local_estimator = copy.deepcopy(estimator)
    # pylint: disable=protected-access
    local_estimator._config._train_distribute = strategy
    context = dc_context.get_current_worker_context()
    _init_run_config_from_worker_context(local_estimator._config, context)
    logging.info('Updated config: %s', str(vars(local_estimator._config)))
    local_estimator._train_distribution = strategy
    # pylint: enable=protected-access

    # In the standalone client, we don't need to run hooks on all threads
    # because logging hooks on all threads may be too much on the screen; also
    # tensor passed to one hook can only be fetched with the graph where the
    # tensor is defined. Other hooks such as checkpointing hooks will added by
    # MonitoredTrainingSession.
    # TODO(yuefengz): Is there a hook that does need to run on all threads in
    # standalone client mode?
    if (run_config._distribute_coordinator_mode ==  # pylint: disable=protected-access
        dc.CoordinatorMode.INDEPENDENT_WORKER or context.is_chief):
      hooks = list(train_spec.hooks)
    else:
      hooks = []

    # Prevent estimator.train from calling distribute coordinator again. This
    # function calls estimator.train which will use distribute coordinator path
    # again if `_distribute_coordinator_mode` is set.
    local_estimator._config._distribute_coordinator_mode = None  # pylint: disable=protected-access
    local_estimator.train(
        input_fn=train_spec.input_fn,
        max_steps=train_spec.max_steps,
        hooks=hooks)

  def _eval_fn(strategy):
    """Function for evaluator task."""
    local_estimator = copy.deepcopy(estimator)
    # pylint: disable=protected-access
    local_estimator._config._eval_distribute = strategy
    _init_run_config_from_worker_context(
        local_estimator._config, dc_context.get_current_worker_context())
    logging.info('Updated config: %s', str(vars(local_estimator._config)))
    local_estimator._eval_distribution = strategy

    # Prevent estimator.evaluate from calling distribute coordinator again. This
    # function calls estimator.evaluate which will use distribute coordinator
    # path again if `_distribute_coordinator_mode` is set.
    local_estimator._config._distribute_coordinator_mode = None  # pylint: disable=protected-access

    executor = executor_cls(local_estimator, train_spec, eval_spec)
    executor._start_continuous_evaluation()
    # pylint: enable=protected-access

  # pylint: disable=protected-access
  if (run_config._distribute_coordinator_mode ==
      dc.CoordinatorMode.STANDALONE_CLIENT):
    cluster_spec = run_config.cluster_spec
    assert cluster_spec
  else:
    # The cluster_spec comes from TF_CONFIG environment variable if it is
    # INDEPENDENT_WORKER mode.
    cluster_spec = None

  dc.run_distribute_coordinator(
      _worker_fn,
      run_config.train_distribute,
      _eval_fn,
      run_config.eval_distribute,
      mode=run_config._distribute_coordinator_mode,
      cluster_spec=cluster_spec,
      session_config=run_config.session_config)