Exemplo n.º 1
0
 def build_subnetwork_train_op(self, subnetwork, loss, var_list, labels,
                               iteration_step, summary, previous_ensemble):
     if self._chief_hook:
         return TrainOpSpec(train_op=tf.no_op(),
                            chief_hooks=[self._chief_hook],
                            hooks=None)
     return None
Exemplo n.º 2
0
    def build_subnetwork(self, features, labels, logits_dimension, training,
                         iteration_step, summary, previous_ensemble):
        model_fn = self._estimator.model_fn

        # We don't need an EVAL mode since AdaNet takes care of evaluation for us.
        mode = tf.estimator.ModeKeys.PREDICT
        if training:
            mode = tf.estimator.ModeKeys.TRAIN
        estimator_spec = model_fn(features=features,
                                  labels=labels,
                                  mode=mode,
                                  config=self._estimator.config)
        logits = self._logits_fn(estimator_spec=estimator_spec)

        self._subnetwork_train_op = TrainOpSpec(
            estimator_spec.train_op,
            chief_hooks=estimator_spec.training_chief_hooks,
            hooks=estimator_spec.training_hooks)

        # TODO: Replace with variance complexity measure.
        complexity = tf.constant(0.)
        return Subnetwork(logits=logits,
                          last_layer=logits,
                          persisted_tensors={},
                          complexity=complexity)
Exemplo n.º 3
0
 def build_subnetwork_train_op(self, subnetwork, loss, var_list, labels,
                               iteration_step, summary, previous_ensemble):
     optimizer = tf_compat.v1.train.GradientDescentOptimizer(
         learning_rate=self._learning_rate)
     train_op = optimizer.minimize(loss, var_list=var_list)
     if not self._subnetwork_hooks:
         return train_op
     return TrainOpSpec(train_op, self._subnetwork_chief_hooks,
                        self._subnetwork_hooks)
Exemplo n.º 4
0
 def _call_model_fn(self, features, labels, mode, summary):
   with summary.current_scope():
     model_fn = self._subestimator.estimator.model_fn
     estimator_spec = model_fn(
         features=features, labels=labels, mode=mode, config=self._config)
     logits = self._logits_fn(estimator_spec=estimator_spec)
     train_op = TrainOpSpec(
         estimator_spec.train_op,
         chief_hooks=estimator_spec.training_chief_hooks,
         hooks=estimator_spec.training_hooks)
   return logits, train_op
Exemplo n.º 5
0
    def build_subnetwork(self, features, labels, logits_dimension, training,
                         iteration_step, summary, previous_ensemble):
        # We don't need an EVAL mode since AdaNet takes care of evaluation for us.
        mode = tf.estimator.ModeKeys.PREDICT
        if training:
            mode = tf.estimator.ModeKeys.TRAIN

        # Call in template to ensure that variables are created once and reused.
        call_model_fn_template = tf.compat.v1.make_template(
            "model_fn", self._call_model_fn)
        subestimator_features, subestimator_labels = features, labels
        local_init_ops = []
        if training and self._subestimator.train_input_fn:
            # TODO: Consider tensorflow_estimator/python/estimator/util.py.
            inputs = self._subestimator.train_input_fn()
            if isinstance(inputs, (tf_compat.DatasetV1, tf_compat.DatasetV2)):
                subestimator_features, subestimator_labels = (
                    tf_compat.make_one_shot_iterator(inputs).get_next())
            else:
                subestimator_features, subestimator_labels = inputs

            # Construct subnetwork graph first because of dependencies on scope.
            _, _, bagging_train_op_spec, sub_local_init_op = call_model_fn_template(
                subestimator_features, subestimator_labels, mode, summary)
            # Graph for ensemble learning gets model_fn_1 for scope.
            logits, last_layer, _, ensemble_local_init_op = call_model_fn_template(
                features, labels, mode, summary)

            if sub_local_init_op:
                local_init_ops.append(sub_local_init_op)
            if ensemble_local_init_op:
                local_init_ops.append(ensemble_local_init_op)

            # Run train op in a hook so that exceptions can be intercepted by the
            # AdaNet framework instead of the Estimator's monitored training session.
            hooks = bagging_train_op_spec.hooks + (_SecondaryTrainOpRunnerHook(
                bagging_train_op_spec.train_op), )
            train_op_spec = TrainOpSpec(
                train_op=tf.no_op(),
                chief_hooks=bagging_train_op_spec.chief_hooks,
                hooks=hooks)
        else:
            logits, last_layer, train_op_spec, local_init_op = call_model_fn_template(
                features, labels, mode, summary)
            if local_init_op:
                local_init_ops.append(local_init_op)

        # TODO: Replace with variance complexity measure.
        complexity = tf.constant(0.)
        return Subnetwork(logits=logits,
                          last_layer=last_layer,
                          shared={"train_op": train_op_spec},
                          complexity=complexity,
                          local_init_ops=local_init_ops)
Exemplo n.º 6
0
  def _call_model_fn(self, subestimator, features, labels, mode, summary):
    with summary.current_scope():
      model_fn = subestimator.estimator.model_fn
      estimator_spec = model_fn(
          features=features, labels=labels, mode=mode, config=self._config)
      logits = self._logits_fn(estimator_spec=estimator_spec)
      last_layer = logits
      if self._last_layer_fn:
        last_layer = self._last_layer_fn(estimator_spec=estimator_spec)

      if estimator_spec.scaffold and estimator_spec.scaffold.local_init_op:
        local_init_op = estimator_spec.scaffold.local_init_op
      else:
        local_init_op = None

      train_op = TrainOpSpec(
          estimator_spec.train_op,
          chief_hooks=estimator_spec.training_chief_hooks,
          hooks=estimator_spec.training_hooks)
    return logits, last_layer, train_op, local_init_op
Exemplo n.º 7
0
def _to_train_op_spec(train_op):
    if isinstance(train_op, TrainOpSpec):
        return train_op
    return TrainOpSpec(train_op)