def build_subnetwork(self, features, labels, logits_dimension, training, iteration_step, summary, previous_ensemble): model_fn = self._estimator.model_fn # We don't need an EVAL mode since AdaNet takes care of evaluation for us. mode = tf.estimator.ModeKeys.PREDICT if training: mode = tf.estimator.ModeKeys.TRAIN estimator_spec = model_fn(features=features, labels=labels, mode=mode, config=self._estimator.config) logits = self._logits_fn(estimator_spec=estimator_spec) self._subnetwork_train_op = TrainOpSpec( estimator_spec.train_op, chief_hooks=estimator_spec.training_chief_hooks, hooks=estimator_spec.training_hooks) # TODO: Replace with variance complexity measure. complexity = tf.constant(0.) return Subnetwork(logits=logits, last_layer=logits, persisted_tensors={}, complexity=complexity)
def build_subnetwork_train_op(self, subnetwork, loss, var_list, labels, iteration_step, summary, previous_ensemble): if self._chief_hook: return TrainOpSpec(train_op=tf.no_op(), chief_hooks=[self._chief_hook], hooks=None) return None
def _to_train_op_spec(train_op): if isinstance(train_op, TrainOpSpec): return train_op return TrainOpSpec(train_op)