Exemplo n.º 1
0
 def test_init_errors(self, max_steps):
     with self.assertRaises(ValueError):
         _IterationBuilder(_FakeCandidateBuilder(),
                           _FakeSubnetworkManager(),
                           _FakeEnsembleBuilder(),
                           summary_maker=_ScopedSummary,
                           ensemblers=[_FakeEnsembler()],
                           max_steps=max_steps)
Exemplo n.º 2
0
 def test_build_iteration_error(self,
                                ensemble_builder,
                                subnetwork_builders,
                                want_raises,
                                previous_ensemble_spec_fn=lambda: None,
                                mode=tf.estimator.ModeKeys.TRAIN,
                                summary_maker=_ScopedSummary):
     with context.graph_mode():
         builder = _IterationBuilder(_FakeCandidateBuilder(),
                                     _FakeSubnetworkManager(),
                                     ensemble_builder,
                                     summary_maker=summary_maker,
                                     ensemblers=[_FakeEnsembler()],
                                     max_steps=100)
         features = [[1., -1., 0.]]
         labels = [1]
         with self.assertRaises(want_raises):
             builder.build_iteration(
                 base_global_step=0,
                 iteration_number=0,
                 ensemble_candidates=[
                     EnsembleCandidate("test", subnetwork_builders, None)
                 ],
                 subnetwork_builders=subnetwork_builders,
                 features=features,
                 labels=labels,
                 mode=mode,
                 config=tf.estimator.RunConfig(
                     model_dir=self.test_subdirectory),
                 previous_ensemble_spec=previous_ensemble_spec_fn())
 def test_build_iteration_error(self,
                                ensemble_builder,
                                subnetwork_builders,
                                want_raises,
                                multiple_candidates=False,
                                mode=tf.estimator.ModeKeys.TRAIN,
                                summary_maker=_ScopedSummary):
   with context.graph_mode():
     tf_compat.v1.train.create_global_step()
     builder = _IterationBuilder(
         _FakeCandidateBuilder(),
         _FakeSubnetworkManager(),
         ensemble_builder,
         summary_maker=summary_maker,
         ensemblers=[_FakeEnsembler()],
         max_steps=100)
     features = [[1., -1., 0.]]
     labels = [1]
     ensemble_candidates = [
         EnsembleCandidate("test", subnetwork_builders, None)
     ]
     if multiple_candidates:
       ensemble_candidates += [
           EnsembleCandidate("test", subnetwork_builders, None)
       ]
     with self.assertRaises(want_raises):
       builder.build_iteration(
           base_global_step=0,
           iteration_number=0,
           ensemble_candidates=ensemble_candidates,
           subnetwork_builders=subnetwork_builders,
           features=features,
           labels=labels,
           mode=mode,
           config=tf.estimator.RunConfig(model_dir=self.test_subdirectory))
Exemplo n.º 4
0
 def test_build_iteration_error(self,
                                ensemble_builder,
                                subnetwork_builders,
                                want_raises,
                                previous_ensemble_spec_fn=lambda: None,
                                mode=tf.estimator.ModeKeys.TRAIN,
                                summary_maker=_ScopedSummary):
     builder = _IterationBuilder(_FakeCandidateBuilder(),
                                 _FakeSubnetworkManager(),
                                 ensemble_builder,
                                 summary_maker=summary_maker,
                                 ensemblers=[_FakeEnsembler()])
     features = [[1., -1., 0.]]
     labels = [1]
     with self.test_session():
         with self.assertRaises(want_raises):
             builder.build_iteration(
                 iteration_number=0,
                 ensemble_candidates=[
                     EnsembleCandidate("test", subnetwork_builders, None)
                 ],
                 subnetwork_builders=subnetwork_builders,
                 features=features,
                 labels=labels,
                 mode=mode,
                 previous_ensemble_spec=previous_ensemble_spec_fn())
Exemplo n.º 5
0
    def test_head_export_outputs(self, head):
        ensemble_builder = _HeadEnsembleBuilder(head)
        builder = _IterationBuilder("/tmp", _FakeCandidateBuilder(),
                                    ensemble_builder)
        features = [[1., -1., 0.]]
        labels = [1]
        mode = tf.estimator.ModeKeys.PREDICT
        iteration = builder.build_iteration(
            iteration_number=0,
            subnetwork_builders=[_FakeBuilder("test")],
            features=features,
            labels=labels,
            mode=mode)

        # Compare iteration outputs with default head outputs.
        spec = head.create_estimator_spec(features=features,
                                          labels=labels,
                                          mode=mode,
                                          logits=[[.5]])
        self.assertEqual(len(spec.export_outputs),
                         len(iteration.estimator_spec.export_outputs))
        with self.test_session() as sess:
            for key in spec.export_outputs:
                if isinstance(spec.export_outputs[key],
                              tf.estimator.export.RegressionOutput):
                    self.assertAlmostEqual(
                        sess.run(spec.export_outputs[key].value),
                        sess.run(iteration.estimator_spec.export_outputs[key].
                                 value))
                    continue
                if isinstance(spec.export_outputs[key],
                              tf.estimator.export.ClassificationOutput):
                    self.assertAllClose(
                        sess.run(spec.export_outputs[key].scores),
                        sess.run(iteration.estimator_spec.export_outputs[key].
                                 scores))
                    self.assertAllEqual(
                        sess.run(spec.export_outputs[key].classes),
                        sess.run(iteration.estimator_spec.export_outputs[key].
                                 classes))
                    continue
                if isinstance(spec.export_outputs[key],
                              tf.estimator.export.PredictOutput):
                    if "classes" in spec.export_outputs[key].outputs:
                        # Verify string Tensor outputs separately.
                        self.assertAllEqual(
                            sess.run(
                                spec.export_outputs[key].outputs["classes"]),
                            sess.run(iteration.estimator_spec.
                                     export_outputs[key].outputs["classes"]))
                        del spec.export_outputs[key].outputs["classes"]
                        del iteration.estimator_spec.export_outputs[
                            key].outputs["classes"]
                    self.assertAllClose(
                        sess.run(spec.export_outputs[key].outputs),
                        sess.run(iteration.estimator_spec.export_outputs[key].
                                 outputs))
                    continue
                self.fail("Invalid export_output for {}.".format(key))
Exemplo n.º 6
0
    def test_build_iteration(self,
                             ensemble_builder,
                             subnetwork_builders,
                             features,
                             labels,
                             want_predictions,
                             want_best_candidate_index,
                             want_eval_metric_ops=(),
                             want_is_over=False,
                             previous_ensemble_spec=lambda: None,
                             want_loss=None,
                             want_export_outputs=None,
                             mode=tf.estimator.ModeKeys.TRAIN):
        global_step = tf.train.create_global_step()
        builder = _IterationBuilder("/tmp", _FakeCandidateBuilder(),
                                    ensemble_builder)
        iteration = builder.build_iteration(
            iteration_number=0,
            subnetwork_builders=subnetwork_builders,
            features=features(),
            labels=labels(),
            mode=mode,
            previous_ensemble_spec=previous_ensemble_spec())
        with self.test_session() as sess:
            init = tf.group(tf.global_variables_initializer(),
                            tf.local_variables_initializer())
            sess.run(init)
            estimator_spec = iteration.estimator_spec
            self.assertAllClose(want_predictions,
                                sess.run(estimator_spec.predictions),
                                atol=1e-3)
            self.assertEqual(set(want_eval_metric_ops),
                             set(estimator_spec.eval_metric_ops.keys()))
            self.assertEqual(want_best_candidate_index,
                             sess.run(iteration.best_candidate_index))

            if mode == tf.estimator.ModeKeys.PREDICT:
                self.assertIsNotNone(estimator_spec.export_outputs)
                self.assertAllClose(want_export_outputs,
                                    sess.run(
                                        _export_output_tensors(
                                            estimator_spec.export_outputs)),
                                    atol=1e-3)
                self.assertEqual(iteration.estimator_spec.train_op.type,
                                 tf.no_op().type)
                self.assertIsNone(iteration.estimator_spec.loss)
                self.assertIsNotNone(want_export_outputs)
                return

            self.assertAlmostEqual(want_loss,
                                   sess.run(iteration.estimator_spec.loss),
                                   places=3)
            self.assertIsNone(iteration.estimator_spec.export_outputs)
            if mode == tf.estimator.ModeKeys.TRAIN:
                sess.run(iteration.estimator_spec.train_op)
                self.assertEqual(want_is_over,
                                 sess.run(iteration.is_over_fn()))
                self.assertEqual(1, sess.run(global_step))
                self.assertEqual(1, sess.run(iteration.step))
Exemplo n.º 7
0
 def test_build_iteration_error(self,
                                ensemble_builder,
                                subnetwork_builders,
                                want_raises,
                                previous_ensemble_spec_fn=lambda: None,
                                mode=tf.estimator.ModeKeys.TRAIN):
     builder = _IterationBuilder(_FakeCandidateBuilder(), ensemble_builder)
     features = [[1., -1., 0.]]
     labels = [1]
     with self.test_session():
         with self.assertRaises(want_raises):
             builder.build_iteration(
                 iteration_number=0,
                 subnetwork_builders=subnetwork_builders,
                 features=features,
                 labels=labels,
                 mode=mode,
                 previous_ensemble_spec=previous_ensemble_spec_fn())
Exemplo n.º 8
0
    def test_head_export_outputs(self, head):
        with context.graph_mode():
            ensemble_builder = _HeadEnsembleBuilder(head)
            builder = _IterationBuilder(_FakeCandidateBuilder(),
                                        _FakeSubnetworkManager(),
                                        ensemble_builder,
                                        summary_maker=_ScopedSummary,
                                        ensemblers=[_FakeEnsembler()],
                                        max_steps=10)
            features = [[1., -1., 0.]]
            labels = [1]
            mode = tf.estimator.ModeKeys.PREDICT
            subnetwork_builders = [_FakeBuilder("test")]
            iteration = builder.build_iteration(
                base_global_step=0,
                iteration_number=0,
                ensemble_candidates=[
                    EnsembleCandidate("test", subnetwork_builders, None)
                ],
                subnetwork_builders=subnetwork_builders,
                features=features,
                labels=labels,
                config=tf.estimator.RunConfig(
                    model_dir=self.test_subdirectory),
                mode=mode)

            # Compare iteration outputs with default head outputs.
            spec = head.create_estimator_spec(features=features,
                                              labels=labels,
                                              mode=mode,
                                              logits=[[.5]])
            self.assertEqual(len(spec.export_outputs),
                             len(iteration.estimator_spec.export_outputs))
            for key in spec.export_outputs:
                if isinstance(spec.export_outputs[key],
                              tf.estimator.export.RegressionOutput):
                    self.assertAlmostEqual(
                        self.evaluate(spec.export_outputs[key].value),
                        self.evaluate(iteration.estimator_spec.
                                      export_outputs[key].value))
                    continue
                if isinstance(spec.export_outputs[key],
                              tf.estimator.export.ClassificationOutput):
                    self.assertAllClose(
                        self.evaluate(spec.export_outputs[key].scores),
                        self.evaluate(iteration.estimator_spec.
                                      export_outputs[key].scores))
                    self.assertAllEqual(
                        self.evaluate(spec.export_outputs[key].classes),
                        self.evaluate(iteration.estimator_spec.
                                      export_outputs[key].classes))
                    continue
                if isinstance(spec.export_outputs[key],
                              tf.estimator.export.PredictOutput):
                    if "classes" in spec.export_outputs[key].outputs:
                        # Verify string Tensor outputs separately.
                        self.assertAllEqual(
                            self.evaluate(
                                spec.export_outputs[key].outputs["classes"]),
                            self.evaluate(
                                iteration.estimator_spec.export_outputs[key].
                                outputs["classes"]))
                        del spec.export_outputs[key].outputs["classes"]
                        del iteration.estimator_spec.export_outputs[
                            key].outputs["classes"]
                    if "all_classes" in spec.export_outputs[key].outputs:
                        # Verify string Tensor outputs separately.
                        self.assertAllEqual(
                            self.evaluate(spec.export_outputs[key].
                                          outputs["all_classes"]),
                            self.evaluate(
                                iteration.estimator_spec.export_outputs[key].
                                outputs["all_classes"]))
                        del spec.export_outputs[key].outputs["all_classes"]
                        del iteration.estimator_spec.export_outputs[
                            key].outputs["all_classes"]
                    self.assertAllClose(
                        self.evaluate(spec.export_outputs[key].outputs),
                        self.evaluate(iteration.estimator_spec.
                                      export_outputs[key].outputs))
                    continue
                self.fail("Invalid export_output for {}.".format(key))
Exemplo n.º 9
0
    def test_build_iteration(self,
                             ensemble_builder,
                             subnetwork_builders,
                             features,
                             labels,
                             want_predictions,
                             want_best_candidate_index,
                             want_eval_metric_ops=(),
                             previous_ensemble_spec=lambda: None,
                             want_loss=None,
                             want_export_outputs=None,
                             mode=tf.estimator.ModeKeys.TRAIN,
                             summary_maker=_ScopedSummary,
                             want_chief_hooks=False):
        with context.graph_mode():
            tf_compat.v1.train.create_global_step()
            builder = _IterationBuilder(_FakeCandidateBuilder(),
                                        _FakeSubnetworkManager(),
                                        ensemble_builder,
                                        summary_maker=summary_maker,
                                        ensemblers=[_FakeEnsembler()],
                                        max_steps=1)
            iteration = builder.build_iteration(
                base_global_step=0,
                iteration_number=0,
                ensemble_candidates=[
                    EnsembleCandidate(b.name, [b], None)
                    for b in subnetwork_builders
                ],
                subnetwork_builders=subnetwork_builders,
                features=features(),
                labels=labels(),
                mode=mode,
                config=tf.estimator.RunConfig(
                    model_dir=self.test_subdirectory),
                previous_ensemble_spec=previous_ensemble_spec())
            init = tf.group(tf_compat.v1.global_variables_initializer(),
                            tf_compat.v1.local_variables_initializer())
            self.evaluate(init)
            estimator_spec = iteration.estimator_spec
            if want_chief_hooks:
                self.assertNotEmpty(
                    iteration.estimator_spec.training_chief_hooks)
            self.assertAllClose(want_predictions,
                                self.evaluate(estimator_spec.predictions),
                                atol=1e-3)

            # A default architecture metric is always included, even if we don't
            # specify one.
            eval_metric_ops = estimator_spec.eval_metric_ops
            if "architecture/adanet/ensembles" in eval_metric_ops:
                del eval_metric_ops["architecture/adanet/ensembles"]
            self.assertEqual(set(want_eval_metric_ops),
                             set(eval_metric_ops.keys()))

            self.assertEqual(want_best_candidate_index,
                             self.evaluate(iteration.best_candidate_index))

            if mode == tf.estimator.ModeKeys.PREDICT:
                self.assertIsNotNone(estimator_spec.export_outputs)
                self.assertAllClose(want_export_outputs,
                                    self.evaluate(
                                        _export_output_tensors(
                                            estimator_spec.export_outputs)),
                                    atol=1e-3)
                self.assertIsNone(iteration.estimator_spec.train_op)
                self.assertIsNone(iteration.estimator_spec.loss)
                self.assertIsNotNone(want_export_outputs)
                return

            self.assertAlmostEqual(want_loss,
                                   self.evaluate(
                                       iteration.estimator_spec.loss),
                                   places=3)
            self.assertIsNone(iteration.estimator_spec.export_outputs)
            if mode == tf.estimator.ModeKeys.TRAIN:
                self.evaluate(iteration.estimator_spec.train_op)
Exemplo n.º 10
0
    def test_build_iteration(self,
                             ensemble_builder,
                             subnetwork_builders,
                             features,
                             labels,
                             want_predictions,
                             want_best_candidate_index,
                             want_eval_metric_ops=(),
                             want_is_over=False,
                             previous_ensemble_spec=lambda: None,
                             want_loss=None,
                             want_export_outputs=None,
                             mode=tf.estimator.ModeKeys.TRAIN,
                             summary_maker=_ScopedSummary,
                             want_chief_hooks=False):
        global_step = tf_compat.v1.train.create_global_step()
        builder = _IterationBuilder(_FakeCandidateBuilder(),
                                    _FakeSubnetworkManager(),
                                    ensemble_builder,
                                    summary_maker=summary_maker,
                                    ensemblers=[_FakeEnsembler()])
        iteration = builder.build_iteration(
            base_global_step=0,
            iteration_number=0,
            ensemble_candidates=[
                EnsembleCandidate(b.name, [b], None)
                for b in subnetwork_builders
            ],
            subnetwork_builders=subnetwork_builders,
            features=features(),
            labels=labels(),
            mode=mode,
            config=tf.estimator.RunConfig(),
            previous_ensemble_spec=previous_ensemble_spec())
        with self.test_session() as sess:
            init = tf.group(tf_compat.v1.global_variables_initializer(),
                            tf_compat.v1.local_variables_initializer())
            sess.run(init)
            estimator_spec = iteration.estimator_spec
            if want_chief_hooks:
                self.assertNotEmpty(
                    iteration.estimator_spec.training_chief_hooks)
            self.assertAllClose(want_predictions,
                                sess.run(estimator_spec.predictions),
                                atol=1e-3)
            self.assertEqual(set(want_eval_metric_ops),
                             set(estimator_spec.eval_metric_ops.keys()))
            self.assertEqual(want_best_candidate_index,
                             sess.run(iteration.best_candidate_index))

            if mode == tf.estimator.ModeKeys.PREDICT:
                self.assertIsNotNone(estimator_spec.export_outputs)
                self.assertAllClose(want_export_outputs,
                                    sess.run(
                                        _export_output_tensors(
                                            estimator_spec.export_outputs)),
                                    atol=1e-3)
                self.assertEqual(iteration.estimator_spec.train_op.type,
                                 tf.no_op().type)
                self.assertIsNone(iteration.estimator_spec.loss)
                self.assertIsNotNone(want_export_outputs)
                return

            self.assertAlmostEqual(want_loss,
                                   sess.run(iteration.estimator_spec.loss),
                                   places=3)
            self.assertIsNone(iteration.estimator_spec.export_outputs)
            if mode == tf.estimator.ModeKeys.TRAIN:
                sess.run(iteration.estimator_spec.train_op)
                self.assertEqual(want_is_over,
                                 sess.run(iteration.is_over_fn()))
                self.assertEqual(1, sess.run(global_step))
Exemplo n.º 11
0
  def __init__(self,
               head,
               subnetwork_generator,
               max_iteration_steps,
               mixture_weight_type=MixtureWeightType.SCALAR,
               mixture_weight_initializer=None,
               warm_start_mixture_weights=False,
               adanet_lambda=0.,
               adanet_beta=0.,
               evaluator=None,
               report_materializer=None,
               use_bias=False,
               replicate_ensemble_in_training=False,
               adanet_loss_decay=.9,
               worker_wait_timeout_secs=7200,
               model_dir=None,
               report_dir=None,
               config=None):
    """Initializes an `Estimator`.

    Regarding the options for `mixture_weight_type`:

    A `SCALAR` mixture weight is a rank 0 tensor. It performs an element-
    wise multiplication with its subnetwork's logits. This mixture weight
    is the simplest to learn, the quickest to train, and most likely to
    generalize well.

    A `VECTOR` mixture weight is a tensor of shape [k] where k is the
    ensemble's logits dimension as defined by `head`. It is similar to
    `SCALAR` in that it performs an element-wise multiplication with its
    subnetwork's logits, but is more flexible in learning a subnetworks's
    preferences per class.

    A `MATRIX` mixture weight is a tensor of shape [a, b] where a is the
    number of outputs from the subnetwork's `last_layer` and b is the
    number of outputs from the ensemble's `logits`. This weight
    matrix-multiplies the subnetwork's `last_layer`. This mixture weight
    offers the most flexibility and expressivity, allowing subnetworks to
    have outputs of different dimensionalities. However, it also has the
    most trainable parameters (a*b), and is therefore the most sensitive to
    learning rates and regularization.

    Args:
      head: A `tf.contrib.estimator.Head` instance for computing loss and
        evaluation metrics for every candidate.
      subnetwork_generator: The `adanet.subnetwork.Generator` which defines the
        candidate subnetworks to train and evaluate at every AdaNet iteration.
      max_iteration_steps: Total number of steps for which to train candidates
        per iteration. If `OutOfRange` or `StopIteration` occurs in the middle,
        training stops before `max_iteration_steps` steps.
      mixture_weight_type: The `adanet.MixtureWeightType` defining which mixture
        weight type to learn in the linear combination of subnetwork outputs.
      mixture_weight_initializer: The initializer for mixture_weights. When
        `None`, the default is different according to `mixture_weight_type`.
        `SCALAR` initializes to 1/N where N is the number of subnetworks in the
        ensemble giving a uniform average. `VECTOR` initializes each entry to
        1/N where N is the number of subnetworks in the ensemble giving a
        uniform average. `MATRIX` uses `tf.zeros_initializer`.
      warm_start_mixture_weights: Whether, at the beginning of an iteration, to
        initialize the mixture weights of the subnetworks from the previous
        ensemble to their learned value at the previous iteration, as opposed to
        retraining them from scratch. Takes precedence over the value for
        `mixture_weight_initializer` for subnetworks from previous iterations.
      adanet_lambda: Float multiplier 'lambda' for applying L1 regularization to
        subnetworks' mixture weights 'w' in the ensemble proportional to their
        complexity. See Equation (4) in the AdaNet paper.
      adanet_beta: Float L1 regularization multiplier 'beta' to apply equally to
        all subnetworks' weights 'w' in the ensemble regardless of their
        complexity. See Equation (4) in the AdaNet paper.
      evaluator: An `Evaluator` for comparing `Ensemble` instances in evaluation
        mode using the training set, or a holdout set. When `None`, they are
        compared using a moving average of their `Ensemble`'s AdaNet loss during
        training.
      report_materializer: A `ReportMaterializer` for materializing a
        `Builder`'s `subnetwork.Reports` into `subnetwork.MaterializedReport`s.
        These reports are made available to the Generator at the next iteration,
        so that it can adapt its search space. When `None`, the Generators'
        `generate_candidates` method will receive empty Lists for their
        `previous_ensemble_reports` and `all_reports` arguments.
      use_bias: Whether to add a bias term to the ensemble's logits. Adding a
        bias allows the ensemble to learn a shift in the data, often leading to
        more stable training and better predictions.
      replicate_ensemble_in_training: Whether to freeze a copy of the ensembled
        subnetworks' subgraphs in training mode in addition to prediction mode.
        A copy of the subnetworks' subgraphs is always saved in prediction mode
        so that at prediction time, the ensemble and composing subnetworks are
        all in prediction mode. This argument only affects the outputs of the
        frozen subnetworks in the ensemble. When `False` and during candidate
        training, the frozen subnetworks in the ensemble are in prediction mode,
        so training-only ops like dropout are not applied to them. When `True`
        and training the candidates, the frozen subnetworks will be in training
        mode as well, so they will apply training-only ops like dropout. However
        when `True`, this doubles the amount of disk space required to store the
        frozen ensembles, and increases the preparation stage between boosting
        iterations. This argument is useful for regularizing learning mixture
        weights, or for making training-only side inputs available in subsequent
        iterations. For most use-cases, this should be `False`.
      adanet_loss_decay: Float decay for the exponential-moving-average of the
        AdaNet objective throughout training. This moving average is a data-
        driven way tracking the best candidate with only the training set.
      worker_wait_timeout_secs: Float number of seconds for workers to wait for
        chief to prepare the next iteration during distributed training. This is
        needed to prevent workers waiting indefinitely for a chief that may have
        crashed or been turned down. When the timeout is exceeded, the worker
        exits the train loop. In situations where the chief job is much slower
        than the worker jobs, this timeout should be increased.
      model_dir: Directory to save model parameters, graph and etc. This can
        also be used to load checkpoints from the directory into a estimator to
        continue training a previously saved model.
      report_dir: Directory where the `adanet.subnetwork.MaterializedReport`s
        materialized by `report_materializer` would be saved. If
        `report_materializer` is None, this will not save anything. If `None` or
        empty string, defaults to "<model_dir>/report".
      config: `RunConfig` object to configure the runtime settings.

    Returns:
      An `Estimator` instance.

    Raises:
      ValueError: If `subnetwork_generator` is `None`.
      ValueError: If `max_iteration_steps` is <= 0.
    """

    # TODO: Add argument to specify how many frozen graph
    # checkpoints to keep.

    if subnetwork_generator is None:
      raise ValueError("subnetwork_generator can't be None.")
    if max_iteration_steps <= 0.:
      raise ValueError("max_iteration_steps must be > 0.")

    self._adanet_loss_decay = adanet_loss_decay

    # Overwrite superclass's assert that members are not overwritten in order
    # to overwrite public methods. Note that we are doing something that is not
    # explicitly supported by the Estimator API and may break in the future.
    tf.estimator.Estimator._assert_members_are_not_overridden = staticmethod(
        lambda _: None)

    self._ensemble_builder = _EnsembleBuilder(
        head=head,
        mixture_weight_type=mixture_weight_type,
        mixture_weight_initializer=mixture_weight_initializer,
        warm_start_mixture_weights=warm_start_mixture_weights,
        adanet_lambda=adanet_lambda,
        adanet_beta=adanet_beta,
        use_bias=use_bias)
    candidate_builder = _CandidateBuilder(
        max_steps=max_iteration_steps,
        adanet_loss_decay=self._adanet_loss_decay)
    self._iteration_builder = _IterationBuilder(candidate_builder,
                                                self._ensemble_builder)
    self._freezer = _EnsembleFreezer()
    self._evaluation_checkpoint_path = None
    self._evaluator = evaluator
    self._report_materializer = report_materializer

    self._replicate_ensemble_in_training = replicate_ensemble_in_training
    self._worker_wait_timeout_secs = worker_wait_timeout_secs

    self._evaluation_name = None

    self._inside_adanet_training_loop = False

    # This `Estimator` is responsible for bookkeeping across iterations, and
    # for training the subnetworks in both a local and distributed setting.
    # Subclassing improves future-proofing against new private methods being
    # added to `tf.estimator.Estimator` that are expected to be callable by
    # external functions, such as in b/110435640.
    super(Estimator, self).__init__(
        model_fn=self._model_fn,
        params={
            self._Keys.SUBNETWORK_GENERATOR: subnetwork_generator,
        },
        config=config,
        model_dir=model_dir)

    # This is defined after base Estimator's init so that report_accessor can
    # use the same temporary model_dir as the underlying Estimator even if
    # model_dir is not provided.
    report_dir = report_dir or os.path.join(self._model_dir, "report")
    self._report_accessor = _ReportAccessor(report_dir)