Ejemplo n.º 1
0
    def _metric_fn(*args, **kwargs):
      """The wrapping function to be returned."""

      # We can only be passed in either a dict or a list of tensors.
      args = args if args else kwargs
      metrics = call_eval_metrics((metric_fn, args))
      if not self._use_tpu:
        return metrics

      logging.log_first_n(logging.INFO,
                          "Writing eval metrics to variables for TPU", 1)
      wrapped_metrics = {}
      for i, key in enumerate(sorted(metrics)):
        tensor, op = tf_compat.metric_op(metrics[key])
        # key cannot be in var name since it may contain illegal chars.
        var = tf_compat.v1.get_variable(
            "metric_{}".format(i),
            shape=tensor.shape,
            dtype=tensor.dtype,
            trainable=False,
            initializer=tf_compat.v1.zeros_initializer(),
            collections=[tf_compat.v1.GraphKeys.LOCAL_VARIABLES])
        if isinstance(op, tf.Operation) or op.shape != tensor.shape:
          with tf.control_dependencies([op]):
            op = var.assign(tensor)
        metric = (var, var.assign(op))
        wrapped_metrics[key] = metric
      return wrapped_metrics
Ejemplo n.º 2
0
  def _group_metric_ops(self, metric_fns, metric_fn_args):
    """Runs the metric_fns and groups the returned metric ops by name.

    Args:
      metric_fns: The eval_metrics functions to run.
      metric_fn_args: The eval_metrics function arguments.

    Returns:
      The metric ops grouped by name.
    """

    grouped_metrics = collections.defaultdict(list)
    for metric_fn, args in zip(metric_fns, metric_fn_args):
      eval_metric_ops = call_eval_metrics((metric_fn, args))
      for metric_name in sorted(eval_metric_ops):
        metric_op = tf_compat.metric_op(eval_metric_ops[metric_name])
        grouped_metrics[metric_name].append(metric_op)
    return grouped_metrics
Ejemplo n.º 3
0
    def __new__(cls, hparams, attributes, metrics):
        def _is_scalar(tensor):
            """Returns True iff tensor is scalar."""
            return tensor.shape.ndims == 0

        def _is_accepted_dtype(tensor):
            """Returns True iff tensor has the dtype we can handle."""
            return tensor.dtype.base_dtype in (tf.bool, tf.int32, tf.float32,
                                               tf.float64, tf.string)

        # Validate hparams
        for key, value in hparams.items():
            if not isinstance(value, (bool, int, float, six.string_types)):
                raise ValueError(
                    "hparam '{}' refers to invalid value {}, type {}. type must be "
                    "python primitive int, float, bool, or string.".format(
                        key, value, type(value)))

        # Validate attributes
        for key, value in attributes.items():
            if not isinstance(value, tf.Tensor):
                raise ValueError(
                    "attribute '{}' refers to invalid value: {}, type: {}."
                    "type must be Tensor.".format(key, value, type(value)))

            if not (_is_scalar(value) and _is_accepted_dtype(value)):
                raise ValueError(
                    "attribute '{}' refers to invalid tensor {}. Shape: {}".
                    format(key, value, value.get_shape()))

        # Validate metrics
        metrics_copy = {}
        for key, value in metrics.items():
            value = tf_compat.metric_op(value)
            if not isinstance(value, tuple):
                raise ValueError(
                    "metric '{}' has invalid type {}. Must be a tuple.".format(
                        key, type(value)))

            if len(value) < 2:
                raise ValueError(
                    "metric tuple '{}' has fewer than 2 elements".format(key))

            if not isinstance(value[0], (tf.Tensor, tf.Variable)):
                raise ValueError(
                    "First element of metric tuple '{}' has value {} and type {}. "
                    "Must be a Tensor or Variable.".format(
                        key, value[0], type(value[0])))

            if not _is_accepted_dtype(value[0]):
                raise ValueError(
                    "First element of metric '{}' refers to Tensor of the wrong "
                    "dtype {}. Must be one of tf.bool, tf.int32, tf.float32, "
                    "tf.float64 or tf.string.".format(key, value[0].dtype))

            if not _is_scalar(value[0]):
                tf.logging.warn(
                    "First element of metric '{}' refers to Tensor of rank > 0. "
                    "AdaNet is currently unable to store metrics of rank > 0 -- this "
                    "metric will be dropped from the report. "
                    "value: {}".format(key, value[0]))
                continue

            if not isinstance(value[1],
                              (tf.Tensor, tf.Operation, tf.Variable)):
                raise ValueError(
                    "Second element of metric tuple '{}' has value {} and type {}. "
                    "Must be a Tensor, Operation, or Variable.".format(
                        key, value[1], type(value[1])))

            metrics_copy[key] = value

        return super(Report, cls).__new__(cls,
                                          hparams=hparams,
                                          attributes=attributes,
                                          metrics=metrics_copy)
Ejemplo n.º 4
0
    def materialize_subnetwork_reports(self, sess, iteration_number,
                                       subnetwork_reports,
                                       included_subnetwork_names):
        """Materializes the Tensor objects in subnetwork_reports using sess.

    This converts the Tensors in subnetwork_reports to ndarrays, logs the
    progress, converts the ndarrays to python primitives, then packages them
    into `adanet.subnetwork.MaterializedReports`.

    Args:
      sess: `Session` instance with most recent variable values loaded.
      iteration_number: Integer iteration number.
      subnetwork_reports: Dict mapping string names to `subnetwork.Report`
        objects to be materialized.
      included_subnetwork_names: List of string names of the
        `subnetwork.Report`s that are included in the final ensemble.

    Returns:
      List of `adanet.subnetwork.MaterializedReport` objects.
    """

        # A metric is a tuple where the first element is a Tensor and
        # the second element is an update op. We collate the update ops here.
        metric_update_ops = []
        for subnetwork_report in subnetwork_reports.values():
            for metric_tuple in subnetwork_report.metrics.values():
                metric_update_ops.append(tf_compat.metric_op(metric_tuple)[1])

        # Extract the Tensors to be materialized.
        tensors_to_materialize = {}
        for name, subnetwork_report in subnetwork_reports.items():
            metrics = {
                metric_key: tf_compat.metric_op(metric_tuple)[0]
                for metric_key, metric_tuple in
                subnetwork_report.metrics.items()
            }
            tensors_to_materialize[name] = {
                "attributes": subnetwork_report.attributes,
                "metrics": metrics
            }

        if self.steps is None:
            logging_frequency = 1000
        elif self.steps < 10:
            logging_frequency = 1
        else:
            logging_frequency = math.floor(self.steps / 10.)

        steps_completed = 0
        while True:
            if self.steps is not None and steps_completed == self.steps:
                break
            try:
                steps_completed += 1
                if (steps_completed % logging_frequency == 0
                        or self.steps == steps_completed):
                    logging.info("Report materialization [%d/%s]",
                                 steps_completed, self.steps or "??")

                sess.run(metric_update_ops)
            except tf.errors.OutOfRangeError:
                logging.info(
                    "Encountered end of input during report materialization")
                break

        materialized_tensors_dict = sess.run(tensors_to_materialize)
        logging.info("Materialized subnetwork_reports.")

        # Convert scalar ndarrays into python primitives, then place them into
        # subnetwork.MaterializedReports.
        materialized_reports = []
        for name, materialized_tensors in materialized_tensors_dict.items():
            attributes = {
                key: value.item() if hasattr(value, "item") else value
                for key, value in materialized_tensors["attributes"].items()
            }
            metrics = {
                key: value.item() if hasattr(value, "item") else value
                for key, value in materialized_tensors["metrics"].items()
            }
            materialized_reports.append(
                subnetwork.MaterializedReport(
                    iteration_number=iteration_number,
                    name=name,
                    hparams=subnetwork_reports[name].hparams,
                    attributes=attributes,
                    metrics=metrics,
                    included_in_final_ensemble=(name
                                                in included_subnetwork_names)))
        return materialized_reports