Ejemplo n.º 1
0
    def test_ensemble_metrics(self):
        with context.graph_mode():
            self.setup_graph()
            architecture = _Architecture("test_ensemble_candidate",
                                         "test_ensembler")
            architecture.add_subnetwork(iteration_number=0,
                                        builder_name="b_0_0")
            architecture.add_subnetwork(iteration_number=0,
                                        builder_name="b_0_1")
            architecture.add_subnetwork(iteration_number=1,
                                        builder_name="b_1_0")
            architecture.add_subnetwork(iteration_number=2,
                                        builder_name="b_2_0")

            metrics = tu.create_ensemble_metrics(
                self._metric_fn,
                features=self._features,
                labels=self._labels,
                estimator_spec=self._estimator_spec,
                architecture=architecture)

            actual = self._run_metrics(metrics.eval_metrics_tuple())

            serialized_arch_proto = actual["architecture/adanet/ensembles"]
            expected_arch_string = b"| b_0_0 | b_0_1 | b_1_0 | b_2_0 |"
            self.assertIn(expected_arch_string, serialized_arch_proto)
 def test_serialization_lifecycle(self):
     arch = _Architecture("foo",
                          "dummy_ensembler_name",
                          replay_indices=[1, 2])
     arch.add_subnetwork(0, "linear")
     arch.add_subnetwork(0, "dnn")
     arch.add_subnetwork(1, "dnn")
     self.assertEqual("foo", arch.ensemble_candidate_name)
     self.assertEqual("dummy_ensembler_name", arch.ensembler_name)
     self.assertEqual(((0, ("linear", "dnn")), (1, ("dnn", ))),
                      arch.subnetworks_grouped_by_iteration)
     iteration_number = 2
     global_step = 100
     serialized = arch.serialize(iteration_number, global_step)
     self.assertEqual(
         '{"ensemble_candidate_name": "foo", "ensembler_name": '
         '"dummy_ensembler_name", "global_step": 100, "iteration_number": 2, '
         '"replay_indices": [1, 2], '
         '"subnetworks": [{"builder_name": "linear", "iteration_number": 0}, '
         '{"builder_name": "dnn", "iteration_number": 0},'
         ' {"builder_name": "dnn", "iteration_number": 1}]}', serialized)
     deserialized_arch = _Architecture.deserialize(serialized)
     self.assertEqual(arch.ensemble_candidate_name,
                      deserialized_arch.ensemble_candidate_name)
     self.assertEqual(arch.ensembler_name, deserialized_arch.ensembler_name)
     self.assertEqual(arch.subnetworks_grouped_by_iteration,
                      deserialized_arch.subnetworks_grouped_by_iteration)
     self.assertEqual(global_step, deserialized_arch.global_step)
Ejemplo n.º 3
0
 def test_serialization_lifecycle(self):
     arch = _Architecture()
     arch.add_subnetwork(0, "linear")
     arch.add_subnetwork(0, "dnn")
     arch.add_subnetwork(1, "dnn")
     self.assertEqual(((0, ("linear", "dnn")), (1, ("dnn", ))),
                      arch.subnetworks)
     serialized = arch.serialize()
     self.assertEqual(
         b"\n\x08\x12\x06linear\n\x05\x12\x03dnn\n\x07\x08\x01\x12\x03dnn",
         serialized)
     deserialized_arch = _Architecture.deserialize(serialized)
     self.assertEqual(arch.subnetworks, deserialized_arch.subnetworks)
Ejemplo n.º 4
0
def create_ensemble_metrics(metric_fn,
                            use_tpu=False,
                            features=None,
                            labels=None,
                            estimator_spec=None,
                            architecture=None):
    """Creates an instance of the _EnsembleMetrics class.

  Args:
    metric_fn: A function which should obey the following signature:
    - Args: can only have following three arguments in any order:
        * predictions: Predictions `Tensor` or dict of `Tensor` created by given
          `Head`.
        * features: Input `dict` of `Tensor` objects created by `input_fn` which
          is given to `estimator.evaluate` as an argument.
        * labels:  Labels `Tensor` or dict of `Tensor` (for multi-head) created
          by `input_fn` which is given to `estimator.evaluate` as an argument.
      - Returns: Dict of metric results keyed by name. Final metrics are a union
        of this and `estimator`s existing metrics. If there is a name conflict
        between this and `estimator`s existing metrics, this will override the
        existing one. The values of the dict are the results of calling a metric
        function, namely a `(metric_tensor, update_op)` tuple.
    use_tpu: Whether to use TPU-specific variable sharing logic.
    features: Input `dict` of `Tensor` objects.
    labels: Labels `Tensor` or a dictionary of string label name to `Tensor`
      (for multi-head).
    estimator_spec: The `EstimatorSpec` created by a `Head` instance.
    architecture: `_Architecture` object.

  Returns:
    An instance of _EnsembleMetrics.
  """

    if not estimator_spec:
        estimator_spec = tf_compat.v1.estimator.tpu.TPUEstimatorSpec(
            mode=tf.estimator.ModeKeys.EVAL,
            loss=tf.constant(2.),
            predictions=None,
            eval_metrics=None)
        if not use_tpu:
            estimator_spec = estimator_spec.as_estimator_spec()

    if not architecture:
        architecture = _Architecture(None, None)

    metrics = _EnsembleMetrics(use_tpu=use_tpu)
    metrics.create_eval_metrics(features, labels, estimator_spec, metric_fn,
                                architecture)

    return metrics
Ejemplo n.º 5
0
 def test_serialization_lifecycle(self):
     arch = _Architecture()
     arch.add_subnetwork(0, "linear")
     arch.add_subnetwork(0, "dnn")
     arch.add_subnetwork(1, "dnn")
     self.assertEqual(((0, ("linear", "dnn")), (1, ("dnn", ))),
                      arch.subnetworks_grouped_by_iteration)
     serialized = arch.serialize()
     self.assertEqual(
         '{"subnetworks": [{"builder_name": "linear", "iteration_number": 0}, '
         '{"builder_name": "dnn", "iteration_number": 0}, '
         '{"builder_name": "dnn", "iteration_number": 1}]}', serialized)
     deserialized_arch = _Architecture.deserialize(serialized)
     self.assertEqual(arch.subnetworks_grouped_by_iteration,
                      deserialized_arch.subnetworks_grouped_by_iteration)
Ejemplo n.º 6
0
    def test_ensemble_metrics(self):
        architecture = _Architecture("test_ensemble_candidate")
        architecture.add_subnetwork(iteration_number=0, builder_name="b_0_0")
        architecture.add_subnetwork(iteration_number=0, builder_name="b_0_1")
        architecture.add_subnetwork(iteration_number=1, builder_name="b_1_0")
        architecture.add_subnetwork(iteration_number=2, builder_name="b_2_0")

        metrics = _EnsembleMetrics()
        metrics.create_eval_metrics(self._features, self._labels,
                                    self._estimator_spec, self._metric_fn,
                                    architecture)

        with self.test_session() as sess:
            actual = _run_metrics(sess, metrics.eval_metrics_tuple())

        serialized_arch_proto = actual["architecture/adanet/ensembles"]
        expected_arch_string = b"| b_0_0 | b_0_1 | b_1_0 | b_2_0 |"
        self.assertIn(expected_arch_string, serialized_arch_proto)
Ejemplo n.º 7
0
    def build_ensemble_spec(self,
                            name,
                            candidate,
                            ensembler,
                            subnetwork_specs,
                            summary,
                            features,
                            mode,
                            iteration_number,
                            labels=None,
                            previous_ensemble_spec=None,
                            my_ensemble_index=None,
                            params=None,
                            previous_iteration_checkpoint=None):
        del ensembler
        del subnetwork_specs
        del summary
        del iteration_number
        del previous_ensemble_spec
        del my_ensemble_index
        del params
        del previous_iteration_checkpoint

        logits = [[.5]]

        estimator_spec = self._head.create_estimator_spec(features=features,
                                                          mode=mode,
                                                          labels=labels,
                                                          logits=logits)
        return _EnsembleSpec(name=name,
                             ensemble=None,
                             architecture=_Architecture("foo", "bar"),
                             subnetwork_builders=candidate.subnetwork_builders,
                             predictions=estimator_spec.predictions,
                             step=tf.Variable(0, dtype=tf.int64),
                             variables=[tf.Variable(1.)],
                             loss=None,
                             adanet_loss=.1,
                             train_op=None,
                             eval_metrics=None,
                             export_outputs=estimator_spec.export_outputs)
Ejemplo n.º 8
0
  def build_ensemble_spec(self,
                          name,
                          candidate,
                          ensembler,
                          subnetwork_specs,
                          summary,
                          features,
                          mode,
                          iteration_step,
                          iteration_number,
                          labels=None,
                          previous_ensemble_spec=None):
    """Builds an `_EnsembleSpec` with the given `adanet.ensemble.Candidate`.

    Args:
      name: The string name of the ensemble. Typically the name of the builder
        that returned the given `Subnetwork`.
      candidate: The `adanet.ensemble.Candidate` for this spec.
      ensembler: The :class:`adanet.ensemble.Ensembler` to use to ensemble a
        group of subnetworks.
      subnetwork_specs: Iterable of `_SubnetworkSpecs` for this iteration.
      summary: A `_ScopedSummary` instance for recording ensemble summaries.
      features: Input `dict` of `Tensor` objects.
      mode: Estimator `ModeKeys` indicating training, evaluation, or inference.
      iteration_step: Integer `Tensor` representing the step since the beginning
        of the current iteration, as opposed to the global step.
      iteration_number: Integer current iteration number.
      labels: Labels `Tensor` or a dictionary of string label name to `Tensor`
        (for multi-head).
      previous_ensemble_spec: Link the rest of the `_EnsembleSpec` from
        iteration t-1. Used for creating the subnetwork train_op.

    Returns:
      An `_EnsembleSpec` instance.
    """

    with tf.variable_scope("ensemble_{}".format(name)):
      architecture = _Architecture(candidate.name)
      previous_subnetworks = []
      subnetwork_builders = []
      previous_ensemble = None
      if previous_ensemble_spec:
        previous_ensemble = previous_ensemble_spec.ensemble
        previous_architecture = previous_ensemble_spec.architecture
        keep_indices = range(len(previous_ensemble.subnetworks))
        if len(candidate.subnetwork_builders) == 1 and previous_ensemble:
          # Prune previous ensemble according to the subnetwork.Builder for
          # backwards compatibility.
          tf.logging.warn(
              "Using an `adanet.subnetwork.Builder#prune_previous_ensemble` "
              "is deprecated. Please use a custom `adanet.ensemble.Strategy` "
              "instead.")
          subnetwork_builder = candidate.subnetwork_builders[0]
          keep_indices = subnetwork_builder.prune_previous_ensemble(
              previous_ensemble)
        for i, builder in enumerate(previous_ensemble_spec.subnetwork_builders):
          if i not in keep_indices:
            continue
          if builder not in candidate.previous_ensemble_subnetwork_builders:
            continue
          previous_subnetworks.append(previous_ensemble.subnetworks[i])
          subnetwork_builders.append(builder)
          architecture.add_subnetwork(*previous_architecture.subnetworks[i])
      for builder in candidate.subnetwork_builders:
        architecture.add_subnetwork(iteration_number, builder.name)
        subnetwork_builders.append(builder)
      subnetwork_map = {s.builder.name: s.subnetwork for s in subnetwork_specs}
      subnetworks = [
          subnetwork_map[s.name] for s in candidate.subnetwork_builders
      ]
      ensemble_scope = tf.get_variable_scope()
      before_var_list = tf.trainable_variables()
      with summary.current_scope(), _monkey_patch_context(
          iteration_step_scope=ensemble_scope,
          scoped_summary=summary,
          trainable_vars=[]):
        ensemble = ensembler.build_ensemble(
            subnetworks,
            previous_ensemble_subnetworks=previous_subnetworks,
            features=features,
            labels=labels,
            logits_dimension=self._head.logits_dimension,
            training=mode == tf.estimator.ModeKeys.TRAIN,
            iteration_step=iteration_step,
            summary=summary,
            previous_ensemble=previous_ensemble)
      ensemble_var_list = _new_trainable_variables(before_var_list)

      estimator_spec = _create_estimator_spec(
          self._head, features, labels, mode, ensemble.logits, self._use_tpu)

      ensemble_loss = estimator_spec.loss
      adanet_loss = None
      if mode != tf.estimator.ModeKeys.PREDICT:
        # TODO: Support any kind of Ensemble. Use a moving average of
        # their train loss for the 'adanet_loss'.
        if not isinstance(ensemble, ComplexityRegularized):
          raise ValueError(
              "Only ComplexityRegularized ensembles are supported.")
        adanet_loss = estimator_spec.loss + ensemble.complexity_regularization

      ensemble_metrics = _EnsembleMetrics()
      if mode == tf.estimator.ModeKeys.EVAL:
        ensemble_metrics.create_eval_metrics(
            features=features,
            labels=labels,
            estimator_spec=estimator_spec,
            metric_fn=self._metric_fn,
            architecture=architecture)

      if mode == tf.estimator.ModeKeys.TRAIN:
        with summary.current_scope():
          summary.scalar("loss", estimator_spec.loss)

      # Create train ops for training subnetworks and ensembles.
      train_op = None
      if mode == tf.estimator.ModeKeys.TRAIN:
        # Note that these mixture weights are on top of the last_layer of the
        # subnetwork constructed in TRAIN mode, which means that dropout is
        # still applied when the mixture weights are being trained.
        ensemble_scope = tf.get_variable_scope()
        with tf.variable_scope("train_mixture_weights"):
          with summary.current_scope(), _monkey_patch_context(
              iteration_step_scope=ensemble_scope,
              scoped_summary=summary,
              trainable_vars=ensemble_var_list):
            # For backwards compatibility.
            subnetwork_builder = candidate.subnetwork_builders[0]
            old_train_op_fn = getattr(subnetwork_builder,
                                      "build_mixture_weights_train_op", None)
            if callable(old_train_op_fn):
              tf.logging.warn(
                  "The `build_mixture_weights_train_op` method is deprecated. "
                  "Please use the `Ensembler#build_train_op` instead.")
              train_op = _to_train_op_spec(
                  subnetwork_builder.build_mixture_weights_train_op(
                      loss=adanet_loss,
                      var_list=ensemble_var_list,
                      logits=ensemble.logits,
                      labels=labels,
                      iteration_step=iteration_step,
                      summary=summary))
            else:
              train_op = _to_train_op_spec(
                  ensembler.build_train_op(
                      ensemble=ensemble,
                      loss=adanet_loss,
                      var_list=ensemble_var_list,
                      labels=labels,
                      iteration_step=iteration_step,
                      summary=summary,
                      previous_ensemble=previous_ensemble))
    return _EnsembleSpec(
        name=name,
        architecture=architecture,
        subnetwork_builders=subnetwork_builders,
        ensemble=ensemble,
        predictions=estimator_spec.predictions,
        loss=ensemble_loss,
        adanet_loss=adanet_loss,
        train_op=train_op,
        eval_metrics=ensemble_metrics.eval_metrics_tuple(),
        export_outputs=estimator_spec.export_outputs)
Ejemplo n.º 9
0
def dummy_ensemble_spec(name,
                        random_seed=42,
                        num_subnetworks=1,
                        bias=0.,
                        loss=None,
                        adanet_loss=None,
                        eval_metrics=None,
                        dict_predictions=False,
                        export_output_key=None,
                        subnetwork_builders=None,
                        train_op=None):
  """Creates a dummy `_EnsembleSpec` instance.

  Args:
    name: _EnsembleSpec's name.
    random_seed: A scalar random seed.
    num_subnetworks: The number of fake subnetworks in this ensemble.
    bias: Bias value.
    loss: Float loss to return. When None, it's picked from a random
      distribution.
    adanet_loss: Float AdaNet loss to return. When None, it's picked from a
      random distribution.
    eval_metrics: Optional eval metrics tuple of (metric_fn, tensor args).
    dict_predictions: Boolean whether to return predictions as a dictionary of
      `Tensor` or just a single float `Tensor`.
    export_output_key: An `ExportOutputKeys` for faking export outputs.
    subnetwork_builders: List of `adanet.subnetwork.Builder` objects.
    train_op: A train op.

  Returns:
    A dummy `_EnsembleSpec` instance.
  """

  if loss is None:
    loss = dummy_tensor([], random_seed)

  if adanet_loss is None:
    adanet_loss = dummy_tensor([], random_seed * 2)
  else:
    adanet_loss = tf.convert_to_tensor(adanet_loss)

  logits = dummy_tensor([], random_seed * 3)
  if dict_predictions:
    predictions = {
        "logits": logits,
        "classes": tf.cast(tf.abs(logits), dtype=tf.int64)
    }
  else:
    predictions = logits
  weighted_subnetworks = [
      WeightedSubnetwork(
          name=name,
          iteration_number=1,
          logits=dummy_tensor([2, 1], random_seed * 4),
          weight=dummy_tensor([2, 1], random_seed * 4),
          subnetwork=Subnetwork(
              last_layer=dummy_tensor([1, 2], random_seed * 4),
              logits=dummy_tensor([2, 1], random_seed * 4),
              complexity=1.,
              persisted_tensors={}))
  ]

  export_outputs = _dummy_export_outputs(export_output_key, logits, predictions)
  bias = tf.constant(bias)
  return _EnsembleSpec(
      name=name,
      ensemble=ComplexityRegularized(
          weighted_subnetworks=weighted_subnetworks * num_subnetworks,
          bias=bias,
          logits=logits,
      ),
      architecture=_Architecture("dummy_ensemble_candidate"),
      subnetwork_builders=subnetwork_builders,
      predictions=predictions,
      loss=loss,
      adanet_loss=adanet_loss,
      train_op=train_op,
      eval_metrics=eval_metrics,
      export_outputs=export_outputs)
    def build_ensemble_spec(self, name, candidate, ensembler, subnetwork_specs,
                            summary, features, mode, iteration_number, labels,
                            my_ensemble_index, previous_ensemble_spec,
                            previous_iteration_checkpoint):
        """Builds an `_EnsembleSpec` with the given `adanet.ensemble.Candidate`.

    Args:
      name: The string name of the ensemble. Typically the name of the builder
        that returned the given `Subnetwork`.
      candidate: The `adanet.ensemble.Candidate` for this spec.
      ensembler: The :class:`adanet.ensemble.Ensembler` to use to ensemble a
        group of subnetworks.
      subnetwork_specs: Iterable of `_SubnetworkSpecs` for this iteration.
      summary: A `_ScopedSummary` instance for recording ensemble summaries.
      features: Input `dict` of `Tensor` objects.
      mode: Estimator `ModeKeys` indicating training, evaluation, or inference.
      iteration_number: Integer current iteration number.
      labels: Labels `Tensor` or a dictionary of string label name to `Tensor`
        (for multi-head).
      my_ensemble_index: An integer holding the index of the ensemble in the
        candidates list of AdaNet.
      previous_ensemble_spec: Link the rest of the `_EnsembleSpec` from
        iteration t-1. Used for creating the subnetwork train_op.
      previous_iteration_checkpoint: `tf.train.Checkpoint` for iteration t-1.

    Returns:
      An `_EnsembleSpec` instance.
    """

        with tf_compat.v1.variable_scope("ensemble_{}".format(name)):
            step = tf_compat.v1.get_variable(
                "step",
                shape=[],
                initializer=tf_compat.v1.zeros_initializer(),
                trainable=False,
                dtype=tf.int64)
            # Convert to tensor so that users cannot mutate it.
            step_tensor = tf.convert_to_tensor(value=step)
            with summary.current_scope():
                summary.scalar("iteration_step/adanet/iteration_step",
                               step_tensor)
            replay_indices = []
            if previous_ensemble_spec:
                replay_indices = copy.copy(
                    previous_ensemble_spec.architecture.replay_indices)
            if my_ensemble_index is not None:
                replay_indices.append(my_ensemble_index)

            architecture = _Architecture(candidate.name,
                                         ensembler.name,
                                         replay_indices=replay_indices)
            previous_subnetworks = []
            previous_subnetwork_specs = []
            subnetwork_builders = []
            previous_ensemble = None
            if previous_ensemble_spec:
                previous_ensemble = previous_ensemble_spec.ensemble
                previous_architecture = previous_ensemble_spec.architecture
                keep_indices = range(len(previous_ensemble.subnetworks))
                if len(candidate.subnetwork_builders
                       ) == 1 and previous_ensemble:
                    # Prune previous ensemble according to the subnetwork.Builder for
                    # backwards compatibility.
                    subnetwork_builder = candidate.subnetwork_builders[0]
                    prune_previous_ensemble = getattr(
                        subnetwork_builder, "prune_previous_ensemble", None)
                    if callable(prune_previous_ensemble):
                        logging.warn(
                            "Using an `adanet.subnetwork.Builder#prune_previous_ensemble` "
                            "is deprecated. Please use a custom `adanet.ensemble.Strategy` "
                            "instead.")
                        keep_indices = prune_previous_ensemble(
                            previous_ensemble)
                for i, builder in enumerate(
                        previous_ensemble_spec.subnetwork_builders):
                    if i not in keep_indices:
                        continue
                    if builder not in candidate.previous_ensemble_subnetwork_builders:
                        continue
                    previous_subnetworks.append(
                        previous_ensemble.subnetworks[i])
                    previous_subnetwork_specs.append(
                        previous_ensemble_spec.subnetwork_specs[i])
                    subnetwork_builders.append(builder)
                    architecture.add_subnetwork(
                        *previous_architecture.subnetworks[i])
            for builder in candidate.subnetwork_builders:
                architecture.add_subnetwork(iteration_number, builder.name)
                subnetwork_builders.append(builder)
            subnetwork_spec_map = {s.builder.name: s for s in subnetwork_specs}
            relevant_subnetwork_specs = [
                subnetwork_spec_map[s.name]
                for s in candidate.subnetwork_builders
            ]
            ensemble_scope = tf_compat.v1.get_variable_scope()

            old_vars = _get_current_vars()

            with summary.current_scope(), _monkey_patch_context(
                    iteration_step_scope=ensemble_scope,
                    scoped_summary=summary,
                    trainable_vars=[]):
                ensemble = ensembler.build_ensemble(
                    subnetworks=[
                        s.subnetwork for s in relevant_subnetwork_specs
                    ],
                    previous_ensemble_subnetworks=previous_subnetworks,
                    features=features,
                    labels=labels,
                    logits_dimension=self._head.logits_dimension,
                    training=mode == tf.estimator.ModeKeys.TRAIN,
                    iteration_step=step_tensor,
                    summary=summary,
                    previous_ensemble=previous_ensemble,
                    previous_iteration_checkpoint=previous_iteration_checkpoint
                )

            estimator_spec = _create_estimator_spec(self._head, features,
                                                    labels, mode,
                                                    ensemble.logits,
                                                    self._use_tpu)

            ensemble_loss = estimator_spec.loss
            adanet_loss = None
            if mode != tf.estimator.ModeKeys.PREDICT:
                adanet_loss = estimator_spec.loss
                # Add ensembler specific loss
                if isinstance(ensemble, ensemble_lib.ComplexityRegularized):
                    adanet_loss += ensemble.complexity_regularization

            predictions = estimator_spec.predictions
            export_outputs = estimator_spec.export_outputs

            if (self._export_subnetwork_logits and export_outputs
                    and subnetwork_spec_map):
                first_subnetwork_logits = list(
                    subnetwork_spec_map.values())[0].subnetwork.logits
                if isinstance(first_subnetwork_logits, dict):
                    for head_name in first_subnetwork_logits.keys():
                        subnetwork_logits = {
                            subnetwork_name:
                            subnetwork_spec.subnetwork.logits[head_name]
                            for subnetwork_name, subnetwork_spec in
                            subnetwork_spec_map.items()
                        }
                        export_outputs.update({
                            "{}_{}".format(
                                _EnsembleBuilder._SUBNETWORK_LOGITS_EXPORT_SIGNATURE, head_name):
                            tf.estimator.export.PredictOutput(
                                subnetwork_logits)
                        })
                else:
                    subnetwork_logits = {
                        subnetwork_name: subnetwork_spec.subnetwork.logits
                        for subnetwork_name, subnetwork_spec in
                        subnetwork_spec_map.items()
                    }
                    export_outputs.update({
                        _EnsembleBuilder._SUBNETWORK_LOGITS_EXPORT_SIGNATURE:
                        tf.estimator.export.PredictOutput(subnetwork_logits)
                    })

            if (self._export_subnetwork_last_layer and export_outputs
                    and subnetwork_spec_map and list(
                        subnetwork_spec_map.values())[0].subnetwork.last_layer
                    is not None):
                first_subnetwork_last_layer = list(
                    subnetwork_spec_map.values())[0].subnetwork.last_layer
                if isinstance(first_subnetwork_last_layer, dict):
                    for head_name in first_subnetwork_last_layer.keys():
                        subnetwork_last_layer = {
                            subnetwork_name:
                            subnetwork_spec.subnetwork.last_layer[head_name]
                            for subnetwork_name, subnetwork_spec in
                            subnetwork_spec_map.items()
                        }
                        export_outputs.update({
                            "{}_{}".format(
                                _EnsembleBuilder._SUBNETWORK_LAST_LAYER_EXPORT_SIGNATURE, head_name):
                            tf.estimator.export.PredictOutput(
                                subnetwork_last_layer)
                        })
                else:
                    subnetwork_last_layer = {
                        subnetwork_name: subnetwork_spec.subnetwork.last_layer
                        for subnetwork_name, subnetwork_spec in
                        subnetwork_spec_map.items()
                    }
                    export_outputs.update({
                        _EnsembleBuilder._SUBNETWORK_LAST_LAYER_EXPORT_SIGNATURE:
                        tf.estimator.export.PredictOutput(
                            subnetwork_last_layer)
                    })

            if ensemble.predictions and predictions:
                predictions.update(ensemble.predictions)
            if ensemble.predictions and export_outputs:
                export_outputs.update({
                    k: tf.estimator.export.PredictOutput(v)
                    for k, v in ensemble.predictions.items()
                })

            ensemble_metrics = _EnsembleMetrics(use_tpu=self._use_tpu)
            if mode == tf.estimator.ModeKeys.EVAL:
                ensemble_metrics.create_eval_metrics(
                    features=features,
                    labels=labels,
                    estimator_spec=estimator_spec,
                    metric_fn=self._metric_fn,
                    architecture=architecture)

            if mode == tf.estimator.ModeKeys.TRAIN:
                with summary.current_scope():
                    summary.scalar("loss", estimator_spec.loss)

            ensemble_trainable_vars = _get_current_vars(
                diffbase=old_vars)["trainable"]
            # Create train ops for training subnetworks and ensembles.
            train_op = None
            if mode == tf.estimator.ModeKeys.TRAIN:
                # Note that these mixture weights are on top of the last_layer of the
                # subnetwork constructed in TRAIN mode, which means that dropout is
                # still applied when the mixture weights are being trained.
                ensemble_scope = tf_compat.v1.get_variable_scope()
                with tf_compat.v1.variable_scope("train_mixture_weights"):
                    with summary.current_scope(), _monkey_patch_context(
                            iteration_step_scope=ensemble_scope,
                            scoped_summary=summary,
                            trainable_vars=ensemble_trainable_vars):
                        # For backwards compatibility.
                        subnetwork_builder = candidate.subnetwork_builders[0]
                        old_train_op_fn = getattr(
                            subnetwork_builder,
                            "build_mixture_weights_train_op", None)
                        if callable(old_train_op_fn):
                            logging.warn(
                                "The `build_mixture_weights_train_op` method is deprecated. "
                                "Please use the `Ensembler#build_train_op` instead."
                            )
                            train_op = _to_train_op_spec(
                                subnetwork_builder.
                                build_mixture_weights_train_op(
                                    loss=adanet_loss,
                                    var_list=ensemble_trainable_vars,
                                    logits=ensemble.logits,
                                    labels=labels,
                                    iteration_step=step_tensor,
                                    summary=summary))
                        else:
                            train_op = _to_train_op_spec(
                                ensembler.build_train_op(
                                    ensemble=ensemble,
                                    loss=adanet_loss,
                                    var_list=ensemble_trainable_vars,
                                    labels=labels,
                                    iteration_step=step_tensor,
                                    summary=summary,
                                    previous_ensemble=previous_ensemble))

            new_vars = _get_current_vars(diffbase=old_vars)
            # Sort our dictionary by key to remove non-determinism of variable order.
            new_vars = collections.OrderedDict(sorted(new_vars.items()))
            # Combine all trainable, global and savable variables into a single list.
            ensemble_variables = sum(new_vars.values(), []) + [step]

        return _EnsembleSpec(name=name,
                             architecture=architecture,
                             subnetwork_builders=subnetwork_builders,
                             subnetwork_specs=previous_subnetwork_specs +
                             relevant_subnetwork_specs,
                             ensemble=ensemble,
                             predictions=predictions,
                             step=step,
                             variables=ensemble_variables,
                             loss=ensemble_loss,
                             adanet_loss=adanet_loss,
                             train_op=train_op,
                             eval_metrics=ensemble_metrics,
                             export_outputs=export_outputs)
Ejemplo n.º 11
0
    def append_new_subnetwork(self,
                              ensemble_name,
                              ensemble_spec,
                              subnetwork_builder,
                              iteration_number,
                              iteration_step,
                              summary,
                              features,
                              mode,
                              labels=None):
        """Adds a `Subnetwork` to an `_EnsembleSpec`.

    For iteration t > 0, the ensemble is built given the `Ensemble` for t-1 and
    the new subnetwork to train as part of the ensemble. The `Ensemble` at
    iteration 0 is comprised of just the subnetwork.

    The subnetwork is first given a weight 'w' in a `WeightedSubnetwork`
    which determines its contribution to the ensemble. The subnetwork's
    complexity L1-regularizes this weight.

    Args:
      ensemble_name: String name of the ensemble.
      ensemble_spec: The recipient `_EnsembleSpec` for the `Subnetwork`.
      subnetwork_builder: A `adanet.Builder` instance which defines how to train
        the subnetwork and ensemble mixture weights.
      iteration_number: Integer current iteration number.
      iteration_step: Integer `Tensor` representing the step since the beginning
        of the current iteration, as opposed to the global step.
      summary: A `_ScopedSummary` instance for recording ensemble summaries.
      features: Input `dict` of `Tensor` objects.
      mode: Estimator's `ModeKeys`.
      labels: Labels `Tensor` or a dictionary of string label name to `Tensor`
        (for multi-head). Can be `None`.

    Returns:
      An new `EnsembleSpec` instance with the `Subnetwork` appended.
    """

        with tf.variable_scope("ensemble_{}".format(ensemble_name)):
            weighted_subnetworks = []
            subnetwork_index = 0
            num_subnetworks = 1
            ensemble = None
            architecture = _Architecture()
            if ensemble_spec:
                ensemble = ensemble_spec.ensemble
                previous_subnetworks = [
                    ensemble.weighted_subnetworks[index] for index in
                    subnetwork_builder.prune_previous_ensemble(ensemble)
                ]
                num_subnetworks += len(previous_subnetworks)
                for weighted_subnetwork in previous_subnetworks:
                    weight_initializer = None
                    if self._warm_start_mixture_weights:
                        weight_initializer = tf.contrib.framework.load_variable(
                            self._checkpoint_dir,
                            weighted_subnetwork.weight.op.name)
                    with tf.variable_scope(
                            "weighted_subnetwork_{}".format(subnetwork_index)):
                        weighted_subnetworks.append(
                            self._build_weighted_subnetwork(
                                weighted_subnetwork.name,
                                weighted_subnetwork.iteration_number,
                                weighted_subnetwork.subnetwork,
                                num_subnetworks,
                                weight_initializer=weight_initializer))
                    architecture.add_subnetwork(
                        weighted_subnetwork.iteration_number,
                        weighted_subnetwork.name)
                    subnetwork_index += 1

            ensemble_scope = tf.get_variable_scope()

            with tf.variable_scope(
                    "weighted_subnetwork_{}".format(subnetwork_index)):
                with tf.variable_scope("subnetwork"):
                    _clear_trainable_variables()
                    build_subnetwork = functools.partial(
                        subnetwork_builder.build_subnetwork,
                        features=features,
                        logits_dimension=self._head.logits_dimension,
                        training=mode == tf.estimator.ModeKeys.TRAIN,
                        iteration_step=iteration_step,
                        summary=summary,
                        previous_ensemble=ensemble)
                    # Check which args are in the implemented build_subnetwork method
                    # signature for backwards compatibility.
                    defined_args = inspect.getargspec(
                        subnetwork_builder.build_subnetwork).args
                    if "labels" in defined_args:
                        build_subnetwork = functools.partial(build_subnetwork,
                                                             labels=labels)
                    with summary.current_scope(), _subnetwork_context(
                            iteration_step_scope=ensemble_scope,
                            scoped_summary=summary):
                        tf.logging.info("Building subnetwork '%s'",
                                        subnetwork_builder.name)
                        subnetwork = build_subnetwork()
                    var_list = tf.trainable_variables()
                weighted_subnetworks.append(
                    self._build_weighted_subnetwork(subnetwork_builder.name,
                                                    iteration_number,
                                                    subnetwork,
                                                    num_subnetworks))
                architecture.add_subnetwork(iteration_number,
                                            subnetwork_builder.name)
            if ensemble:
                if len(previous_subnetworks) == len(
                        ensemble.weighted_subnetworks):
                    bias = self._create_bias_term(weighted_subnetworks,
                                                  prior=ensemble.bias)
                else:
                    bias = self._create_bias_term(weighted_subnetworks)
                    tf.logging.info(
                        "Builder '%s' is using a subset of the subnetworks "
                        "from the previous ensemble, so its ensemble's bias "
                        "term will not be warm started with the previous "
                        "ensemble's bias.", subnetwork_builder.name)
            else:
                bias = self._create_bias_term(weighted_subnetworks)

            return self._build_ensemble_spec(
                name=ensemble_name,
                weighted_subnetworks=weighted_subnetworks,
                architecture=architecture,
                summary=summary,
                bias=bias,
                features=features,
                mode=mode,
                iteration_step=iteration_step,
                labels=labels,
                subnetwork_builder=subnetwork_builder,
                var_list=var_list,
                previous_ensemble_spec=ensemble_spec)
Ejemplo n.º 12
0
 def test_subnetworks(self, subnetworks, want):
     arch = _Architecture()
     for subnetwork in subnetworks:
         arch.add_subnetwork(*subnetwork)
     self.assertEqual(want, arch.subnetworks)
 def test_set_and_add_replay_index(self):
     arch = _Architecture("foo", "dummy_ensembler_name")
     arch.set_replay_indices([1, 2, 3])
     self.assertAllEqual([1, 2, 3], arch.replay_indices)
     arch.add_replay_index(4)
     self.assertAllEqual([1, 2, 3, 4], arch.replay_indices)
 def test_subnetworks_grouped_by_iteration(self, subnetworks, want):
     arch = _Architecture("foo", "dummy_ensembler_name")
     for subnetwork in subnetworks:
         arch.add_subnetwork(*subnetwork)
     self.assertEqual(want, arch.subnetworks_grouped_by_iteration)
Ejemplo n.º 15
0
 def test_subnetworks_grouped_by_iteration(self, subnetworks, want):
     arch = _Architecture()
     for subnetwork in subnetworks:
         arch.add_subnetwork(*subnetwork)
     self.assertEqual(want, arch.subnetworks_grouped_by_iteration)