コード例 #1
0
  def _call_model_fn(self, subestimator, features, labels, mode, summary):
    with summary.current_scope():
      model_fn = subestimator.estimator.model_fn
      estimator_spec = model_fn(
          features=features, labels=labels, mode=mode, config=self._config)
      logits = self._logits_fn(estimator_spec=estimator_spec)
      last_layer = logits
      if self._last_layer_fn:
        last_layer = self._last_layer_fn(estimator_spec=estimator_spec)

      if estimator_spec.scaffold and estimator_spec.scaffold.local_init_op:
        local_init_op = estimator_spec.scaffold.local_init_op
      else:
        local_init_op = None

      train_op = subnetwork_lib.TrainOpSpec(
          estimator_spec.train_op,
          chief_hooks=estimator_spec.training_chief_hooks,
          hooks=estimator_spec.training_hooks)
    return logits, last_layer, train_op, local_init_op
コード例 #2
0
def _to_train_op_spec(train_op):
    if isinstance(train_op, subnetwork_lib.TrainOpSpec):
        return train_op
    return subnetwork_lib.TrainOpSpec(train_op)
コード例 #3
0
ファイル: common.py プロジェクト: yuanyichuangzhi/adanet
    def build_subnetwork(self,
                         features,
                         labels,
                         logits_dimension,
                         training,
                         iteration_step,
                         summary,
                         previous_ensemble,
                         config=None):
        # We don't need an EVAL mode since AdaNet takes care of evaluation for us.
        mode = tf.estimator.ModeKeys.PREDICT
        if training:
            mode = tf.estimator.ModeKeys.TRAIN

        # Call in template to ensure that variables are created once and reused.
        call_model_fn_template = tf.compat.v1.make_template(
            "model_fn", self._call_model_fn)
        subestimator_features, subestimator_labels = features, labels
        local_init_ops = []
        subestimator = self._subestimator(config)
        if training and subestimator.train_input_fn:
            # TODO: Consider tensorflow_estimator/python/estimator/util.py.
            inputs = subestimator.train_input_fn()
            if isinstance(inputs, (tf_compat.DatasetV1, tf_compat.DatasetV2)):
                subestimator_features, subestimator_labels = (
                    tf_compat.make_one_shot_iterator(inputs).get_next())
            else:
                subestimator_features, subestimator_labels = inputs

            # Construct subnetwork graph first because of dependencies on scope.
            _, _, bagging_train_op_spec, sub_local_init_op = call_model_fn_template(
                subestimator, subestimator_features, subestimator_labels, mode,
                summary)
            # Graph for ensemble learning gets model_fn_1 for scope.
            logits, last_layer, _, ensemble_local_init_op = call_model_fn_template(
                subestimator, features, labels, mode, summary)

            if sub_local_init_op:
                local_init_ops.append(sub_local_init_op)
            if ensemble_local_init_op:
                local_init_ops.append(ensemble_local_init_op)

            # Run train op in a hook so that exceptions can be intercepted by the
            # AdaNet framework instead of the Estimator's monitored training session.
            hooks = bagging_train_op_spec.hooks + (_SecondaryTrainOpRunnerHook(
                bagging_train_op_spec.train_op), )
            train_op_spec = subnetwork_lib.TrainOpSpec(
                train_op=tf.no_op(),
                chief_hooks=bagging_train_op_spec.chief_hooks,
                hooks=hooks)
        else:
            logits, last_layer, train_op_spec, local_init_op = call_model_fn_template(
                subestimator, features, labels, mode, summary)
            if local_init_op:
                local_init_ops.append(local_init_op)

        # TODO: Replace with variance complexity measure.
        complexity = tf.constant(0.)
        return subnetwork_lib.Subnetwork(logits=logits,
                                         last_layer=last_layer,
                                         shared={"train_op": train_op_spec},
                                         complexity=complexity,
                                         local_init_ops=local_init_ops)