Beispiel #1
0
    def build_subnetwork(self, features, labels, logits_dimension, training,
                         iteration_step, summary, previous_ensemble):
        model_fn = self._estimator.model_fn

        # We don't need an EVAL mode since AdaNet takes care of evaluation for us.
        mode = tf.estimator.ModeKeys.PREDICT
        if training:
            mode = tf.estimator.ModeKeys.TRAIN
        estimator_spec = model_fn(features=features,
                                  labels=labels,
                                  mode=mode,
                                  config=self._estimator.config)
        logits = self._logits_fn(estimator_spec=estimator_spec)

        self._subnetwork_train_op = TrainOpSpec(
            estimator_spec.train_op,
            chief_hooks=estimator_spec.training_chief_hooks,
            hooks=estimator_spec.training_hooks)

        # TODO: Replace with variance complexity measure.
        complexity = tf.constant(0.)
        return Subnetwork(logits=logits,
                          last_layer=logits,
                          persisted_tensors={},
                          complexity=complexity)
Beispiel #2
0
 def build_subnetwork_train_op(self, subnetwork, loss, var_list, labels,
                               iteration_step, summary, previous_ensemble):
     if self._chief_hook:
         return TrainOpSpec(train_op=tf.no_op(),
                            chief_hooks=[self._chief_hook],
                            hooks=None)
     return None
Beispiel #3
0
def _to_train_op_spec(train_op):
  if isinstance(train_op, TrainOpSpec):
    return train_op
  return TrainOpSpec(train_op)