def test_ensemble_metrics(self): with context.graph_mode(): self.setup_graph() architecture = _Architecture("test_ensemble_candidate", "test_ensembler") architecture.add_subnetwork(iteration_number=0, builder_name="b_0_0") architecture.add_subnetwork(iteration_number=0, builder_name="b_0_1") architecture.add_subnetwork(iteration_number=1, builder_name="b_1_0") architecture.add_subnetwork(iteration_number=2, builder_name="b_2_0") metrics = tu.create_ensemble_metrics( self._metric_fn, features=self._features, labels=self._labels, estimator_spec=self._estimator_spec, architecture=architecture) actual = self._run_metrics(metrics.eval_metrics_tuple()) serialized_arch_proto = actual["architecture/adanet/ensembles"] expected_arch_string = b"| b_0_0 | b_0_1 | b_1_0 | b_2_0 |" self.assertIn(expected_arch_string, serialized_arch_proto)
def test_serialization_lifecycle(self): arch = _Architecture("foo", "dummy_ensembler_name", replay_indices=[1, 2]) arch.add_subnetwork(0, "linear") arch.add_subnetwork(0, "dnn") arch.add_subnetwork(1, "dnn") self.assertEqual("foo", arch.ensemble_candidate_name) self.assertEqual("dummy_ensembler_name", arch.ensembler_name) self.assertEqual(((0, ("linear", "dnn")), (1, ("dnn", ))), arch.subnetworks_grouped_by_iteration) iteration_number = 2 global_step = 100 serialized = arch.serialize(iteration_number, global_step) self.assertEqual( '{"ensemble_candidate_name": "foo", "ensembler_name": ' '"dummy_ensembler_name", "global_step": 100, "iteration_number": 2, ' '"replay_indices": [1, 2], ' '"subnetworks": [{"builder_name": "linear", "iteration_number": 0}, ' '{"builder_name": "dnn", "iteration_number": 0},' ' {"builder_name": "dnn", "iteration_number": 1}]}', serialized) deserialized_arch = _Architecture.deserialize(serialized) self.assertEqual(arch.ensemble_candidate_name, deserialized_arch.ensemble_candidate_name) self.assertEqual(arch.ensembler_name, deserialized_arch.ensembler_name) self.assertEqual(arch.subnetworks_grouped_by_iteration, deserialized_arch.subnetworks_grouped_by_iteration) self.assertEqual(global_step, deserialized_arch.global_step)
def test_serialization_lifecycle(self): arch = _Architecture() arch.add_subnetwork(0, "linear") arch.add_subnetwork(0, "dnn") arch.add_subnetwork(1, "dnn") self.assertEqual(((0, ("linear", "dnn")), (1, ("dnn", ))), arch.subnetworks) serialized = arch.serialize() self.assertEqual( b"\n\x08\x12\x06linear\n\x05\x12\x03dnn\n\x07\x08\x01\x12\x03dnn", serialized) deserialized_arch = _Architecture.deserialize(serialized) self.assertEqual(arch.subnetworks, deserialized_arch.subnetworks)
def create_ensemble_metrics(metric_fn, use_tpu=False, features=None, labels=None, estimator_spec=None, architecture=None): """Creates an instance of the _EnsembleMetrics class. Args: metric_fn: A function which should obey the following signature: - Args: can only have following three arguments in any order: * predictions: Predictions `Tensor` or dict of `Tensor` created by given `Head`. * features: Input `dict` of `Tensor` objects created by `input_fn` which is given to `estimator.evaluate` as an argument. * labels: Labels `Tensor` or dict of `Tensor` (for multi-head) created by `input_fn` which is given to `estimator.evaluate` as an argument. - Returns: Dict of metric results keyed by name. Final metrics are a union of this and `estimator`s existing metrics. If there is a name conflict between this and `estimator`s existing metrics, this will override the existing one. The values of the dict are the results of calling a metric function, namely a `(metric_tensor, update_op)` tuple. use_tpu: Whether to use TPU-specific variable sharing logic. features: Input `dict` of `Tensor` objects. labels: Labels `Tensor` or a dictionary of string label name to `Tensor` (for multi-head). estimator_spec: The `EstimatorSpec` created by a `Head` instance. architecture: `_Architecture` object. Returns: An instance of _EnsembleMetrics. """ if not estimator_spec: estimator_spec = tf_compat.v1.estimator.tpu.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.EVAL, loss=tf.constant(2.), predictions=None, eval_metrics=None) if not use_tpu: estimator_spec = estimator_spec.as_estimator_spec() if not architecture: architecture = _Architecture(None, None) metrics = _EnsembleMetrics(use_tpu=use_tpu) metrics.create_eval_metrics(features, labels, estimator_spec, metric_fn, architecture) return metrics
def test_serialization_lifecycle(self): arch = _Architecture() arch.add_subnetwork(0, "linear") arch.add_subnetwork(0, "dnn") arch.add_subnetwork(1, "dnn") self.assertEqual(((0, ("linear", "dnn")), (1, ("dnn", ))), arch.subnetworks_grouped_by_iteration) serialized = arch.serialize() self.assertEqual( '{"subnetworks": [{"builder_name": "linear", "iteration_number": 0}, ' '{"builder_name": "dnn", "iteration_number": 0}, ' '{"builder_name": "dnn", "iteration_number": 1}]}', serialized) deserialized_arch = _Architecture.deserialize(serialized) self.assertEqual(arch.subnetworks_grouped_by_iteration, deserialized_arch.subnetworks_grouped_by_iteration)
def test_ensemble_metrics(self): architecture = _Architecture("test_ensemble_candidate") architecture.add_subnetwork(iteration_number=0, builder_name="b_0_0") architecture.add_subnetwork(iteration_number=0, builder_name="b_0_1") architecture.add_subnetwork(iteration_number=1, builder_name="b_1_0") architecture.add_subnetwork(iteration_number=2, builder_name="b_2_0") metrics = _EnsembleMetrics() metrics.create_eval_metrics(self._features, self._labels, self._estimator_spec, self._metric_fn, architecture) with self.test_session() as sess: actual = _run_metrics(sess, metrics.eval_metrics_tuple()) serialized_arch_proto = actual["architecture/adanet/ensembles"] expected_arch_string = b"| b_0_0 | b_0_1 | b_1_0 | b_2_0 |" self.assertIn(expected_arch_string, serialized_arch_proto)
def build_ensemble_spec(self, name, candidate, ensembler, subnetwork_specs, summary, features, mode, iteration_number, labels=None, previous_ensemble_spec=None, my_ensemble_index=None, params=None, previous_iteration_checkpoint=None): del ensembler del subnetwork_specs del summary del iteration_number del previous_ensemble_spec del my_ensemble_index del params del previous_iteration_checkpoint logits = [[.5]] estimator_spec = self._head.create_estimator_spec(features=features, mode=mode, labels=labels, logits=logits) return _EnsembleSpec(name=name, ensemble=None, architecture=_Architecture("foo", "bar"), subnetwork_builders=candidate.subnetwork_builders, predictions=estimator_spec.predictions, step=tf.Variable(0, dtype=tf.int64), variables=[tf.Variable(1.)], loss=None, adanet_loss=.1, train_op=None, eval_metrics=None, export_outputs=estimator_spec.export_outputs)
def build_ensemble_spec(self, name, candidate, ensembler, subnetwork_specs, summary, features, mode, iteration_step, iteration_number, labels=None, previous_ensemble_spec=None): """Builds an `_EnsembleSpec` with the given `adanet.ensemble.Candidate`. Args: name: The string name of the ensemble. Typically the name of the builder that returned the given `Subnetwork`. candidate: The `adanet.ensemble.Candidate` for this spec. ensembler: The :class:`adanet.ensemble.Ensembler` to use to ensemble a group of subnetworks. subnetwork_specs: Iterable of `_SubnetworkSpecs` for this iteration. summary: A `_ScopedSummary` instance for recording ensemble summaries. features: Input `dict` of `Tensor` objects. mode: Estimator `ModeKeys` indicating training, evaluation, or inference. iteration_step: Integer `Tensor` representing the step since the beginning of the current iteration, as opposed to the global step. iteration_number: Integer current iteration number. labels: Labels `Tensor` or a dictionary of string label name to `Tensor` (for multi-head). previous_ensemble_spec: Link the rest of the `_EnsembleSpec` from iteration t-1. Used for creating the subnetwork train_op. Returns: An `_EnsembleSpec` instance. """ with tf.variable_scope("ensemble_{}".format(name)): architecture = _Architecture(candidate.name) previous_subnetworks = [] subnetwork_builders = [] previous_ensemble = None if previous_ensemble_spec: previous_ensemble = previous_ensemble_spec.ensemble previous_architecture = previous_ensemble_spec.architecture keep_indices = range(len(previous_ensemble.subnetworks)) if len(candidate.subnetwork_builders) == 1 and previous_ensemble: # Prune previous ensemble according to the subnetwork.Builder for # backwards compatibility. tf.logging.warn( "Using an `adanet.subnetwork.Builder#prune_previous_ensemble` " "is deprecated. Please use a custom `adanet.ensemble.Strategy` " "instead.") subnetwork_builder = candidate.subnetwork_builders[0] keep_indices = subnetwork_builder.prune_previous_ensemble( previous_ensemble) for i, builder in enumerate(previous_ensemble_spec.subnetwork_builders): if i not in keep_indices: continue if builder not in candidate.previous_ensemble_subnetwork_builders: continue previous_subnetworks.append(previous_ensemble.subnetworks[i]) subnetwork_builders.append(builder) architecture.add_subnetwork(*previous_architecture.subnetworks[i]) for builder in candidate.subnetwork_builders: architecture.add_subnetwork(iteration_number, builder.name) subnetwork_builders.append(builder) subnetwork_map = {s.builder.name: s.subnetwork for s in subnetwork_specs} subnetworks = [ subnetwork_map[s.name] for s in candidate.subnetwork_builders ] ensemble_scope = tf.get_variable_scope() before_var_list = tf.trainable_variables() with summary.current_scope(), _monkey_patch_context( iteration_step_scope=ensemble_scope, scoped_summary=summary, trainable_vars=[]): ensemble = ensembler.build_ensemble( subnetworks, previous_ensemble_subnetworks=previous_subnetworks, features=features, labels=labels, logits_dimension=self._head.logits_dimension, training=mode == tf.estimator.ModeKeys.TRAIN, iteration_step=iteration_step, summary=summary, previous_ensemble=previous_ensemble) ensemble_var_list = _new_trainable_variables(before_var_list) estimator_spec = _create_estimator_spec( self._head, features, labels, mode, ensemble.logits, self._use_tpu) ensemble_loss = estimator_spec.loss adanet_loss = None if mode != tf.estimator.ModeKeys.PREDICT: # TODO: Support any kind of Ensemble. Use a moving average of # their train loss for the 'adanet_loss'. if not isinstance(ensemble, ComplexityRegularized): raise ValueError( "Only ComplexityRegularized ensembles are supported.") adanet_loss = estimator_spec.loss + ensemble.complexity_regularization ensemble_metrics = _EnsembleMetrics() if mode == tf.estimator.ModeKeys.EVAL: ensemble_metrics.create_eval_metrics( features=features, labels=labels, estimator_spec=estimator_spec, metric_fn=self._metric_fn, architecture=architecture) if mode == tf.estimator.ModeKeys.TRAIN: with summary.current_scope(): summary.scalar("loss", estimator_spec.loss) # Create train ops for training subnetworks and ensembles. train_op = None if mode == tf.estimator.ModeKeys.TRAIN: # Note that these mixture weights are on top of the last_layer of the # subnetwork constructed in TRAIN mode, which means that dropout is # still applied when the mixture weights are being trained. ensemble_scope = tf.get_variable_scope() with tf.variable_scope("train_mixture_weights"): with summary.current_scope(), _monkey_patch_context( iteration_step_scope=ensemble_scope, scoped_summary=summary, trainable_vars=ensemble_var_list): # For backwards compatibility. subnetwork_builder = candidate.subnetwork_builders[0] old_train_op_fn = getattr(subnetwork_builder, "build_mixture_weights_train_op", None) if callable(old_train_op_fn): tf.logging.warn( "The `build_mixture_weights_train_op` method is deprecated. " "Please use the `Ensembler#build_train_op` instead.") train_op = _to_train_op_spec( subnetwork_builder.build_mixture_weights_train_op( loss=adanet_loss, var_list=ensemble_var_list, logits=ensemble.logits, labels=labels, iteration_step=iteration_step, summary=summary)) else: train_op = _to_train_op_spec( ensembler.build_train_op( ensemble=ensemble, loss=adanet_loss, var_list=ensemble_var_list, labels=labels, iteration_step=iteration_step, summary=summary, previous_ensemble=previous_ensemble)) return _EnsembleSpec( name=name, architecture=architecture, subnetwork_builders=subnetwork_builders, ensemble=ensemble, predictions=estimator_spec.predictions, loss=ensemble_loss, adanet_loss=adanet_loss, train_op=train_op, eval_metrics=ensemble_metrics.eval_metrics_tuple(), export_outputs=estimator_spec.export_outputs)
def dummy_ensemble_spec(name, random_seed=42, num_subnetworks=1, bias=0., loss=None, adanet_loss=None, eval_metrics=None, dict_predictions=False, export_output_key=None, subnetwork_builders=None, train_op=None): """Creates a dummy `_EnsembleSpec` instance. Args: name: _EnsembleSpec's name. random_seed: A scalar random seed. num_subnetworks: The number of fake subnetworks in this ensemble. bias: Bias value. loss: Float loss to return. When None, it's picked from a random distribution. adanet_loss: Float AdaNet loss to return. When None, it's picked from a random distribution. eval_metrics: Optional eval metrics tuple of (metric_fn, tensor args). dict_predictions: Boolean whether to return predictions as a dictionary of `Tensor` or just a single float `Tensor`. export_output_key: An `ExportOutputKeys` for faking export outputs. subnetwork_builders: List of `adanet.subnetwork.Builder` objects. train_op: A train op. Returns: A dummy `_EnsembleSpec` instance. """ if loss is None: loss = dummy_tensor([], random_seed) if adanet_loss is None: adanet_loss = dummy_tensor([], random_seed * 2) else: adanet_loss = tf.convert_to_tensor(adanet_loss) logits = dummy_tensor([], random_seed * 3) if dict_predictions: predictions = { "logits": logits, "classes": tf.cast(tf.abs(logits), dtype=tf.int64) } else: predictions = logits weighted_subnetworks = [ WeightedSubnetwork( name=name, iteration_number=1, logits=dummy_tensor([2, 1], random_seed * 4), weight=dummy_tensor([2, 1], random_seed * 4), subnetwork=Subnetwork( last_layer=dummy_tensor([1, 2], random_seed * 4), logits=dummy_tensor([2, 1], random_seed * 4), complexity=1., persisted_tensors={})) ] export_outputs = _dummy_export_outputs(export_output_key, logits, predictions) bias = tf.constant(bias) return _EnsembleSpec( name=name, ensemble=ComplexityRegularized( weighted_subnetworks=weighted_subnetworks * num_subnetworks, bias=bias, logits=logits, ), architecture=_Architecture("dummy_ensemble_candidate"), subnetwork_builders=subnetwork_builders, predictions=predictions, loss=loss, adanet_loss=adanet_loss, train_op=train_op, eval_metrics=eval_metrics, export_outputs=export_outputs)
def build_ensemble_spec(self, name, candidate, ensembler, subnetwork_specs, summary, features, mode, iteration_number, labels, my_ensemble_index, previous_ensemble_spec, previous_iteration_checkpoint): """Builds an `_EnsembleSpec` with the given `adanet.ensemble.Candidate`. Args: name: The string name of the ensemble. Typically the name of the builder that returned the given `Subnetwork`. candidate: The `adanet.ensemble.Candidate` for this spec. ensembler: The :class:`adanet.ensemble.Ensembler` to use to ensemble a group of subnetworks. subnetwork_specs: Iterable of `_SubnetworkSpecs` for this iteration. summary: A `_ScopedSummary` instance for recording ensemble summaries. features: Input `dict` of `Tensor` objects. mode: Estimator `ModeKeys` indicating training, evaluation, or inference. iteration_number: Integer current iteration number. labels: Labels `Tensor` or a dictionary of string label name to `Tensor` (for multi-head). my_ensemble_index: An integer holding the index of the ensemble in the candidates list of AdaNet. previous_ensemble_spec: Link the rest of the `_EnsembleSpec` from iteration t-1. Used for creating the subnetwork train_op. previous_iteration_checkpoint: `tf.train.Checkpoint` for iteration t-1. Returns: An `_EnsembleSpec` instance. """ with tf_compat.v1.variable_scope("ensemble_{}".format(name)): step = tf_compat.v1.get_variable( "step", shape=[], initializer=tf_compat.v1.zeros_initializer(), trainable=False, dtype=tf.int64) # Convert to tensor so that users cannot mutate it. step_tensor = tf.convert_to_tensor(value=step) with summary.current_scope(): summary.scalar("iteration_step/adanet/iteration_step", step_tensor) replay_indices = [] if previous_ensemble_spec: replay_indices = copy.copy( previous_ensemble_spec.architecture.replay_indices) if my_ensemble_index is not None: replay_indices.append(my_ensemble_index) architecture = _Architecture(candidate.name, ensembler.name, replay_indices=replay_indices) previous_subnetworks = [] previous_subnetwork_specs = [] subnetwork_builders = [] previous_ensemble = None if previous_ensemble_spec: previous_ensemble = previous_ensemble_spec.ensemble previous_architecture = previous_ensemble_spec.architecture keep_indices = range(len(previous_ensemble.subnetworks)) if len(candidate.subnetwork_builders ) == 1 and previous_ensemble: # Prune previous ensemble according to the subnetwork.Builder for # backwards compatibility. subnetwork_builder = candidate.subnetwork_builders[0] prune_previous_ensemble = getattr( subnetwork_builder, "prune_previous_ensemble", None) if callable(prune_previous_ensemble): logging.warn( "Using an `adanet.subnetwork.Builder#prune_previous_ensemble` " "is deprecated. Please use a custom `adanet.ensemble.Strategy` " "instead.") keep_indices = prune_previous_ensemble( previous_ensemble) for i, builder in enumerate( previous_ensemble_spec.subnetwork_builders): if i not in keep_indices: continue if builder not in candidate.previous_ensemble_subnetwork_builders: continue previous_subnetworks.append( previous_ensemble.subnetworks[i]) previous_subnetwork_specs.append( previous_ensemble_spec.subnetwork_specs[i]) subnetwork_builders.append(builder) architecture.add_subnetwork( *previous_architecture.subnetworks[i]) for builder in candidate.subnetwork_builders: architecture.add_subnetwork(iteration_number, builder.name) subnetwork_builders.append(builder) subnetwork_spec_map = {s.builder.name: s for s in subnetwork_specs} relevant_subnetwork_specs = [ subnetwork_spec_map[s.name] for s in candidate.subnetwork_builders ] ensemble_scope = tf_compat.v1.get_variable_scope() old_vars = _get_current_vars() with summary.current_scope(), _monkey_patch_context( iteration_step_scope=ensemble_scope, scoped_summary=summary, trainable_vars=[]): ensemble = ensembler.build_ensemble( subnetworks=[ s.subnetwork for s in relevant_subnetwork_specs ], previous_ensemble_subnetworks=previous_subnetworks, features=features, labels=labels, logits_dimension=self._head.logits_dimension, training=mode == tf.estimator.ModeKeys.TRAIN, iteration_step=step_tensor, summary=summary, previous_ensemble=previous_ensemble, previous_iteration_checkpoint=previous_iteration_checkpoint ) estimator_spec = _create_estimator_spec(self._head, features, labels, mode, ensemble.logits, self._use_tpu) ensemble_loss = estimator_spec.loss adanet_loss = None if mode != tf.estimator.ModeKeys.PREDICT: adanet_loss = estimator_spec.loss # Add ensembler specific loss if isinstance(ensemble, ensemble_lib.ComplexityRegularized): adanet_loss += ensemble.complexity_regularization predictions = estimator_spec.predictions export_outputs = estimator_spec.export_outputs if (self._export_subnetwork_logits and export_outputs and subnetwork_spec_map): first_subnetwork_logits = list( subnetwork_spec_map.values())[0].subnetwork.logits if isinstance(first_subnetwork_logits, dict): for head_name in first_subnetwork_logits.keys(): subnetwork_logits = { subnetwork_name: subnetwork_spec.subnetwork.logits[head_name] for subnetwork_name, subnetwork_spec in subnetwork_spec_map.items() } export_outputs.update({ "{}_{}".format( _EnsembleBuilder._SUBNETWORK_LOGITS_EXPORT_SIGNATURE, head_name): tf.estimator.export.PredictOutput( subnetwork_logits) }) else: subnetwork_logits = { subnetwork_name: subnetwork_spec.subnetwork.logits for subnetwork_name, subnetwork_spec in subnetwork_spec_map.items() } export_outputs.update({ _EnsembleBuilder._SUBNETWORK_LOGITS_EXPORT_SIGNATURE: tf.estimator.export.PredictOutput(subnetwork_logits) }) if (self._export_subnetwork_last_layer and export_outputs and subnetwork_spec_map and list( subnetwork_spec_map.values())[0].subnetwork.last_layer is not None): first_subnetwork_last_layer = list( subnetwork_spec_map.values())[0].subnetwork.last_layer if isinstance(first_subnetwork_last_layer, dict): for head_name in first_subnetwork_last_layer.keys(): subnetwork_last_layer = { subnetwork_name: subnetwork_spec.subnetwork.last_layer[head_name] for subnetwork_name, subnetwork_spec in subnetwork_spec_map.items() } export_outputs.update({ "{}_{}".format( _EnsembleBuilder._SUBNETWORK_LAST_LAYER_EXPORT_SIGNATURE, head_name): tf.estimator.export.PredictOutput( subnetwork_last_layer) }) else: subnetwork_last_layer = { subnetwork_name: subnetwork_spec.subnetwork.last_layer for subnetwork_name, subnetwork_spec in subnetwork_spec_map.items() } export_outputs.update({ _EnsembleBuilder._SUBNETWORK_LAST_LAYER_EXPORT_SIGNATURE: tf.estimator.export.PredictOutput( subnetwork_last_layer) }) if ensemble.predictions and predictions: predictions.update(ensemble.predictions) if ensemble.predictions and export_outputs: export_outputs.update({ k: tf.estimator.export.PredictOutput(v) for k, v in ensemble.predictions.items() }) ensemble_metrics = _EnsembleMetrics(use_tpu=self._use_tpu) if mode == tf.estimator.ModeKeys.EVAL: ensemble_metrics.create_eval_metrics( features=features, labels=labels, estimator_spec=estimator_spec, metric_fn=self._metric_fn, architecture=architecture) if mode == tf.estimator.ModeKeys.TRAIN: with summary.current_scope(): summary.scalar("loss", estimator_spec.loss) ensemble_trainable_vars = _get_current_vars( diffbase=old_vars)["trainable"] # Create train ops for training subnetworks and ensembles. train_op = None if mode == tf.estimator.ModeKeys.TRAIN: # Note that these mixture weights are on top of the last_layer of the # subnetwork constructed in TRAIN mode, which means that dropout is # still applied when the mixture weights are being trained. ensemble_scope = tf_compat.v1.get_variable_scope() with tf_compat.v1.variable_scope("train_mixture_weights"): with summary.current_scope(), _monkey_patch_context( iteration_step_scope=ensemble_scope, scoped_summary=summary, trainable_vars=ensemble_trainable_vars): # For backwards compatibility. subnetwork_builder = candidate.subnetwork_builders[0] old_train_op_fn = getattr( subnetwork_builder, "build_mixture_weights_train_op", None) if callable(old_train_op_fn): logging.warn( "The `build_mixture_weights_train_op` method is deprecated. " "Please use the `Ensembler#build_train_op` instead." ) train_op = _to_train_op_spec( subnetwork_builder. build_mixture_weights_train_op( loss=adanet_loss, var_list=ensemble_trainable_vars, logits=ensemble.logits, labels=labels, iteration_step=step_tensor, summary=summary)) else: train_op = _to_train_op_spec( ensembler.build_train_op( ensemble=ensemble, loss=adanet_loss, var_list=ensemble_trainable_vars, labels=labels, iteration_step=step_tensor, summary=summary, previous_ensemble=previous_ensemble)) new_vars = _get_current_vars(diffbase=old_vars) # Sort our dictionary by key to remove non-determinism of variable order. new_vars = collections.OrderedDict(sorted(new_vars.items())) # Combine all trainable, global and savable variables into a single list. ensemble_variables = sum(new_vars.values(), []) + [step] return _EnsembleSpec(name=name, architecture=architecture, subnetwork_builders=subnetwork_builders, subnetwork_specs=previous_subnetwork_specs + relevant_subnetwork_specs, ensemble=ensemble, predictions=predictions, step=step, variables=ensemble_variables, loss=ensemble_loss, adanet_loss=adanet_loss, train_op=train_op, eval_metrics=ensemble_metrics, export_outputs=export_outputs)
def append_new_subnetwork(self, ensemble_name, ensemble_spec, subnetwork_builder, iteration_number, iteration_step, summary, features, mode, labels=None): """Adds a `Subnetwork` to an `_EnsembleSpec`. For iteration t > 0, the ensemble is built given the `Ensemble` for t-1 and the new subnetwork to train as part of the ensemble. The `Ensemble` at iteration 0 is comprised of just the subnetwork. The subnetwork is first given a weight 'w' in a `WeightedSubnetwork` which determines its contribution to the ensemble. The subnetwork's complexity L1-regularizes this weight. Args: ensemble_name: String name of the ensemble. ensemble_spec: The recipient `_EnsembleSpec` for the `Subnetwork`. subnetwork_builder: A `adanet.Builder` instance which defines how to train the subnetwork and ensemble mixture weights. iteration_number: Integer current iteration number. iteration_step: Integer `Tensor` representing the step since the beginning of the current iteration, as opposed to the global step. summary: A `_ScopedSummary` instance for recording ensemble summaries. features: Input `dict` of `Tensor` objects. mode: Estimator's `ModeKeys`. labels: Labels `Tensor` or a dictionary of string label name to `Tensor` (for multi-head). Can be `None`. Returns: An new `EnsembleSpec` instance with the `Subnetwork` appended. """ with tf.variable_scope("ensemble_{}".format(ensemble_name)): weighted_subnetworks = [] subnetwork_index = 0 num_subnetworks = 1 ensemble = None architecture = _Architecture() if ensemble_spec: ensemble = ensemble_spec.ensemble previous_subnetworks = [ ensemble.weighted_subnetworks[index] for index in subnetwork_builder.prune_previous_ensemble(ensemble) ] num_subnetworks += len(previous_subnetworks) for weighted_subnetwork in previous_subnetworks: weight_initializer = None if self._warm_start_mixture_weights: weight_initializer = tf.contrib.framework.load_variable( self._checkpoint_dir, weighted_subnetwork.weight.op.name) with tf.variable_scope( "weighted_subnetwork_{}".format(subnetwork_index)): weighted_subnetworks.append( self._build_weighted_subnetwork( weighted_subnetwork.name, weighted_subnetwork.iteration_number, weighted_subnetwork.subnetwork, num_subnetworks, weight_initializer=weight_initializer)) architecture.add_subnetwork( weighted_subnetwork.iteration_number, weighted_subnetwork.name) subnetwork_index += 1 ensemble_scope = tf.get_variable_scope() with tf.variable_scope( "weighted_subnetwork_{}".format(subnetwork_index)): with tf.variable_scope("subnetwork"): _clear_trainable_variables() build_subnetwork = functools.partial( subnetwork_builder.build_subnetwork, features=features, logits_dimension=self._head.logits_dimension, training=mode == tf.estimator.ModeKeys.TRAIN, iteration_step=iteration_step, summary=summary, previous_ensemble=ensemble) # Check which args are in the implemented build_subnetwork method # signature for backwards compatibility. defined_args = inspect.getargspec( subnetwork_builder.build_subnetwork).args if "labels" in defined_args: build_subnetwork = functools.partial(build_subnetwork, labels=labels) with summary.current_scope(), _subnetwork_context( iteration_step_scope=ensemble_scope, scoped_summary=summary): tf.logging.info("Building subnetwork '%s'", subnetwork_builder.name) subnetwork = build_subnetwork() var_list = tf.trainable_variables() weighted_subnetworks.append( self._build_weighted_subnetwork(subnetwork_builder.name, iteration_number, subnetwork, num_subnetworks)) architecture.add_subnetwork(iteration_number, subnetwork_builder.name) if ensemble: if len(previous_subnetworks) == len( ensemble.weighted_subnetworks): bias = self._create_bias_term(weighted_subnetworks, prior=ensemble.bias) else: bias = self._create_bias_term(weighted_subnetworks) tf.logging.info( "Builder '%s' is using a subset of the subnetworks " "from the previous ensemble, so its ensemble's bias " "term will not be warm started with the previous " "ensemble's bias.", subnetwork_builder.name) else: bias = self._create_bias_term(weighted_subnetworks) return self._build_ensemble_spec( name=ensemble_name, weighted_subnetworks=weighted_subnetworks, architecture=architecture, summary=summary, bias=bias, features=features, mode=mode, iteration_step=iteration_step, labels=labels, subnetwork_builder=subnetwork_builder, var_list=var_list, previous_ensemble_spec=ensemble_spec)
def test_subnetworks(self, subnetworks, want): arch = _Architecture() for subnetwork in subnetworks: arch.add_subnetwork(*subnetwork) self.assertEqual(want, arch.subnetworks)
def test_set_and_add_replay_index(self): arch = _Architecture("foo", "dummy_ensembler_name") arch.set_replay_indices([1, 2, 3]) self.assertAllEqual([1, 2, 3], arch.replay_indices) arch.add_replay_index(4) self.assertAllEqual([1, 2, 3, 4], arch.replay_indices)
def test_subnetworks_grouped_by_iteration(self, subnetworks, want): arch = _Architecture("foo", "dummy_ensembler_name") for subnetwork in subnetworks: arch.add_subnetwork(*subnetwork) self.assertEqual(want, arch.subnetworks_grouped_by_iteration)
def test_subnetworks_grouped_by_iteration(self, subnetworks, want): arch = _Architecture() for subnetwork in subnetworks: arch.add_subnetwork(*subnetwork) self.assertEqual(want, arch.subnetworks_grouped_by_iteration)