def _metric_fn(*args, **kwargs): """The wrapping function to be returned.""" # We can only be passed in either a dict or a list of tensors. args = args if args else kwargs metrics = call_eval_metrics((metric_fn, args)) if not self._use_tpu: return metrics logging.log_first_n(logging.INFO, "Writing eval metrics to variables for TPU", 1) wrapped_metrics = {} for i, key in enumerate(sorted(metrics)): tensor, op = tf_compat.metric_op(metrics[key]) # key cannot be in var name since it may contain illegal chars. var = tf_compat.v1.get_variable( "metric_{}".format(i), shape=tensor.shape, dtype=tensor.dtype, trainable=False, initializer=tf_compat.v1.zeros_initializer(), collections=[tf_compat.v1.GraphKeys.LOCAL_VARIABLES]) if isinstance(op, tf.Operation) or op.shape != tensor.shape: with tf.control_dependencies([op]): op = var.assign(tensor) metric = (var, var.assign(op)) wrapped_metrics[key] = metric return wrapped_metrics
def _group_metric_ops(self, metric_fns, metric_fn_args): """Runs the metric_fns and groups the returned metric ops by name. Args: metric_fns: The eval_metrics functions to run. metric_fn_args: The eval_metrics function arguments. Returns: The metric ops grouped by name. """ grouped_metrics = collections.defaultdict(list) for metric_fn, args in zip(metric_fns, metric_fn_args): eval_metric_ops = call_eval_metrics((metric_fn, args)) for metric_name in sorted(eval_metric_ops): metric_op = tf_compat.metric_op(eval_metric_ops[metric_name]) grouped_metrics[metric_name].append(metric_op) return grouped_metrics
def __new__(cls, hparams, attributes, metrics): def _is_scalar(tensor): """Returns True iff tensor is scalar.""" return tensor.shape.ndims == 0 def _is_accepted_dtype(tensor): """Returns True iff tensor has the dtype we can handle.""" return tensor.dtype.base_dtype in (tf.bool, tf.int32, tf.float32, tf.float64, tf.string) # Validate hparams for key, value in hparams.items(): if not isinstance(value, (bool, int, float, six.string_types)): raise ValueError( "hparam '{}' refers to invalid value {}, type {}. type must be " "python primitive int, float, bool, or string.".format( key, value, type(value))) # Validate attributes for key, value in attributes.items(): if not isinstance(value, tf.Tensor): raise ValueError( "attribute '{}' refers to invalid value: {}, type: {}." "type must be Tensor.".format(key, value, type(value))) if not (_is_scalar(value) and _is_accepted_dtype(value)): raise ValueError( "attribute '{}' refers to invalid tensor {}. Shape: {}". format(key, value, value.get_shape())) # Validate metrics metrics_copy = {} for key, value in metrics.items(): value = tf_compat.metric_op(value) if not isinstance(value, tuple): raise ValueError( "metric '{}' has invalid type {}. Must be a tuple.".format( key, type(value))) if len(value) < 2: raise ValueError( "metric tuple '{}' has fewer than 2 elements".format(key)) if not isinstance(value[0], (tf.Tensor, tf.Variable)): raise ValueError( "First element of metric tuple '{}' has value {} and type {}. " "Must be a Tensor or Variable.".format( key, value[0], type(value[0]))) if not _is_accepted_dtype(value[0]): raise ValueError( "First element of metric '{}' refers to Tensor of the wrong " "dtype {}. Must be one of tf.bool, tf.int32, tf.float32, " "tf.float64 or tf.string.".format(key, value[0].dtype)) if not _is_scalar(value[0]): tf.logging.warn( "First element of metric '{}' refers to Tensor of rank > 0. " "AdaNet is currently unable to store metrics of rank > 0 -- this " "metric will be dropped from the report. " "value: {}".format(key, value[0])) continue if not isinstance(value[1], (tf.Tensor, tf.Operation, tf.Variable)): raise ValueError( "Second element of metric tuple '{}' has value {} and type {}. " "Must be a Tensor, Operation, or Variable.".format( key, value[1], type(value[1]))) metrics_copy[key] = value return super(Report, cls).__new__(cls, hparams=hparams, attributes=attributes, metrics=metrics_copy)
def materialize_subnetwork_reports(self, sess, iteration_number, subnetwork_reports, included_subnetwork_names): """Materializes the Tensor objects in subnetwork_reports using sess. This converts the Tensors in subnetwork_reports to ndarrays, logs the progress, converts the ndarrays to python primitives, then packages them into `adanet.subnetwork.MaterializedReports`. Args: sess: `Session` instance with most recent variable values loaded. iteration_number: Integer iteration number. subnetwork_reports: Dict mapping string names to `subnetwork.Report` objects to be materialized. included_subnetwork_names: List of string names of the `subnetwork.Report`s that are included in the final ensemble. Returns: List of `adanet.subnetwork.MaterializedReport` objects. """ # A metric is a tuple where the first element is a Tensor and # the second element is an update op. We collate the update ops here. metric_update_ops = [] for subnetwork_report in subnetwork_reports.values(): for metric_tuple in subnetwork_report.metrics.values(): metric_update_ops.append(tf_compat.metric_op(metric_tuple)[1]) # Extract the Tensors to be materialized. tensors_to_materialize = {} for name, subnetwork_report in subnetwork_reports.items(): metrics = { metric_key: tf_compat.metric_op(metric_tuple)[0] for metric_key, metric_tuple in subnetwork_report.metrics.items() } tensors_to_materialize[name] = { "attributes": subnetwork_report.attributes, "metrics": metrics } if self.steps is None: logging_frequency = 1000 elif self.steps < 10: logging_frequency = 1 else: logging_frequency = math.floor(self.steps / 10.) steps_completed = 0 while True: if self.steps is not None and steps_completed == self.steps: break try: steps_completed += 1 if (steps_completed % logging_frequency == 0 or self.steps == steps_completed): logging.info("Report materialization [%d/%s]", steps_completed, self.steps or "??") sess.run(metric_update_ops) except tf.errors.OutOfRangeError: logging.info( "Encountered end of input during report materialization") break materialized_tensors_dict = sess.run(tensors_to_materialize) logging.info("Materialized subnetwork_reports.") # Convert scalar ndarrays into python primitives, then place them into # subnetwork.MaterializedReports. materialized_reports = [] for name, materialized_tensors in materialized_tensors_dict.items(): attributes = { key: value.item() if hasattr(value, "item") else value for key, value in materialized_tensors["attributes"].items() } metrics = { key: value.item() if hasattr(value, "item") else value for key, value in materialized_tensors["metrics"].items() } materialized_reports.append( subnetwork.MaterializedReport( iteration_number=iteration_number, name=name, hparams=subnetwork_reports[name].hparams, attributes=attributes, metrics=metrics, included_in_final_ensemble=(name in included_subnetwork_names))) return materialized_reports