コード例 #1
0
    def _scoped_metric_fn(**kwargs):
        """The wrapping function to be returned."""

        if not metric_fn:
            return {}

        kwargs = dict_utils.unflatten_dict(kwargs,
                                           prefixes=[group_name])[group_name]
        kwargs = dict_utils.unflatten_dict(kwargs, prefixes=_PREFIXES)
        kwargs = {
            k: _reconstruct_tuple_keys(v)
            for k, v in six.iteritems(kwargs)
        }
        kwargs_ = {}
        for key, value in six.iteritems(kwargs):
            if key in _PREFIXES and key != _KWARGS_KEY:
                kwargs_[key.replace("_", "")] = value
            else:
                kwargs_[key] = value
        kwargs = kwargs_

        metrics = _reflective_call(metric_fn, **kwargs)
        rescoped_metrics = {}
        for key, value in six.iteritems(metrics):
            rescoped_metrics["{}/adanet/{}".format(key, group_name)] = value
        return rescoped_metrics
コード例 #2
0
    def test_unflatten_dict(self):
        flat_dict = {
            "hello-world": 1,
            "hello-sailor": 2,
            "ada-net": 3,
            "ada-boost": 4,
            "nodict": 5,
        }

        actual_wrong_delimiter = dict_utils.unflatten_dict(
            flat_dict, prefixes=["hello", "ada"], delimiter="/")
        actual_unflattened = dict_utils.unflatten_dict(flat_dict,
                                                       prefixes=["ada", "unk"],
                                                       delimiter="-")

        expected = {
            "hello-world": 1,
            "hello-sailor": 2,
            "ada": {
                "net": 3,
                "boost": 4,
            },
            "nodict": 5,
        }

        self.assertDictEqual(actual_wrong_delimiter, flat_dict)
        self.assertDictEqual(actual_unflattened, expected)
コード例 #3
0
        def _best_eval_metrics_fn(**kwargs):
            """Returns the best eval metrics."""

            with tf.variable_scope("best_eval_metrics"):
                subnetwork_metric_fns = {
                    k: metric_fns[k]
                    for k in metric_fns if k.startswith("subnetwork_")
                }
                subnetwork_tensors = dict_utils.unflatten_dict(
                    kwargs, subnetwork_metric_fns.keys())
                subnetwork_metric_ops = self._group_metric_ops(
                    subnetwork_metric_fns, subnetwork_tensors)
                ensemble_metric_fns = {
                    k: metric_fns[k]
                    for k in metric_fns if k.startswith("ensemble_")
                }
                ensemble_tensors = dict_utils.unflatten_dict(
                    kwargs, ensemble_metric_fns.keys())
                grouped_metrics = self._group_metric_ops(
                    ensemble_metric_fns, ensemble_tensors)

                eval_metric_ops = {}
                for metric_name in sorted(grouped_metrics):
                    metric_ops = grouped_metrics[metric_name]
                    if len(metric_ops) != len(candidates):
                        continue
                    if metric_name == "loss":
                        continue

                    best_candidate_index = kwargs["best_candidate_index"]
                    values, ops = list(six.moves.zip(*metric_ops))
                    idx, idx_update_op = tf.metrics.mean(best_candidate_index)
                    best_value = tf.stack(values)[tf.cast(idx, tf.int32)]
                    # All tensors in this function have been outfed from the TPU, so we
                    # must update them manually, otherwise the TPU will hang indefinetly
                    # for the value of idx to update.
                    ops = list(ops)
                    ops.append(idx_update_op)
                    # Bundle subnetwork eval metric ops and ensemble "loss"" ops (which
                    # is a restricted Estimator keyword) into other metric ops so that
                    # they are computed.
                    ensemble_loss_ops = grouped_metrics.get("loss", tf.no_op())
                    all_ops = tf.group(ops, ensemble_loss_ops,
                                       subnetwork_metric_ops)
                    eval_metric_ops[metric_name] = (best_value, all_ops)

                # tf.estimator.Estimator does not allow a "loss" key to be present in
                # its eval_metrics.
                assert "loss" not in eval_metric_ops
                return eval_metric_ops
コード例 #4
0
        def _best_eval_metrics_fn(**kwargs):
            """Returns the best eval metrics."""

            with tf.variable_scope("best_eval_metrics"):
                tensors = dict_utils.unflatten_dict(kwargs, metric_fns.keys())
                grouped_metrics = self._group_metric_ops(metric_fns, tensors)

                eval_metric_ops = {}
                for metric_name in sorted(grouped_metrics):
                    metric_ops = grouped_metrics[metric_name]
                    if len(metric_ops) != len(candidates):
                        continue

                    best_candidate_index = tensors["best_candidate_index"]
                    values, ops = list(six.moves.zip(*metric_ops))
                    idx, idx_update_op = tf.metrics.mean(best_candidate_index)
                    best_value = tf.stack(values)[tf.cast(idx, tf.int32)]
                    # All tensors in this function have been outfed from the TPU, so we
                    # must update them manually, otherwise the TPU will hang indefinetly
                    # for the value of idx to wait.
                    ops = list(ops)
                    ops.append(idx_update_op)
                    best_op = tf.group(ops)
                    best_candidate_metric = (best_value, best_op)
                    eval_metric_ops[metric_name] = best_candidate_metric

                    # Include any evaluation metric shared among all the candidates in
                    # the top level metrics in TensorBoard. These "root" metrics track
                    # AdaNet's overall performance, making it easier to compare with other
                    # estimators that report the same metrics.
                    suffix = "/adanet/adanet_weighted_ensemble"
                    if not metric_name.endswith(suffix):
                        continue
                    root_metric_name = metric_name[:-len(suffix)]
                    if root_metric_name == "loss":
                        continue
                    eval_metric_ops[root_metric_name] = best_candidate_metric

                return eval_metric_ops
コード例 #5
0
    def _wrapped_metric_fn(**kwargs):
        """The wrapping function to be returned."""

        if not metric_fn:
            return {}

        kwargs = dict_utils.unflatten_dict(kwargs, prefixes=_PREFIXES)
        kwargs = {
            k: _reconstruct_tuple_keys(v)
            for k, v in six.iteritems(kwargs)
        }
        kwargs_ = {}
        for key, value in six.iteritems(kwargs):
            if key in _PREFIXES and key != _KWARGS_KEY:
                kwargs_[key.replace("_", "")] = value
            else:
                kwargs_[key] = value
        kwargs = kwargs_

        metrics = _reflective_call(metric_fn, **kwargs)
        wrapped_metrics = {}
        # Hooks on TPU cannot depend on any graph Tensors. Instead the metric values
        # are stored in Variables that are later read from the evaluation hooks.
        for i, key in enumerate(sorted(metrics)):
            tensor, op = metrics[key]
            # `key` cannot be in the var name since it can contain illegal characters.
            var = tf.get_variable("metric_{}".format(i),
                                  shape=tensor.shape,
                                  dtype=tensor.dtype,
                                  trainable=False,
                                  initializer=tf.zeros_initializer(),
                                  collections=[tf.GraphKeys.LOCAL_VARIABLES])
            if isinstance(op, tf.Operation):
                with tf.control_dependencies([op]):
                    op = tf.assign(var, tensor)
            metric = (var, tf.assign(var, op))
            wrapped_metrics[key] = metric
        return wrapped_metrics