def _call_model_fn(self, subestimator, features, labels, mode, summary): with summary.current_scope(): model_fn = subestimator.estimator.model_fn estimator_spec = model_fn( features=features, labels=labels, mode=mode, config=self._config) logits = self._logits_fn(estimator_spec=estimator_spec) last_layer = logits if self._last_layer_fn: last_layer = self._last_layer_fn(estimator_spec=estimator_spec) if estimator_spec.scaffold and estimator_spec.scaffold.local_init_op: local_init_op = estimator_spec.scaffold.local_init_op else: local_init_op = None train_op = subnetwork_lib.TrainOpSpec( estimator_spec.train_op, chief_hooks=estimator_spec.training_chief_hooks, hooks=estimator_spec.training_hooks) return logits, last_layer, train_op, local_init_op
def _to_train_op_spec(train_op): if isinstance(train_op, subnetwork_lib.TrainOpSpec): return train_op return subnetwork_lib.TrainOpSpec(train_op)
def build_subnetwork(self, features, labels, logits_dimension, training, iteration_step, summary, previous_ensemble, config=None): # We don't need an EVAL mode since AdaNet takes care of evaluation for us. mode = tf.estimator.ModeKeys.PREDICT if training: mode = tf.estimator.ModeKeys.TRAIN # Call in template to ensure that variables are created once and reused. call_model_fn_template = tf.compat.v1.make_template( "model_fn", self._call_model_fn) subestimator_features, subestimator_labels = features, labels local_init_ops = [] subestimator = self._subestimator(config) if training and subestimator.train_input_fn: # TODO: Consider tensorflow_estimator/python/estimator/util.py. inputs = subestimator.train_input_fn() if isinstance(inputs, (tf_compat.DatasetV1, tf_compat.DatasetV2)): subestimator_features, subestimator_labels = ( tf_compat.make_one_shot_iterator(inputs).get_next()) else: subestimator_features, subestimator_labels = inputs # Construct subnetwork graph first because of dependencies on scope. _, _, bagging_train_op_spec, sub_local_init_op = call_model_fn_template( subestimator, subestimator_features, subestimator_labels, mode, summary) # Graph for ensemble learning gets model_fn_1 for scope. logits, last_layer, _, ensemble_local_init_op = call_model_fn_template( subestimator, features, labels, mode, summary) if sub_local_init_op: local_init_ops.append(sub_local_init_op) if ensemble_local_init_op: local_init_ops.append(ensemble_local_init_op) # Run train op in a hook so that exceptions can be intercepted by the # AdaNet framework instead of the Estimator's monitored training session. hooks = bagging_train_op_spec.hooks + (_SecondaryTrainOpRunnerHook( bagging_train_op_spec.train_op), ) train_op_spec = subnetwork_lib.TrainOpSpec( train_op=tf.no_op(), chief_hooks=bagging_train_op_spec.chief_hooks, hooks=hooks) else: logits, last_layer, train_op_spec, local_init_op = call_model_fn_template( subestimator, features, labels, mode, summary) if local_init_op: local_init_ops.append(local_init_op) # TODO: Replace with variance complexity measure. complexity = tf.constant(0.) return subnetwork_lib.Subnetwork(logits=logits, last_layer=last_layer, shared={"train_op": train_op_spec}, complexity=complexity, local_init_ops=local_init_ops)