def test_run_std_server(self, mock_server): # Arrange. tf_config = { 'cluster': self._cluster_spec(), 'task': { 'type': run_config_lib.TaskType.PS, 'index': 1 } } with test.mock.patch.dict('os.environ', {'TF_CONFIG': json.dumps(tf_config)}): config = RunConfig( master='host2:2222', num_cores=15, gpu_memory_fraction=0.314, ) for est in self._estimators_for_tests(config): ex = Experiment(est, train_input_fn='train_input', eval_input_fn='eval_input') # Act. ex.run_std_server() # Assert. mock_server.assert_has_calls( [test.mock.call().start(), test.mock.call().join()])
def __init__(self, config=None, max_evals=5, eval_dict=None): self.eval_count = 0 self.fit_count = 0 self._max_evals = max_evals self.export_count = 0 self._config = config or RunConfig() self._model_dir = tempfile.mkdtemp() self._eval_dict = eval_dict tf_logging.info('Create Core Estimator')
def test_train_default_delay(self): for task_id in [0, 1, 3]: tf_config = {'task': {'index': task_id}} with test.mock.patch.dict('os.environ', {'TF_CONFIG': json.dumps(tf_config)}): config = RunConfig() for est in self._estimators_for_tests(config): ex = Experiment(est, train_input_fn='train_input', eval_input_fn='eval_input') sheep = SheepCounter() with test.mock.patch.object(time, 'time', sheep.time): with test.mock.patch.object(time, 'sleep', sheep.sleep): ex.train() self.assertAlmostEqual(task_id * 5, sheep.time(), delta=1e-4)
def __init__(self, model_fn=None, model_dir=None, config=None, params=None): # Create a run configuration. if config is None: self._config = RunConfig() logging.info("Using default config.") else: if not isinstance(config, RunConfig): raise ValueError("config must be an instance of RunConfig, " "received {}.".format(config)) self._config = config logging.info("Using config: {}".format(vars(self._config))) if (model_dir is not None) and (self._config.model_dir is not None): if model_dir != self._config.model_dir: # pylint: disable=g-doc-exception raise ValueError( "model_dir are set both in constructor and RunConfig, but with " "different values. In constructor: '{}', in RunConfig: " "'{}' ".format(model_dir, self._config.model_dir)) self._model_dir = model_dir or self._config.model_dir or generate_model_dir( ) if self._config.model_dir is None: self._config = self._config.replace(model_dir=self._model_dir) if self._config.session_config is None: self._session_config = config_pb2.ConfigProto( allow_soft_placement=True) else: self._session_config = self._config.session_config # Set device function depending if there are replicas or not. self._device_fn = _get_replica_device_setter(self._config) self._graph = None self._verify_model_fn_args(model_fn, params) self._model_fn = model_fn self._params = params or {}