def testInGraphContextWithEval(self): # Adds a EVALUATOR job. cluster_spec = copy.deepcopy(self._cluster_spec) cluster_spec[EVALUATOR] = ["fake_evaluator"] # Dumps the task contexts to the self._worker_context dict. distribute_coordinator.run_distribute_coordinator( self._dump_worker_context, MockStrategy(between_graph=False), cluster_spec=cluster_spec, rpc_layer=None) # There are one "None" task and one EVALUATOR task. self.assertEqual(len(self._worker_context), 2) self.assertTrue("None" in self._worker_context) self.assertTrue(EVALUATOR in self._worker_context) self.assertEqual(len(self._worker_context["None"]), 1) self.assertEqual(len(self._worker_context[EVALUATOR]), 1) # Check whether each task has the right master_target, num_workers, is_chief # and distributed_mode. self.assertEqual(self._worker_context["None"][0], (_strip_protocol( _bytes_to_str(self._workers[0].target)), 3, True, True)) self.assertEqual(self._worker_context[EVALUATOR][0], ("fake_evaluator", 3, True, False))
def _run_standalone_client(test_obj, strategy, cluster_spec): input_shape = (28, 28, 1) with strategy.scope(): orig_model = _get_model(input_shape) def worker_fn(strategy): with ops.Graph().as_default(): batch_size = 64 steps = 2 with strategy.scope(): train_ds, _ = _mnist_synthetic_dataset(batch_size, steps) model = _clone_and_build_model(orig_model, strategy) orig_loss, orig_acc = model.evaluate(train_ds, steps=steps) # Workaround for the metrics issue (b/122928955) in async training. This # can only be used in standalone client mode. dc_context.get_current_worker_context().wait_for_other_workers() model.fit(x=train_ds, epochs=2, steps_per_epoch=steps) dc_context.get_current_worker_context().wait_for_other_workers() trained_loss, trained_acc = model.evaluate(train_ds, steps=steps) test_obj.assertLessEqual(trained_loss, orig_loss) test_obj.assertGreaterEqual(trained_acc, orig_acc) dc.run_distribute_coordinator( worker_fn, strategy, mode=dc.CoordinatorMode.STANDALONE_CLIENT, cluster_spec=cluster_spec)
def testBetweenGraphContextWithChief(self): # Adds a chief node, so there are NUM_WORKERS + 1 workers in total. cluster_spec = copy.deepcopy(self._cluster_spec) cluster_spec[CHIEF] = ["fake_chief"] # Dumps the task contexts to the self._worker_context dict. distribute_coordinator.run_distribute_coordinator( self._dump_worker_context, MockStrategy(between_graph=True), cluster_spec=cluster_spec, rpc_layer="grpc") # There are one CHIEF and three workers. self.assertEqual(len(self._worker_context), 2) self.assertTrue(CHIEF in self._worker_context) self.assertTrue(WORKER in self._worker_context) self.assertEqual(len(self._worker_context[CHIEF]), 1) self.assertEqual(len(self._worker_context[WORKER]), NUM_WORKERS) # Check whether each task has the right master_target, num_workers, is_chief # and distributed_mode. self.assertEqual(self._worker_context[CHIEF][0], ("grpc://fake_chief", 4, True, True)) self.assertEqual( self._worker_context[WORKER][0], (_bytes_to_str(self._workers[0].target), NUM_WORKERS + 1, False, True)) self.assertEqual( self._worker_context[WORKER][1], (_bytes_to_str(self._workers[1].target), NUM_WORKERS + 1, False, True)) self.assertEqual( self._worker_context[WORKER][2], (_bytes_to_str(self._workers[2].target), NUM_WORKERS + 1, False, True))
def testInGraphSplitMode(self): """Test it runs in-graph replication in split client mode.""" distribute_coordinator.run_distribute_coordinator( self._in_graph_worker_fn, cluster_spec=self._cluster_spec, between_graph=False) self.assertEqual(self._result_correct, 1)
def testInGraph(self): """Test it runs in-graph replicated training correctly.""" distribute_coordinator.run_distribute_coordinator( self._in_graph_worker_fn, cluster_spec=self._cluster_spec, between_graph=False) self.assertEqual(self._result_correct, 1)
def testInGraphStandaloneMode(self): """Test it runs in-graph replication in standalone client mode.""" distribute_coordinator.run_distribute_coordinator( self._in_graph_worker_fn, MockStrategy(between_graph=False), cluster_spec=self._cluster_spec) self.assertEqual(self._result_correct, 1)
def testRpcLayerEnvironmentVariable(self): cluster_spec = {"worker": ["fake_worker"], "ps": ["fake_ps"]} tf_config = {"cluster": cluster_spec, "rpc_layer": "cake"} rpc_layer_from_coordinator = [None] def _run_mock_server(cluster_spec=None, task_type=None, task_id=None, session_config=None, rpc_layer=None, environment=None): del cluster_spec, task_type, task_id, session_config, environment rpc_layer_from_coordinator[0] = rpc_layer return MockServer() with test.mock.patch.dict( "os.environ", {"TF_CONFIG": json.dumps(tf_config)}), test.mock.patch.object( distribute_coordinator, "_run_std_server", _run_mock_server): distribute_coordinator.run_distribute_coordinator( None, MockStrategy(between_graph=True), mode=INDEPENDENT_WORKER, cluster_spec=cluster_spec, task_type="ps", task_id=0) self.assertEqual(rpc_layer_from_coordinator[0], "cake")
def _thread_fn(cluster_spec): distribute_coordinator.run_distribute_coordinator( None, None, mode=INDEPENDENT_WORKER, cluster_spec=cluster_spec, task_type="ps", task_id=0)
def _thread_fn(cluster_spec): distribute_coordinator.run_distribute_coordinator( None, MockStrategy(between_graph=True), mode=INDEPENDENT_WORKER, cluster_spec=cluster_spec, task_type="ps", task_id=0)
def testBetweenGraphWithMonitoredSession(self): """Test monitored session in standalone client mode.""" distribute_coordinator.run_distribute_coordinator( self._between_graph_with_monitored_session, MockStrategy(between_graph=True), cluster_spec=self._cluster_spec) # Each finished worker will increment self._result_correct. self.assertEqual(self._result_correct, NUM_WORKERS)
def testBetweenGraph(self): """Test it runs between-graph replication in standalone client mode.""" distribute_coordinator.run_distribute_coordinator( self._between_graph_worker_fn, MockStrategy(between_graph=True), cluster_spec=self._cluster_spec) # Each finished worker will increment self._result_correct. self.assertEqual(self._result_correct, NUM_WORKERS)
def testBetweenGraph(self): """Test it runs between-graph replicated training correctly.""" distribute_coordinator.run_distribute_coordinator( self._between_graph_worker_fn, cluster_spec=self._cluster_spec, between_graph=True) # Each finished worker will increment self._result_correct. self.assertEqual(self._result_correct, NUM_WORKERS)
def testLocalContext(self): # Dumps the task contexts to the self._task_context dict. distribute_coordinator.run_distribute_coordinator( self._dump_task_context, cluster_spec=None, between_graph=True) # There is only a "None" task. self.assertEqual(len(self._task_context), 1) self.assertTrue("None" in self._task_context) self.assertEqual(len(self._task_context["None"]), 1) # Check whether each task has the right master_target, num_workers, is_chief # and distributed_mode. self.assertEqual(self._task_context["None"][0], ("local", 0, True, False))
def test_session_config_in_session_creator(self): cluster_spec = {"worker": ["localhost:0"]} tf_config = {"cluster": cluster_spec} with test.mock.patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}): distribute_coordinator.run_distribute_coordinator( self._worker_fn, MockStrategy(between_graph=True), mode=INDEPENDENT_WORKER, cluster_spec=cluster_spec, task_type="worker", task_id=0) self.assertEqual(self._device_filters, ["/job:worker/task:0", "/job:ps"]) self.assertEqual(self._intra_op_parallelism_threads, 2) self.assertEqual(self._inter_op_parallelism_threads, 0)
def testInGraphContext(self): # Dumps the task contexts to the self._worker_context dict. distribute_coordinator.run_distribute_coordinator( self._dump_worker_context, MockStrategy(between_graph=False), cluster_spec=self._cluster_spec) # There is only a "None" task in the dumped task context. self.assertEqual(len(self._worker_context), 1) self.assertTrue("None" in self._worker_context) self.assertEqual(len(self._worker_context["None"]), 1) # Check whether each task has the right master_target, num_workers, is_chief # and distributed_mode. self.assertEqual( self._worker_context["None"][0], (_bytes_to_str(self._workers[0].target), NUM_WORKERS, True, True))
def testBetweenGraphStrategyProperties(self): # Dumps properties of the strategy objects. distribute_coordinator.run_distribute_coordinator( self._dump_strategy_property, MockStrategy(between_graph=True, should_init=True), cluster_spec=self._cluster_spec) # There is only one type of task and there three such tasks. self.assertEqual(len(self._strategy_property), 1) self.assertTrue(WORKER in self._strategy_property) self.assertEqual(len(self._strategy_property[WORKER]), NUM_WORKERS) # Check whether each task has the right properties of should_init, # should_checkpoint and should_save_summary. self.assertEqual(self._strategy_property[WORKER][0], (True, True, True)) self.assertEqual(self._strategy_property[WORKER][1], (True, False, False)) self.assertEqual(self._strategy_property[WORKER][2], (True, False, False))
def test_eval_strategy_configure(self): cluster_spec = {"evaluator": ["localhost:0"]} tf_config = {"cluster": cluster_spec} with test.mock.patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}): distribute_coordinator.run_distribute_coordinator( lambda _: None, MockStrategy(between_graph=False), eval_fn=self._worker_fn, eval_strategy=MockStrategy(between_graph=True), mode=INDEPENDENT_WORKER, cluster_spec=cluster_spec, task_type="evaluator", task_id=0) self.assertEqual(self._device_filters, ["/job:somejob"]) self.assertEqual(self._intra_op_parallelism_threads, 0) self.assertEqual(self._inter_op_parallelism_threads, 2)
def test_session_config_in_std_server(self): cluster_spec = {"worker": ["fake_worker"], "ps": ["fake_ps"]} tf_config = {"cluster": cluster_spec} with test.mock.patch.dict( "os.environ", {"TF_CONFIG": json.dumps(tf_config)}), test.mock.patch.object( distribute_coordinator, "_run_std_server", self._dump_device_filters): distribute_coordinator.run_distribute_coordinator( lambda _: None, MockStrategy(between_graph=True), mode=INDEPENDENT_WORKER, cluster_spec=cluster_spec, task_type="worker", task_id=0) self.assertEqual(self._intra_op_parallelism_threads, 1) self.assertEqual(self._inter_op_parallelism_threads, 0)
def test_session_config_in_session_creator(self): cluster_spec = {"worker": ["localhost:0"]} tf_config = {"cluster": cluster_spec} # Reset the saved Server state. distribute_coordinator._thread_local = threading.local() # pylint: disable=protected-access with test.mock.patch.dict("os.environ", {"TF_CONFIG": json.dumps(tf_config)}): distribute_coordinator.run_distribute_coordinator( self._worker_fn, MockStrategy(between_graph=True), mode=INDEPENDENT_WORKER, cluster_spec=cluster_spec, task_type="worker", task_id=0) self.assertEqual(self._device_filters, ["/job:worker/task:0", "/job:ps"]) self.assertEqual(self._intra_op_parallelism_threads, 2) self.assertEqual(self._inter_op_parallelism_threads, 0)
def wrapper(model, **kwargs): def _worker_fn(_): callbacks = kwargs.pop('callbacks', None) filtered_callbacks = dist_utils.filter_distributed_callbacks( callbacks, model) kwargs['callbacks'] = filtered_callbacks return method(model, **kwargs) return dc.run_distribute_coordinator(_worker_fn, model._distribution_strategy, mode='independent_worker')
def wrapper(model, **kwargs): def _worker_fn(_): callbacks = kwargs.pop('callbacks', None) filtered_callbacks = dist_utils.filter_distributed_callbacks(callbacks) kwargs['callbacks'] = filtered_callbacks return method(model, **kwargs) return dc.run_distribute_coordinator( _worker_fn, model._distribution_strategy, mode=dc.CoordinatorMode.INDEPENDENT_WORKER)
def testBetweenGraphContext(self): # Dumps the task contexts to the self._worker_context dict. distribute_coordinator.run_distribute_coordinator( self._dump_worker_context, MockStrategy(between_graph=True), cluster_spec=self._cluster_spec) # There is only one type of task and there three such tasks. self.assertEqual(len(self._worker_context), 1) self.assertTrue(WORKER in self._worker_context) self.assertEqual(len(self._worker_context[WORKER]), NUM_WORKERS) # Check whether each task has the right master_target, num_workers, is_chief # and distributed_mode. self.assertEqual( self._worker_context[WORKER][0], (_bytes_to_str(self._workers[0].target), NUM_WORKERS, True, True)) self.assertEqual( self._worker_context[WORKER][1], (_bytes_to_str(self._workers[1].target), NUM_WORKERS, False, True)) self.assertEqual( self._worker_context[WORKER][2], (_bytes_to_str(self._workers[2].target), NUM_WORKERS, False, True))
def _run_standalone_client(test_obj, strategy, cluster_spec): input_shape = (28, 28, 1) with strategy.scope(): orig_model = multi_worker_testing_utils.get_mnist_model(input_shape) def worker_fn(strategy): with ops.Graph().as_default(): batch_size = 64 steps = 2 with strategy.scope(): train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset( batch_size, steps) model = _clone_and_build_model(orig_model, strategy) orig_loss, orig_acc = model.evaluate(train_ds, steps=steps) # Workaround for the metrics issue (b/122928955) in async training. This # can only be used in standalone client mode. dc_context.get_current_worker_context().wait_for_other_workers( ) model.fit(x=train_ds, epochs=2, steps_per_epoch=steps) dc_context.get_current_worker_context().wait_for_other_workers( ) trained_loss, trained_acc = model.evaluate(train_ds, steps=steps) test_obj.assertLessEqual(trained_loss, orig_loss) test_obj.assertGreaterEqual(trained_acc, orig_acc) dc.run_distribute_coordinator(worker_fn, strategy, mode=dc.CoordinatorMode.STANDALONE_CLIENT, cluster_spec=cluster_spec)
def estimator_train(estimator, train_distributed_fn, hooks): """Run distribute coordinator for Estimator's `train` method.""" assert estimator._config._distribute_coordinator_mode run_config = estimator._config assert estimator._config.cluster_spec cluster_spec = multi_worker_util.normalize_cluster_spec( estimator._config.cluster_spec) assert estimator._config._train_distribute if 'evaluator' in cluster_spec.jobs: raise ValueError("'evaluator' job is not supported if you don't use " '`train_and_evaluate`') if (estimator._config._distribute_coordinator_mode != # pylint: disable=protected-access dc.CoordinatorMode.STANDALONE_CLIENT): raise ValueError( 'Only `STANDALONE_CLIENT` mode is supported when you call ' '`estimator.train`') if estimator._config._train_distribute.extended.experimental_between_graph: # TODO(yuefengz): remove this limitation once we figure out how to merge # return values from `_worker_fn`s. raise ValueError( '`Estimator.train` API is not supported for %s with ' '`STANDALONE_CLIENT` mode.' % estimator._config._train_distribute.__class__.__name__) def _worker_fn(strategy): """Function for worker task.""" local_estimator = copy.deepcopy(estimator) local_estimator._config._train_distribute = strategy context = dc_context.get_current_worker_context() _init_run_config_from_worker_context(local_estimator._config, context) logging.info('Updated config: %s', str(vars(local_estimator._config))) local_estimator._train_distribution = strategy if context.is_chief: chief_hooks = hooks else: chief_hooks = [] train_distributed_fn(local_estimator, strategy, chief_hooks) return local_estimator return dc.run_distribute_coordinator( _worker_fn, estimator._config.train_distribute, mode=run_config._distribute_coordinator_mode, cluster_spec=cluster_spec, session_config=run_config.session_config)
def estimator_train(estimator, train_distributed_fn, hooks): """Run distribute coordinator for Estimator's `train` method.""" assert estimator._config._distribute_coordinator_mode run_config = estimator._config assert estimator._config.cluster_spec cluster_spec = multi_worker_util.normalize_cluster_spec( estimator._config.cluster_spec) assert estimator._config._train_distribute if 'evaluator' in cluster_spec.jobs: raise ValueError("'evaluator' job is not supported if you don't use " '`train_and_evaluate`') if (estimator._config._distribute_coordinator_mode != # pylint: disable=protected-access dc.CoordinatorMode.STANDALONE_CLIENT): raise ValueError('Only `STANDALONE_CLIENT` mode is supported when you call ' '`estimator.train`') if estimator._config._train_distribute.extended.experimental_between_graph: # TODO(yuefengz): remove this limitation once we figure out how to merge # return values from `_worker_fn`s. raise ValueError('`Estimator.train` API is not supported for %s with ' '`STANDALONE_CLIENT` mode.' % estimator._config._train_distribute.__class__.__name__) def _worker_fn(strategy): """Function for worker task.""" local_estimator = copy.deepcopy(estimator) local_estimator._config._train_distribute = strategy context = dc_context.get_current_worker_context() _init_run_config_from_worker_context(local_estimator._config, context) logging.info('Updated config: %s', str(vars(local_estimator._config))) local_estimator._train_distribution = strategy if context.is_chief: chief_hooks = hooks else: chief_hooks = [] train_distributed_fn(local_estimator, strategy, chief_hooks) return local_estimator return dc.run_distribute_coordinator( _worker_fn, estimator._config.train_distribute, mode=run_config._distribute_coordinator_mode, cluster_spec=cluster_spec, session_config=run_config.session_config)
def fit(self, model, x=None, y=None, batch_size=None, epochs=1, verbose=1, callbacks=None, validation_split=0., validation_data=None, shuffle=True, class_weight=None, sample_weight=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None, validation_freq=1, **kwargs): # Multi-Worker mode runs the Keras training loop on multiple # servers via the Distribute Coordinator. def _worker_fn(_): """Run training inside the distributed coordinator.""" filtered_callbacks = dist_utils.filter_distributed_callbacks( callbacks) return fit_distributed(model, x=x, y=y, batch_size=batch_size, epochs=epochs, verbose=verbose, callbacks=filtered_callbacks, validation_split=validation_split, validation_data=validation_data, shuffle=shuffle, class_weight=class_weight, sample_weight=sample_weight, initial_epoch=initial_epoch, steps_per_epoch=steps_per_epoch, validation_steps=validation_steps, validation_freq=validation_freq) # Independent worker only for now. return dc.run_distribute_coordinator( _worker_fn, model._distribution_strategy, mode=dc.CoordinatorMode.INDEPENDENT_WORKER)
def train_and_evaluate(estimator, train_spec, eval_spec, executor_cls): """Run distribute coordinator for Estimator's `train_and_evaluate`. Args: estimator: An `Estimator` instance to train and evaluate. train_spec: A `TrainSpec` instance to specify the training specification. eval_spec: A `EvalSpec` instance to specify the evaluation and export specification. executor_cls: the evaluation executor class of Estimator. Raises: ValueError: if `distribute_coordinator_mode` is None in RunConfig. """ run_config = estimator.config if not run_config._distribute_coordinator_mode: # pylint: disable=protected-access raise ValueError( 'Distribute coordinator mode is not specified in `RunConfig`.') def _worker_fn(strategy): """Function for worker task.""" local_estimator = copy.deepcopy(estimator) # pylint: disable=protected-access local_estimator._config._train_distribute = strategy context = dc_context.get_current_worker_context() _init_run_config_from_worker_context(local_estimator._config, context) logging.info('Updated config: %s', str(vars(local_estimator._config))) local_estimator._train_distribution = strategy # pylint: enable=protected-access # In the standalone client, we don't need to run hooks on all threads # because logging hooks on all threads may be too much on the screen; also # tensor passed to one hook can only be fetched with the graph where the # tensor is defined. Other hooks such as checkpointing hooks will added by # MonitoredTrainingSession. # TODO(yuefengz): Is there a hook that does need to run on all threads in # standalone client mode? if (run_config._distribute_coordinator_mode == # pylint: disable=protected-access dc.CoordinatorMode.INDEPENDENT_WORKER or context.is_chief): hooks = list(train_spec.hooks) else: hooks = [] # Prevent estimator.train from calling distribute coordinator again. This # function calls estimator.train which will use distribute coordinator path # again if `_distribute_coordinator_mode` is set. local_estimator._config._distribute_coordinator_mode = None # pylint: disable=protected-access local_estimator.train( input_fn=train_spec.input_fn, max_steps=train_spec.max_steps, hooks=hooks) def _eval_fn(strategy): """Function for evaluator task.""" local_estimator = copy.deepcopy(estimator) # pylint: disable=protected-access local_estimator._config._eval_distribute = strategy _init_run_config_from_worker_context( local_estimator._config, dc_context.get_current_worker_context()) logging.info('Updated config: %s', str(vars(local_estimator._config))) local_estimator._eval_distribution = strategy # Prevent estimator.evaluate from calling distribute coordinator again. This # function calls estimator.evaluate which will use distribute coordinator # path again if `_distribute_coordinator_mode` is set. local_estimator._config._distribute_coordinator_mode = None # pylint: disable=protected-access executor = executor_cls(local_estimator, train_spec, eval_spec) executor._start_continuous_evaluation() # pylint: enable=protected-access # pylint: disable=protected-access if (run_config._distribute_coordinator_mode == dc.CoordinatorMode.STANDALONE_CLIENT): cluster_spec = run_config.cluster_spec assert cluster_spec else: # The cluster_spec comes from TF_CONFIG environment variable if it is # INDEPENDENT_WORKER mode. cluster_spec = None dc.run_distribute_coordinator( _worker_fn, run_config.train_distribute, _eval_fn, run_config.eval_distribute, mode=run_config._distribute_coordinator_mode, cluster_spec=cluster_spec, session_config=run_config.session_config)
def train_and_evaluate(estimator, train_spec, eval_spec, executor_cls): """Run distribute coordinator for Estimator's `train_and_evaluate`. Args: estimator: An `Estimator` instance to train and evaluate. train_spec: A `TrainSpec` instance to specify the training specification. eval_spec: A `EvalSpec` instance to specify the evaluation and export specification. executor_cls: the evaluation executor class of Estimator. Raises: ValueError: if `distribute_coordinator_mode` is None in RunConfig. """ run_config = estimator.config if not run_config._distribute_coordinator_mode: # pylint: disable=protected-access raise ValueError( 'Distribute coordinator mode is not specified in `RunConfig`.') def _worker_fn(strategy): """Function for worker task.""" local_estimator = copy.deepcopy(estimator) # pylint: disable=protected-access local_estimator._config._train_distribute = strategy _init_run_config_from_worker_context( local_estimator._config, dc_context.get_current_worker_context()) local_estimator._train_distribution = strategy # pylint: enable=protected-access local_estimator.train( input_fn=train_spec.input_fn, max_steps=train_spec.max_steps, hooks=list(train_spec.hooks)) def _eval_fn(strategy): """Function for evaluator task.""" local_estimator = copy.deepcopy(estimator) # pylint: disable=protected-access local_estimator._config._eval_distribute = strategy _init_run_config_from_worker_context( local_estimator._config, dc_context.get_current_worker_context()) local_estimator._eval_distribution = strategy executor = executor_cls(local_estimator, train_spec, eval_spec) executor._start_continuous_evaluation() # pylint: enable=protected-access # pylint: disable=protected-access if (run_config._distribute_coordinator_mode == dc.CoordinatorMode.STANDALONE_CLIENT): cluster_spec = run_config.cluster_spec assert cluster_spec else: # The cluster_spec comes from TF_CONFIG environment variable if it is # INDEPENDENT_WORKER mode. cluster_spec = None dc.run_distribute_coordinator( _worker_fn, run_config.train_distribute, _eval_fn, run_config.eval_distribute, mode=run_config._distribute_coordinator_mode, cluster_spec=cluster_spec, session_config=run_config.session_config)
def train_and_evaluate(estimator, train_spec, eval_spec, executor_cls): """Run distribute coordinator for Estimator's `train_and_evaluate`. Args: estimator: An `Estimator` instance to train and evaluate. train_spec: A `TrainSpec` instance to specify the training specification. eval_spec: A `EvalSpec` instance to specify the evaluation and export specification. executor_cls: the evaluation executor class of Estimator. Raises: ValueError: if `distribute_coordinator_mode` is None in RunConfig. """ run_config = estimator.config if not run_config._distribute_coordinator_mode: # pylint: disable=protected-access raise ValueError( 'Distribute coordinator mode is not specified in `RunConfig`.') def _worker_fn(strategy): """Function for worker task.""" local_estimator = copy.deepcopy(estimator) # pylint: disable=protected-access local_estimator._config._train_distribute = strategy _init_run_config_from_worker_context( local_estimator._config, dc_context.get_current_worker_context()) local_estimator._train_distribution = strategy # pylint: enable=protected-access local_estimator.train(input_fn=train_spec.input_fn, max_steps=train_spec.max_steps, hooks=list(train_spec.hooks)) def _eval_fn(strategy): """Function for evaluator task.""" local_estimator = copy.deepcopy(estimator) # pylint: disable=protected-access local_estimator._config._eval_distribute = strategy _init_run_config_from_worker_context( local_estimator._config, dc_context.get_current_worker_context()) local_estimator._eval_distribution = strategy executor = executor_cls(local_estimator, train_spec, eval_spec) executor._start_continuous_evaluation() # pylint: enable=protected-access # pylint: disable=protected-access if (run_config._distribute_coordinator_mode == dc.CoordinatorMode.STANDALONE_CLIENT): cluster_spec = run_config.cluster_spec assert cluster_spec else: # The cluster_spec comes from TF_CONFIG environment variable if it is # INDEPENDENT_WORKER mode. cluster_spec = None dc.run_distribute_coordinator(_worker_fn, run_config.train_distribute, _eval_fn, run_config.eval_distribute, mode=run_config._distribute_coordinator_mode, cluster_spec=cluster_spec, session_config=run_config.session_config)