Ejemplo n.º 1
0
 def test_init_run_config_independent_worker(self):
   # When `train_distribute` is specified and TF_CONFIG is detected, use
   # distribute coordinator with INDEPENDENT_WORKER mode.
   with test.mock.patch.dict("os.environ",
                             {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}):
     config = run_config_lib.RunConfig(
         train_distribute=mirrored_strategy.MirroredStrategy())
   self.assertEqual(config._distribute_coordinator_mode,
                    dc.CoordinatorMode.INDEPENDENT_WORKER)
  def test_complete_flow_with_mode(self, distribution):
    label_dimension = 2
    input_dimension = label_dimension
    batch_size = 10
    data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)
    data = data.reshape(batch_size, label_dimension)
    train_input_fn = self.dataset_input_fn(
        x={'x': data},
        y=data,
        batch_size=batch_size // len(distribution.worker_devices),
        shuffle=True)
    eval_input_fn = self.dataset_input_fn(
        x={'x': data},
        y=data,
        batch_size=batch_size // len(distribution.worker_devices),
        shuffle=False)
    predict_input_fn = numpy_io.numpy_input_fn(
        x={'x': data}, batch_size=batch_size, shuffle=False)

    linear_feature_columns = [
        feature_column.numeric_column('x', shape=(input_dimension,))
    ]
    dnn_feature_columns = [
        feature_column.numeric_column('x', shape=(input_dimension,))
    ]
    feature_columns = linear_feature_columns + dnn_feature_columns
    estimator = dnn_linear_combined.DNNLinearCombinedRegressor(
        linear_feature_columns=linear_feature_columns,
        dnn_hidden_units=(2, 2),
        dnn_feature_columns=dnn_feature_columns,
        label_dimension=label_dimension,
        model_dir=self._model_dir,
        # TODO(isaprykin): Work around the colocate_with error.
        dnn_optimizer=adagrad.AdagradOptimizer(0.001),
        linear_optimizer=adagrad.AdagradOptimizer(0.001),
        config=run_config.RunConfig(
            train_distribute=distribution, eval_distribute=distribution))

    num_steps = 10
    estimator.train(train_input_fn, steps=num_steps)

    scores = estimator.evaluate(eval_input_fn)
    self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP])
    self.assertIn('loss', six.iterkeys(scores))

    predictions = np.array([
        x[prediction_keys.PredictionKeys.PREDICTIONS]
        for x in estimator.predict(predict_input_fn)
    ])
    self.assertAllEqual((batch_size, label_dimension), predictions.shape)

    feature_spec = feature_column.make_parse_example_spec(feature_columns)
    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
        feature_spec)
    export_dir = estimator.export_savedmodel(tempfile.mkdtemp(),
                                             serving_input_receiver_fn)
    self.assertTrue(gfile.Exists(export_dir))
Ejemplo n.º 3
0
    def test_should_run_distribute_coordinator(self):
        """Tests that should_run_distribute_coordinator return a correct value."""
        # We don't use distribute coordinator for local training.
        self.assertFalse(
            dc_training.should_run_distribute_coordinator(
                run_config_lib.RunConfig()))

        # When `train_distribute` is not specified, don't use distribute
        # coordinator.
        with test.mock.patch.dict(
                "os.environ", {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}):
            self.assertFalse(
                dc_training.should_run_distribute_coordinator(
                    run_config_lib.RunConfig()))

        # When `train_distribute` is specified and TF_CONFIG is detected, use
        # distribute coordinator.
        with test.mock.patch.dict(
                "os.environ", {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}):
            config_with_train_distribute = run_config_lib.RunConfig(
                experimental_distribute=DistributeConfig(
                    train_distribute=mirrored_strategy.MirroredStrategy(
                        num_gpus=2)))
            config_with_eval_distribute = run_config_lib.RunConfig(
                experimental_distribute=DistributeConfig(
                    eval_distribute=mirrored_strategy.MirroredStrategy(
                        num_gpus=2)))
        self.assertTrue(
            dc_training.should_run_distribute_coordinator(
                config_with_train_distribute))
        self.assertFalse(
            dc_training.should_run_distribute_coordinator(
                config_with_eval_distribute))

        # With a master in the cluster, don't run distribute coordinator.
        with test.mock.patch.dict(
                "os.environ",
            {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_MASTER)}):
            config = run_config_lib.RunConfig(
                experimental_distribute=DistributeConfig(
                    train_distribute=mirrored_strategy.MirroredStrategy(
                        num_gpus=2)))
        self.assertFalse(dc_training.should_run_distribute_coordinator(config))
Ejemplo n.º 4
0
 def test_default_property_values(self):
     config = run_config_lib.RunConfig()
     self.assertIsNone(config.model_dir)
     self.assertIsNone(config.session_config)
     self.assertEqual(1, config.tf_random_seed)
     self.assertEqual(100, config.save_summary_steps)
     self.assertEqual(600, config.save_checkpoints_secs)
     self.assertIsNone(config.save_checkpoints_steps)
     self.assertEqual(5, config.keep_checkpoint_max)
     self.assertEqual(10000, config.keep_checkpoint_every_n_hours)
Ejemplo n.º 5
0
 def test_init_run_config_standalone_client(self):
   # When `train_distribute` is specified, TF_CONFIG is detected and
   # `experimental.remote_cluster` is set use distribute coordinator with
   # STANDALONE_CLIENT mode.
   config = run_config_lib.RunConfig(
       train_distribute=mirrored_strategy.CoreMirroredStrategy(),
       experimental_distribute=DistributeConfig(
           remote_cluster={"chief": ["fake_worker"]}))
   self.assertEqual(config._distribute_coordinator_mode,
                    dc.CoordinatorMode.STANDALONE_CLIENT)
Ejemplo n.º 6
0
 def test_default_values(self):
     self._assert_distributed_properties(
         run_config=run_config_lib.RunConfig(),
         expected_cluster_spec={},
         expected_task_type=run_config_lib.TaskType.WORKER,
         expected_task_id=0,
         expected_master='',
         expected_evaluation_master='',
         expected_is_chief=True,
         expected_num_worker_replicas=1,
         expected_num_ps_replicas=0)
Ejemplo n.º 7
0
  def test_init_run_config_none_distribute_coordinator_mode(self):
    # We don't use distribute coordinator for local training.
    config = run_config_lib.RunConfig(
        train_distribute=mirrored_strategy.CoreMirroredStrategy())
    dc_training.init_run_config(config, {})
    self.assertIsNone(config._distribute_coordinator_mode)

    # With a master in the cluster, don't run distribute coordinator.
    with test.mock.patch.dict("os.environ",
                              {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_MASTER)}):
      config = run_config_lib.RunConfig(
          train_distribute=mirrored_strategy.CoreMirroredStrategy())
      self.assertIsNone(config._distribute_coordinator_mode)

    # When `train_distribute` is not specified, don't use distribute
    # coordinator.
    with test.mock.patch.dict("os.environ",
                              {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}):
      config = run_config_lib.RunConfig()
      self.assertFalse(hasattr(config, "_distribute_coordinator_mode"))
Ejemplo n.º 8
0
    def test_replace(self):
        config = run_config_lib.RunConfig()

        with self.assertRaisesRegexp(ValueError,
                                     _NOT_SUPPORTED_REPLACE_PROPERTY_MSG):
            # master is not allowed to be replaced.
            config.replace(master=_MASTER)

        with self.assertRaisesRegexp(ValueError,
                                     _NOT_SUPPORTED_REPLACE_PROPERTY_MSG):
            config.replace(some_undefined_property=_MASTER)
Ejemplo n.º 9
0
 def test_init_invalid_values(self):
     with self.assertRaisesRegexp(ValueError, _MODEL_DIR_ERR):
         run_config_lib.RunConfig(model_dir='')
     with self.assertRaisesRegexp(ValueError, _SAVE_SUMMARY_STEPS_ERR):
         run_config_lib.RunConfig(save_summary_steps=-1)
     with self.assertRaisesRegexp(ValueError, _SAVE_CKPT_STEPS_ERR):
         run_config_lib.RunConfig(save_checkpoints_steps=-1)
     with self.assertRaisesRegexp(ValueError, _SAVE_CKPT_SECS_ERR):
         run_config_lib.RunConfig(save_checkpoints_secs=-1)
     with self.assertRaisesRegexp(ValueError, _SESSION_CONFIG_ERR):
         run_config_lib.RunConfig(session_config={})
     with self.assertRaisesRegexp(ValueError, _KEEP_CKPT_MAX_ERR):
         run_config_lib.RunConfig(keep_checkpoint_max=-1)
     with self.assertRaisesRegexp(ValueError, _KEEP_CKPT_HOURS_ERR):
         run_config_lib.RunConfig(keep_checkpoint_every_n_hours=0)
     with self.assertRaisesRegexp(ValueError, _TF_RANDOM_SEED_ERR):
         run_config_lib.RunConfig(tf_random_seed=1.0)
     with self.assertRaisesRegexp(ValueError, _DEVICE_FN_ERR):
         run_config_lib.RunConfig(device_fn=lambda x: "/cpu:0")
Ejemplo n.º 10
0
 def test_replace_with_allowed_properties(self):
     config = run_config_lib.RunConfig().replace(
         tf_random_seed=11,
         save_summary_steps=12,
         save_checkpoints_secs=14,
         session_config=15,
         keep_checkpoint_max=16,
         keep_checkpoint_every_n_hours=17)
     self.assertEqual(11, config.tf_random_seed)
     self.assertEqual(12, config.save_summary_steps)
     self.assertEqual(14, config.save_checkpoints_secs)
     self.assertEqual(15, config.session_config)
     self.assertEqual(16, config.keep_checkpoint_max)
     self.assertEqual(17, config.keep_checkpoint_every_n_hours)
Ejemplo n.º 11
0
 def do_test_multi_inputs_multi_outputs_with_input_fn(self, train_input_fn,
                                                      eval_input_fn):
   config = run_config_lib.RunConfig(
       tf_random_seed=_RANDOM_SEED,
       model_dir=self._base_dir,
       train_distribute=self._dist)
   with self.cached_session():
     model = multi_inputs_multi_outputs_model()
     est_keras = keras_lib.model_to_estimator(keras_model=model, config=config)
     baseline_eval_results = est_keras.evaluate(
         input_fn=eval_input_fn, steps=1)
     est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16)
     eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1)
     self.assertLess(eval_results['loss'], baseline_eval_results['loss'])
Ejemplo n.º 12
0
  def test_graph_initialization_global_step_and_random_seed(self):
    expected_random_seed = run_config.RunConfig().tf_random_seed
    def _model_fn(features, labels, mode):
      _, _, _ = features, labels, mode
      self.assertIsNotNone(training.get_global_step())
      self.assertEqual(expected_random_seed, ops.get_default_graph().seed)
      return model_fn_lib.EstimatorSpec(
          mode=mode,
          loss=constant_op.constant(0.),
          train_op=constant_op.constant(0.),
          predictions=constant_op.constant([[0.]]))

    est = estimator.Estimator(model_fn=_model_fn)
    est.train(dummy_input_fn, steps=1)
Ejemplo n.º 13
0
  def test_with_empty_config_and_empty_model_dir(self):
    keras_model, _, _, _, _ = get_resource_for_simple_model(
        model_type='sequential', is_evaluate=True)
    keras_model.compile(
        loss='categorical_crossentropy',
        optimizer='rmsprop',
        metrics=['mse', keras.metrics.CategoricalAccuracy()])

    with self.cached_session():
      with test.mock.patch.object(tempfile, 'mkdtemp', return_value=_TMP_DIR):
        est_keras = keras_lib.model_to_estimator(
            keras_model=keras_model,
            config=run_config_lib.RunConfig())
        self.assertEqual(est_keras._model_dir, _TMP_DIR)
def run_config():
    config = run_config_lib.RunConfig()
    if config:
        config_dict = {
            'master': config.master,
            'task_id': config.task_id,
            'num_ps_replicas': config.num_ps_replicas,
            'num_worker_replicas': config.num_worker_replicas,
            'cluster_spec': config.cluster_spec.as_dict(),
            'task_type': config.task_type,
            'is_chief': config.is_chief,
        }
        return json.dumps(config_dict)
    return ""
Ejemplo n.º 15
0
  def test_with_conflicting_model_dir_and_config(self):
    keras_model, _, _, _, _ = get_resource_for_simple_model(
        model_type='sequential', is_evaluate=True)
    keras_model.compile(
        loss='categorical_crossentropy',
        optimizer='rmsprop',
        metrics=['mse', keras.metrics.CategoricalAccuracy()])

    with self.cached_session():
      with self.assertRaisesRegexp(ValueError, '`model_dir` are set both in '
                                   'constructor and `RunConfig`'):
        keras_lib.model_to_estimator(
            keras_model=keras_model, model_dir=self._base_dir,
            config=run_config_lib.RunConfig(model_dir=_TMP_DIR))
Ejemplo n.º 16
0
    def test_save_checkpoint(self):
        empty_config = run_config_lib.RunConfig()
        self.assertEqual(600, empty_config.save_checkpoints_secs)
        self.assertIsNone(empty_config.save_checkpoints_steps)

        config_with_steps = empty_config.replace(save_checkpoints_steps=100)
        del empty_config
        self.assertEqual(100, config_with_steps.save_checkpoints_steps)
        self.assertIsNone(config_with_steps.save_checkpoints_secs)

        config_with_secs = config_with_steps.replace(save_checkpoints_secs=200)
        del config_with_steps
        self.assertEqual(200, config_with_secs.save_checkpoints_secs)
        self.assertIsNone(config_with_secs.save_checkpoints_steps)
Ejemplo n.º 17
0
    def test_init_with_allowed_properties(self):
        session_config = config_pb2.ConfigProto(allow_soft_placement=True)

        config = run_config_lib.RunConfig(tf_random_seed=11,
                                          save_summary_steps=12,
                                          save_checkpoints_secs=14,
                                          session_config=session_config,
                                          keep_checkpoint_max=16,
                                          keep_checkpoint_every_n_hours=17)
        self.assertEqual(11, config.tf_random_seed)
        self.assertEqual(12, config.save_summary_steps)
        self.assertEqual(14, config.save_checkpoints_secs)
        self.assertEqual(session_config, config.session_config)
        self.assertEqual(16, config.keep_checkpoint_max)
        self.assertEqual(17, config.keep_checkpoint_every_n_hours)
Ejemplo n.º 18
0
  def test_run_local(self):
    mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
    mock_est.config = run_config_lib.RunConfig()
    mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
    mock_eval_spec = test.mock.Mock(spec=training.EvalSpec)

    with test.mock.patch.object(training, '_TrainingExecutor') as mock_executor:
      mock_executor.return_value = self._mock_executor_instance()
      return_value = training.train_and_evaluate(
          mock_est, mock_train_spec, mock_eval_spec)

      self.assertEqual('local', return_value)
      mock_executor.assert_called_with(estimator=mock_est,
                                       train_spec=mock_train_spec,
                                       eval_spec=mock_eval_spec)
Ejemplo n.º 19
0
 def test_init_none_value(self):
     config = run_config_lib.RunConfig(tf_random_seed=None,
                                       model_dir=None,
                                       save_summary_steps=None,
                                       save_checkpoints_secs=None,
                                       save_checkpoints_steps=None,
                                       session_config=None,
                                       keep_checkpoint_max=None,
                                       keep_checkpoint_every_n_hours=None)
     self.assertIsNone(config.tf_random_seed)
     self.assertIsNone(config.model_dir)
     self.assertIsNone(config.save_summary_steps)
     self.assertIsNone(config.save_checkpoints_secs)
     self.assertIsNone(config.save_checkpoints_steps)
     self.assertIsNone(config.session_config)
     self.assertIsNone(config.keep_checkpoint_max)
     self.assertIsNone(config.keep_checkpoint_every_n_hours)
Ejemplo n.º 20
0
    def testTrain(self):

        shutil.rmtree("testlogs", True)

        opts = utils.create_ipu_config()
        utils.configure_ipu_system(opts)

        run_cfg = run_config.RunConfig()

        classifier = estimator.Estimator(model_fn=model_fn,
                                         config=run_cfg,
                                         model_dir="testlogs")

        classifier.train(input_fn=input_fn, steps=16)

        event_file = glob.glob("testlogs/event*")

        self.assertTrue(len(event_file) == 1)
Ejemplo n.º 21
0
    def test_replace_invalid_values(self):
        config = run_config_lib.RunConfig()

        with self.assertRaisesRegexp(ValueError, _MODEL_DIR_ERR):
            config.replace(model_dir='')
        with self.assertRaisesRegexp(ValueError, _SAVE_SUMMARY_STEPS_ERR):
            config.replace(save_summary_steps=-1)
        with self.assertRaisesRegexp(ValueError, _SAVE_CKPT_STEPS_ERR):
            config.replace(save_checkpoints_steps=-1)
        with self.assertRaisesRegexp(ValueError, _SAVE_CKPT_SECS_ERR):
            config.replace(save_checkpoints_secs=-1)
        with self.assertRaisesRegexp(ValueError, _SESSION_CONFIG_ERR):
            config.replace(session_config={})
        with self.assertRaisesRegexp(ValueError, _KEEP_CKPT_MAX_ERR):
            config.replace(keep_checkpoint_max=-1)
        with self.assertRaisesRegexp(ValueError, _KEEP_CKPT_HOURS_ERR):
            config.replace(keep_checkpoint_every_n_hours=0)
        with self.assertRaisesRegexp(ValueError, _TF_RANDOM_SEED_ERR):
            config.replace(tf_random_seed=1.0)
Ejemplo n.º 22
0
  def test_keras_optimizer_with_distribution_strategy(self, distribution):
    keras_model = simple_sequential_model()
    keras_model.compile(
        loss='categorical_crossentropy',
        optimizer=keras.optimizers.rmsprop(lr=0.01))

    config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED,
                                      model_dir=self._base_dir,
                                      train_distribute=distribution)
    with self.cached_session():
      est_keras = keras_lib.model_to_estimator(keras_model=keras_model,
                                               config=config)
      with self.assertRaisesRegexp(ValueError,
                                   'Only TensorFlow native optimizers are '
                                   'supported with DistributionStrategy.'):
        est_keras.train(input_fn=get_ds_train_input_fn, steps=_TRAIN_SIZE / 16)

    writer_cache.FileWriterCache.clear()
    gfile.DeleteRecursively(self._config.model_dir)
Ejemplo n.º 23
0
    def test_replace_with_allowed_properties(self):
        session_config = config_pb2.ConfigProto(allow_soft_placement=True)
        device_fn = lambda op: "/cpu:0"

        config = run_config_lib.RunConfig().replace(
            tf_random_seed=11,
            save_summary_steps=12,
            save_checkpoints_secs=14,
            session_config=session_config,
            keep_checkpoint_max=16,
            keep_checkpoint_every_n_hours=17,
            device_fn=device_fn)
        self.assertEqual(11, config.tf_random_seed)
        self.assertEqual(12, config.save_summary_steps)
        self.assertEqual(14, config.save_checkpoints_secs)
        self.assertEqual(session_config, config.session_config)
        self.assertEqual(16, config.keep_checkpoint_max)
        self.assertEqual(17, config.keep_checkpoint_every_n_hours)
        self.assertEqual(device_fn, config.device_fn)
Ejemplo n.º 24
0
  def test_train_sequential_with_distribution_strategy(self, distribution):
    keras_model = simple_sequential_model()
    keras_model.compile(
        loss='categorical_crossentropy',
        metrics=[keras.metrics.CategoricalAccuracy()],
        optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.01))
    config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED,
                                      model_dir=self._base_dir,
                                      train_distribute=distribution)
    with self.cached_session():
      est_keras = keras_lib.model_to_estimator(
          keras_model=keras_model, config=config)
      before_eval_results = est_keras.evaluate(
          input_fn=get_ds_test_input_fn, steps=1)
      est_keras.train(input_fn=get_ds_train_input_fn, steps=_TRAIN_SIZE / 16)
      after_eval_results = est_keras.evaluate(input_fn=get_ds_test_input_fn,
                                              steps=1)
      self.assertLess(after_eval_results['loss'], before_eval_results['loss'])

    writer_cache.FileWriterCache.clear()
    gfile.DeleteRecursively(self._config.model_dir)
Ejemplo n.º 25
0
  def test_train_functional_with_distribution_strategy(self):
    dist = mirrored_strategy.MirroredStrategy(
        devices=['/device:GPU:0', '/device:GPU:1'])
    keras_model = simple_functional_model()
    keras_model.compile(
        loss='categorical_crossentropy',
        optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.01))
    config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED,
                                      model_dir=self._base_dir,
                                      train_distribute=dist,
                                      eval_distribute=dist)
    with self.test_session():
      est_keras = keras_lib.model_to_estimator(
          keras_model=keras_model, config=config)
      before_eval_results = est_keras.evaluate(
          input_fn=get_ds_test_input_fn, steps=1)
      est_keras.train(input_fn=get_ds_train_input_fn, steps=_TRAIN_SIZE / 16)
      after_eval_results = est_keras.evaluate(input_fn=get_ds_test_input_fn,
                                              steps=1)
      self.assertLess(after_eval_results['loss'], before_eval_results['loss'])

    writer_cache.FileWriterCache.clear()
    gfile.DeleteRecursively(self._config.model_dir)
Ejemplo n.º 26
0
  def test_graph_initialization_global_step_and_random_seed(self):
    expected_random_seed = run_config.RunConfig().tf_random_seed
    def _model_fn(features, labels, mode):
      _, _, _ = features, labels, mode
      self.assertIsNotNone(training.get_global_step())
      self.assertEqual(expected_random_seed, ops.get_default_graph().seed)
      return model_fn_lib.EstimatorSpec(
          mode=mode,
          loss=constant_op.constant(0.),
          train_op=constant_op.constant(0.),
          predictions=constant_op.constant([[0.]]),
          export_outputs={
              'test': export_output.ClassificationOutput(
                  constant_op.constant([[0.]]))
          })

    def serving_input_receiver_fn():
      return export.ServingInputReceiver(
          {'test-features': constant_op.constant([[1], [1]])},
          array_ops.placeholder(dtype=dtypes.string))

    est = estimator.Estimator(model_fn=_model_fn)
    est.train(dummy_input_fn, steps=1)
    est.export_savedmodel(tempfile.mkdtemp(), serving_input_receiver_fn)
Ejemplo n.º 27
0
    def __init__(self, model_fn, model_dir=None, config=None, params=None):
        """Constructs an `Estimator` instance.

    Args:
      model_fn: Model function. Follows the signature:

        * Args:

          * `features`: This is the first item returned from the `input_fn`
                 passed to `train`, 'evaluate`, and `predict`. This should be a
                 single `Tensor` or `dict` of same.
          * `labels`: This is the second item returned from the `input_fn`
                 passed to `train`, 'evaluate`, and `predict`. This should be a
                 single `Tensor` or `dict` of same (for multi-head models). If
                 mode is `ModeKeys.PREDICT`, `labels=None` will be passed. If
                 the `model_fn`'s signature does not accept `mode`, the
                 `model_fn` must still be able to handle `labels=None`.
          * `mode`: Optional. Specifies if this training, evaluation or
                 prediction. See `ModeKeys`.
          * `params`: Optional `dict` of hyperparameters.  Will receive what
                 is passed to Estimator in `params` parameter. This allows
                 to configure Estimators from hyper parameter tuning.
          * `config`: Optional configuration object. Will receive what is passed
                 to Estimator in `config` parameter, or the default `config`.
                 Allows updating things in your model_fn based on configuration
                 such as `num_ps_replicas`, or `model_dir`.

        * Returns:
          `EstimatorSpec`

      model_dir: Directory to save model parameters, graph and etc. This can
        also be used to load checkpoints from the directory into a estimator to
        continue training a previously saved model. If `None`, the model_dir in
        `config` will be used if set. If both are set, they must be same. If
        both are `None`, a temporary directory will be used.
      config: Configuration object.
      params: `dict` of hyper parameters that will be passed into `model_fn`.
              Keys are names of parameters, values are basic python types.

    Raises:
      ValueError: parameters of `model_fn` don't match `params`.
      ValueError: if this is called via a subclass and if that class overrides
        a member of `Estimator`.
    """
        Estimator._assert_members_are_not_overridden(self)

        if config is None:
            self._config = run_config.RunConfig()
            logging.info('Using default config.')
        else:
            if not isinstance(config, run_config.RunConfig):
                raise ValueError(
                    'config must be an instance of RunConfig, but provided %s.'
                    % config)
            self._config = config

        # Model directory.
        if (model_dir is not None) and (self._config.model_dir is not None):
            if model_dir != self._config.model_dir:
                # pylint: disable=g-doc-exception
                raise ValueError(
                    "model_dir are set both in constructor and RunConfig, but with "
                    "different values. In constructor: '{}', in RunConfig: "
                    "'{}' ".format(model_dir, self._config.model_dir))
                # pylint: enable=g-doc-exception

        self._model_dir = model_dir or self._config.model_dir
        if self._model_dir is None:
            self._model_dir = tempfile.mkdtemp()
            logging.warning('Using temporary folder as model directory: %s',
                            self._model_dir)
        if self._config.model_dir is None:
            self._config = self._config.replace(model_dir=self._model_dir)
        logging.info('Using config: %s', str(vars(self._config)))

        if self._config.session_config is None:
            self._session_config = config_pb2.ConfigProto(
                allow_soft_placement=True)
        else:
            self._session_config = self._config.session_config

        self._device_fn = _get_replica_device_setter(self._config)

        if model_fn is None:
            raise ValueError('model_fn must be provided to Estimator.')
        _verify_model_fn_args(model_fn, params)
        self._model_fn = model_fn
        self._params = params or {}
Ejemplo n.º 28
0
 def config(tf_random_seed):
     return run_config.RunConfig().replace(tf_random_seed=tf_random_seed)
Ejemplo n.º 29
0
    def test_model_dir(self):
        empty_config = run_config_lib.RunConfig()
        self.assertIsNone(empty_config.model_dir)

        new_config = empty_config.replace(model_dir=_TEST_DIR)
        self.assertEqual(_TEST_DIR, new_config.model_dir)
Ejemplo n.º 30
0
 def test_save_checkpoint_flip_secs_to_none(self):
     config_with_secs = run_config_lib.RunConfig()
     config_without_ckpt = config_with_secs.replace(
         save_checkpoints_secs=None)
     self.assertIsNone(config_without_ckpt.save_checkpoints_steps)
     self.assertIsNone(config_without_ckpt.save_checkpoints_secs)