def _sort(self, session_groups):
    """Sorts 'session_groups' in place according to _request.col_params."""

    # Sort by session_group name so we have a deterministic order.
    session_groups.sort(key=operator.attrgetter('name'))
    # Sort by lexicographical order of the _request.col_params whose order
    # is not ORDER_UNSPECIFIED. The first such column is the primary sorting
    # key, the second is the secondary sorting key, etc. To achieve that we
    # need to iterate on these columns in reverse order (thus the primary key
    # is the key used in the last sort).
    for col_param, extractor in reversed(list(zip(self._request.col_params,
                                                  self._extractors))):
      if col_param.order == api_pb2.ORDER_UNSPECIFIED:
        continue
      if col_param.order == api_pb2.ORDER_ASC:
        session_groups.sort(
            key=_create_key_func(
                extractor,
                none_is_largest=not col_param.missing_values_first))
      elif col_param.order == api_pb2.ORDER_DESC:
        session_groups.sort(
            key=_create_key_func(
                extractor,
                none_is_largest=col_param.missing_values_first),
            reverse=True)
      else:
        raise error.HParamsError('Unknown col_param.order given: %s' %
                                 col_param)
 def filter_fn(value):
   if (not isinstance(value, six.integer_types) and
       not isinstance(value, float)):
     raise error.HParamsError(
         'Cannot use an interval filter for a value of type: %s, Value: %s' %
         (type(value), value))
   return interval.min_value <= value and value <= interval.max_value
Esempio n. 3
0
 def filter_fn(value):
     if not isinstance(value, six.string_types):
         raise error.HParamsError(
             "Cannot use a regexp filter for a value of type %s. Value: %s"
             % (type(value), value)
         )
     return re.search(compiled_regex, value) is not None
Esempio n. 4
0
 def list_metric_evals_route(self, request):
     experiment = plugin_util.experiment_id(request.environ)
     try:
         request_proto = _parse_request_argument(
             request, api_pb2.ListMetricEvalsRequest
         )
         scalars_plugin = self._get_scalars_plugin()
         if not scalars_plugin:
             raise error.HParamsError(
                 "Internal error: the scalars plugin is not"
                 " registered; yet, the hparams plugin is"
                 " active."
             )
         return http_util.Respond(
             request,
             json.dumps(
                 list_metric_evals.Handler(
                     request_proto, scalars_plugin, experiment
                 ).run()
             ),
             "application/json",
         )
     except error.HParamsError as e:
         logger.error("HParams error: %s" % e)
         raise werkzeug.exceptions.BadRequest(description=str(e))
Esempio n. 5
0
    def _aggregate_metrics(self, session_group):
        """Sets the metrics of the group based on aggregation_type."""

        if (
            self._request.aggregation_type == api_pb2.AGGREGATION_AVG
            or self._request.aggregation_type == api_pb2.AGGREGATION_UNSET
        ):
            _set_avg_session_metrics(session_group)
        elif self._request.aggregation_type == api_pb2.AGGREGATION_MEDIAN:
            _set_median_session_metrics(
                session_group, self._request.aggregation_metric
            )
        elif self._request.aggregation_type == api_pb2.AGGREGATION_MIN:
            _set_extremum_session_metrics(
                session_group, self._request.aggregation_metric, min
            )
        elif self._request.aggregation_type == api_pb2.AGGREGATION_MAX:
            _set_extremum_session_metrics(
                session_group, self._request.aggregation_metric, max
            )
        else:
            raise error.HParamsError(
                "Unknown aggregation_type in request: %s"
                % self._request.aggregation_type
            )
Esempio n. 6
0
 def __init__(self, regex, extractor, include_missing_values):
     super(_SessionGroupRegexFilter, self).__init__(extractor,
                                                    include_missing_values)
     try:
         self._regex = re.compile(regex)
     except re.error as e:
         raise error.HParamsError('Error parsing regexp: %s. Error: %s' %
                                  (regex, e))
Esempio n. 7
0
def _parse_plugin_data_as(content, data_oneof_field):
    """Parses a given HParam's SummaryMetadata.plugin_data.content and
  returns the data oneof's field given by 'data_oneof_field'.

  Raises:
    HParamsError if the content doesn't have 'data_oneof_field' set or
    this file is incompatible with the version of the metadata stored.
  """
    plugin_data = plugin_data_pb2.HParamsPluginData.FromString(content)
    if plugin_data.version != PLUGIN_DATA_VERSION:
        raise error.HParamsError(
            'Only supports plugin_data version: %s; found: %s in: %s' %
            (PLUGIN_DATA_VERSION, plugin_data.version, plugin_data))
    if not plugin_data.HasField(data_oneof_field):
        raise error.HParamsError('Expected plugin_data.%s to be set. Got: %s' %
                                 (data_oneof_field, plugin_data))
    return getattr(plugin_data, data_oneof_field)
Esempio n. 8
0
 def create(col_param):
     if col_param.HasField("metric"):
         return _SessionGroupMetricExtractor(col_param.metric)
     elif col_param.HasField("hparam"):
         return _SessionGroupHParamExtractor(col_param.hparam)
     else:
         raise error.HParamsError(
             'Got ColParam with both "metric" and "hparam" fields unset: %s'
             % col_param)
def _create_extractor(col_param):
  if col_param.HasField('metric'):
    return _create_metric_extractor(col_param.metric)
  elif col_param.HasField('hparam'):
    return _create_hparam_extractor(col_param.hparam)
  else:
    raise error.HParamsError(
        'Got ColParam with both "metric" and "hparam" fields unset: %s' %
        col_param)
Esempio n. 10
0
def _parse_request_argument(request, proto_class):
  if request.method == 'POST':
    return json_format.Parse(request.data, proto_class())

  # args.get() returns the request URI-unescaped.
  request_json = request.args.get('request')
  if request_json is None:
    raise error.HParamsError(
        'Expected a JSON-formatted \'request\' arg of type: %s' % proto_class)
  return json_format.Parse(request_json, proto_class())
Esempio n. 11
0
def _parse_plugin_data_as(content, data_oneof_field):
    """Returns a data oneof's field from plugin_data.content.

  Raises HParamsError if the content doesn't have 'data_oneof_field' set or
  this file is incompatible with the version of the metadata stored.

  Args:
    content: The SummaryMetadata.plugin_data.content to use.
    data_oneof_field: string. The name of the data oneof field to return.
  """
    plugin_data = plugin_data_pb2.HParamsPluginData.FromString(content)
    if plugin_data.version != PLUGIN_DATA_VERSION:
        raise error.HParamsError(
            'Only supports plugin_data version: %s; found: %s in: %s' %
            (PLUGIN_DATA_VERSION, plugin_data.version, plugin_data))
    if not plugin_data.HasField(data_oneof_field):
        raise error.HParamsError('Expected plugin_data.%s to be set. Got: %s' %
                                 (data_oneof_field, plugin_data))
    return getattr(plugin_data, data_oneof_field)
Esempio n. 12
0
    def get_experiment_route(self, request):
        try:
            if not self.is_active():
                raise error.HParamsError("HParams plugin is not active.")

            return http_util.Respond(
                request, json_format.MessageToJson(self._context.experiment()),
                'application/json')
        except error.HParamsError as e:
            raise werkzeug.exceptions.BadRequest(description=str(e))
Esempio n. 13
0
 def list_session_groups_route(self, request):
     try:
         if not self.is_active():
             raise error.HParamsError("HParams plugin is not active.")
         # args.get() returns the request unquoted.
         request_proto = request.args.get('request')
         if request_proto is None:
             raise error.HParamsError(
                 '/session_groups must have a \'request\' arg.')
         request_proto = json_format.Parse(
             request_proto, api_pb2.ListSessionGroupsRequest())
         return http_util.Respond(
             request,
             json_format.MessageToJson(
                 list_session_groups.Handler(self._context,
                                             request_proto).run()),
             'application/json')
     except error.HParamsError as e:
         raise werkzeug.exceptions.BadRequest(description=str(e))
Esempio n. 14
0
    def run(self):
        """Handles the request specified on construction.

        Returns:
          An Experiment object.
        """
        experiment = self._context.experiment()
        if experiment is None:
            raise error.HParamsError(
                "Can't find an HParams-plugin experiment data in"
                " the log directory. Note that it takes some time to"
                " scan the log directory; if you just started"
                " Tensorboard it could be that we haven't finished"
                " scanning it yet. Consider trying again in a"
                " few seconds.")
        return experiment
Esempio n. 15
0
 def list_metric_evals_route(self, request):
   try:
     request_proto = _parse_request_argument(
         request, api_pb2.ListMetricEvalsRequest)
     scalars_plugin = self._get_scalars_plugin()
     if not scalars_plugin:
       raise error.HParamsError('Internal error: the scalars plugin is not'
                                ' registered; yet, the hparams plugin is'
                                ' active.')
     return http_util.Respond(
         request,
         json.dumps(
             list_metric_evals.Handler(request_proto, scalars_plugin).run()),
         'application/json')
   except error.HParamsError as e:
     logger.error('HParams error: %s' % e)
     raise werkzeug.exceptions.BadRequest(description=str(e))
Esempio n. 16
0
    def _sort(self, session_groups):
        """Sorts 'session_groups' in place according to _request.col_params"""
        def _create_key_func(extractor, none_is_largest):
            """Returns a key_func to be used in list.sort() that sorts session groups
      by the value extracted by extractor. None extracted values will either
      be considered largest or smallest as specified by the "none_is_largest"
      boolean parameter. """
            if none_is_largest:

                def key_func_none_is_largest(session_group):
                    value = extractor.extract(session_group)
                    return (value is None, value)

                return key_func_none_is_largest

            def key_func_none_is_smallest(session_group):
                value = extractor.extract(session_group)
                return (value is not None, value)

            return key_func_none_is_smallest

        # Sort by session_group name so we have a deterministic order.
        session_groups.sort(key=lambda session_group: session_group.name)
        # Sort by lexicographical order of the _request.col_params whose order
        # is not ORDER_UNSPECIFIED. The first such column is the primary sorting
        # key, the second is the secondary sorting key, etc. To achieve that we
        # need to iterate on these columns in reverse order (thus the primary key
        # is the key used in the last sort).
        for col_param, extractor in reversed(
                list(zip(self._request.col_params, self._extractors))):
            if col_param.order == api_pb2.ORDER_UNSPECIFIED:
                continue
            if col_param.order == api_pb2.ORDER_ASC:
                session_groups.sort(key=_create_key_func(
                    extractor,
                    none_is_largest=not col_param.missing_values_first))
            elif col_param.order == api_pb2.ORDER_DESC:
                session_groups.sort(key=_create_key_func(
                    extractor, none_is_largest=col_param.missing_values_first),
                                    reverse=True)
            else:
                raise error.HParamsError('Unknown col_param.order given: %s' %
                                         col_param)
Esempio n. 17
0
def _verify_request_is_post(request, end_point):
  if request.method != 'POST':
    raise error.HParamsError('%s must be a POST. Got: %s' %
                             (end_point, request.method))
Esempio n. 18
0
    def run(self):
        """Handles the request specified on construction.

        Returns:
          A response body.
          A mime type (string) for the response.
        """
        experiment = self._experiment
        session_groups = self._session_groups
        response_format = self._response_format
        visibility = self._columns_visibility

        header = []
        for hparam_info in experiment.hparam_infos:
            header.append(hparam_info.display_name or hparam_info.name)

        for metric_info in experiment.metric_infos:
            header.append(metric_info.display_name or metric_info.name.tag)

        def _filter_columns(row):
            return [value for value, visible in zip(row, visibility) if visible]

        header = _filter_columns(header)

        rows = []

        def _get_value(value):
            if value.HasField("number_value"):
                return value.number_value
            if value.HasField("string_value"):
                return value.string_value
            if value.HasField("bool_value"):
                return value.bool_value
            # hyperparameter values can be optional in a session group
            return ""

        def _get_metric_id(metric):
            return metric.group + "." + metric.tag

        for group in session_groups.session_groups:
            row = []
            for hparam_info in experiment.hparam_infos:
                row.append(_get_value(group.hparams[hparam_info.name]))
            metric_values = {}
            for metric_value in group.metric_values:
                metric_id = _get_metric_id(metric_value.name)
                metric_values[metric_id] = metric_value.value
            for metric_info in experiment.metric_infos:
                metric_id = _get_metric_id(metric_info.name)
                row.append(metric_values.get(metric_id))
            rows.append(_filter_columns(row))

        if response_format == OutputFormat.JSON:
            mime_type = "application/json"
            body = dict(header=header, rows=rows)
        elif response_format == OutputFormat.LATEX:

            def latex_format(value):
                if value is None:
                    return "-"
                elif isinstance(value, int):
                    return "$%d$" % value
                elif isinstance(value, float):
                    if math.isnan(value):
                        return r"$\mathrm{NaN}$"
                    if value in (float("inf"), float("-inf")):
                        return r"$%s\infty$" % ("-" if value < 0 else "+")
                    scientific = "%.3g" % value
                    if "e" in scientific:
                        coefficient, exponent = scientific.split("e")
                        return "$%s\\cdot 10^{%d}$" % (
                            coefficient,
                            int(exponent),
                        )
                    return "$%s$" % scientific
                return value.replace("_", "\\_").replace("%", "\\%")

            mime_type = "application/x-latex"
            top_part = "\\begin{table}[tbp]\n\\begin{tabular}{%s}\n" % (
                "l" * len(header)
            )
            header_part = (
                " & ".join(map(latex_format, header)) + " \\\\ \\hline\n"
            )
            middle_part = "".join(
                " & ".join(map(latex_format, row)) + " \\\\\n" for row in rows
            )
            bottom_part = "\\hline\n\\end{tabular}\n\\end{table}\n"
            body = top_part + header_part + middle_part + bottom_part
        elif response_format == OutputFormat.CSV:
            string_io = six.StringIO()
            writer = csv.writer(string_io)
            writer.writerow(header)
            writer.writerows(rows)
            body = string_io.getvalue()
            mime_type = "text/csv"
        else:
            raise error.HParamsError(
                "Invalid reponses format: %s" % response_format
            )
        return body, mime_type
 def filter_fn(value):
     if not isinstance(value, (int, float)):
         raise error.HParamsError(
             "Cannot use an interval filter for a value of type: %s, Value: %s"
             % (type(value), value))
     return interval.min_value <= value and value <= interval.max_value
Esempio n. 20
0
 def _value_passes(self, value):
     if not isinstance(value, six.string_types):
         raise error.HParamsError(
             'Cannot use a regexp filter for a value of type %s. Value: %s'
             % (type(value), value))
     return self._regex.search(value) is not None