示例#1
0
    def test_create_estimator_with_misspecified_args(self):
        hparams = _get_hparams()
        with self.assertRaises(ValueError):
            _ = tfr_estimator.EstimatorBuilder(
                _context_feature_columns,
                None,  # `document_feature_columns` is None.
                _scoring_function,
                hparams=hparams)

        with self.assertRaises(ValueError):
            _ = tfr_estimator.EstimatorBuilder(
                _context_feature_columns,
                _example_feature_columns,
                None,  # `scoring_function` is None.
                hparams=hparams)

        # Either the optimizer or the hparams["learning_rate"] should be specified.
        del hparams["learning_rate"]
        with self.assertRaises(ValueError):
            _ = tfr_estimator.EstimatorBuilder(_context_feature_columns,
                                               _example_feature_columns,
                                               _scoring_function,
                                               optimizer=None,
                                               hparams=hparams)

        # Passing an optimizer (no hparams["learning_rate"]) will slience the error.
        pip = tfr_estimator.EstimatorBuilder(
            _context_feature_columns,
            _example_feature_columns,
            _scoring_function,
            optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=0.01),
            hparams=_get_hparams())
        self.assertIsInstance(pip, tfr_estimator.EstimatorBuilder)

        # Adding "learning_rate" to hparams (no optimizer) also silences the errors.
        hparams.update(learning_rate=0.01)
        pip = tfr_estimator.EstimatorBuilder(_context_feature_columns,
                                             _example_feature_columns,
                                             _scoring_function,
                                             optimizer=None,
                                             hparams=_get_hparams())
        self.assertIsInstance(pip, tfr_estimator.EstimatorBuilder)
示例#2
0
    def test_optimizer(self):
        estimator_with_default_optimizer = self._create_default_estimator()
        self.assertIsInstance(estimator_with_default_optimizer._optimizer,
                              tf.compat.v1.train.AdagradOptimizer)

        estimator_with_adam_optimizer = tfr_estimator.EstimatorBuilder(
            _context_feature_columns(),
            _example_feature_columns(),
            _scoring_function,
            optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=0.01),
            hparams=_get_hparams())
        self.assertIsInstance(estimator_with_adam_optimizer._optimizer,
                              tf.compat.v1.train.AdamOptimizer)
示例#3
0
    def test_custom_transform_fn(self):
        estimator_with_customized_transform_fn = tfr_estimator.EstimatorBuilder(
            _context_feature_columns(),
            _example_feature_columns(),
            _scoring_function,
            transform_function=_multiply_by_two_transform_fn,
            hparams=_get_hparams())

        context, example = estimator_with_customized_transform_fn._transform_fn(
            {
                "f1": tf.ones([10, 10, 1], dtype=tf.float32),
                "f2": tf.ones([10, 10, 1], dtype=tf.float32) * 2.0,
                "f3": tf.ones([10, 10, 1], dtype=tf.float32) * 3.0,
                "c1": tf.ones([10, 1], dtype=tf.float32),
                "c2": tf.ones([10, 1], dtype=tf.float32) * 2.0,
            }, tf.estimator.ModeKeys.TRAIN)

        self.assertCountEqual(context.keys(), ["c1"])
        self.assertCountEqual(example.keys(), ["f1", "f2", "f3"])
        # By adopting `_multiply_by_two_transform_fn`, the `context` and `example`
        # tensors will be both multiplied by 2.
        self.assertAllEqual(2 * tf.ones(shape=[10, 1]), context["c1"])
        self.assertAllEqual(2 * tf.ones(shape=[10, 10, 1]), example["f1"])
示例#4
0
 def _create_default_estimator(self):
     return tfr_estimator.EstimatorBuilder(_context_feature_columns(),
                                           _example_feature_columns(),
                                           _scoring_function,
                                           hparams=_get_hparams())