def _apply_operation_on_fine_grained_view(self, operation_def,
                                              fine_grained_views,
                                              next_hashed_path):
        """Applies a shardable operation on a fine grained view.

    This also updates `cache_output_nodes` when necessary.

    Args:
      operation_def: A shardable `OperationDef`.
      fine_grained_views: A tuple of `_OptimizationView.fine_grained_view`s.
      next_hashed_path: The hashed path for the currently processed
        operation_def.

    Returns:
      The resulting list of `_OptimizationView.fine_grained_view`s.
    """
        result_fine_grained_view = collections.OrderedDict()

        cache_entry_key = analyzer_cache.make_cache_entry_key(
            tf.compat.as_bytes(operation_def.label) + b'-' + next_hashed_path)

        for (dataset_idx, dataset_key) in enumerate(self._sorted_dataset_keys):
            # We use an index for the label in order to make beam labels more stable.
            infix = 'AnalysisIndex{}'.format(dataset_idx)
            if (operation_def.cache_coder and self._cache_dict.get(
                    dataset_key, {}).get(cache_entry_key) is not None):
                self._dataset_has_cache_misses[dataset_key] |= False
                decode_cache = analyzer_nodes.DecodeCache(
                    dataset_key,
                    cache_entry_key,
                    coder=operation_def.cache_coder,
                    label='DecodeCache[{}][{}]'.format(operation_def.label,
                                                       infix))
                (op_output, ) = nodes.OperationNode(decode_cache,
                                                    tuple()).outputs
            else:
                value_nodes = tuple(v[dataset_key] for v in fine_grained_views)
                (op_output, ) = nodes.OperationNode(
                    operation_def._replace(
                        label='{}[{}]'.format(operation_def.label, infix)),
                    value_nodes).outputs
                if operation_def.cache_coder:
                    self._dataset_has_cache_misses[dataset_key] = True
                    encode_cache = nodes.apply_operation(
                        analyzer_nodes.EncodeCache,
                        op_output,
                        coder=operation_def.cache_coder,
                        label='EncodeCache[{}][{}]'.format(
                            operation_def.label, infix))
                    self.cache_output_nodes[(dataset_key,
                                             cache_entry_key)] = encode_cache
            result_fine_grained_view[dataset_key] = op_output

        return result_fine_grained_view
Esempio n. 2
0
    def _apply_operation_on_fine_grained_view(self, operation_def,
                                              fine_grained_view):
        """Applies a shardable operation on a fine grained view.

    This also updates `cache_output_nodes` when necessary.

    Args:
      operation_def: A shardable `OperationDef`.
      fine_grained_view: A `_OptimizationView.fine_grained_view`.

    Returns:
      The resulting list of `_OptimizationView.fine_grained_view`s.
    """
        result_fine_grained_view = collections.OrderedDict()

        # TODO(b/37788560): Use a better cache key than label. A good alternative is
        # to reuse graph_tools logic to compose names that include properties and
        # fingerprint it.
        cache_entry_key = analyzer_cache.make_cache_entry_key(
            operation_def.label)
        for dataset_key in self._dataset_keys:

            # TODO(b/37788560): Add instrumentation.

            if self._cache_dict.get(dataset_key,
                                    {}).get(cache_entry_key) is not None:
                (op_output, ) = nodes.OperationNode(
                    analyzer_nodes.DecodeCache(
                        dataset_key,
                        cache_entry_key,
                        coder=operation_def.cache_coder), tuple()).outputs
            else:
                value_node = fine_grained_view[dataset_key]
                (op_output, ) = nodes.OperationNode(
                    operation_def._replace(label='{}[{}]'.format(
                        operation_def.label, dataset_key)),
                    (value_node, )).outputs
                if operation_def.cache_coder:
                    encoded_cache = nodes.apply_operation(
                        analyzer_nodes.EncodeCache,
                        op_output,
                        coder=operation_def.cache_coder,
                        label='EncodeCache[{}][{}]'.format(
                            operation_def.label, dataset_key))
                    self.cache_output_nodes[(dataset_key,
                                             cache_entry_key)] = encoded_cache
            result_fine_grained_view[dataset_key] = op_output

        return result_fine_grained_view
  def _apply_operation_on_fine_grained_view(self, operation_def,
                                            fine_grained_view,
                                            next_hashed_path):
    """Applies a shardable operation on a fine grained view.

    This also updates `cache_output_nodes` when necessary.

    Args:
      operation_def: A shardable `OperationDef`.
      fine_grained_view: A `_OptimizationView.fine_grained_view`.
      next_hashed_path: The hashed path for the currently processed
        operation_def.

    Returns:
      The resulting list of `_OptimizationView.fine_grained_view`s.
    """
    result_fine_grained_view = collections.OrderedDict()

    cache_entry_key = analyzer_cache.make_cache_entry_key(
        tf.compat.as_bytes(operation_def.label) + b'-' + next_hashed_path)

    for dataset_key in self._dataset_keys:

      if (operation_def.cache_coder and self._cache_dict.get(
          dataset_key, {}).get(cache_entry_key) is not None):
        (op_output,) = nodes.OperationNode(
            analyzer_nodes.DecodeCache(
                dataset_key,
                cache_entry_key,
                operation_def.label,
                coder=operation_def.cache_coder), tuple()).outputs
      else:
        value_node = fine_grained_view[dataset_key]
        (op_output,) = nodes.OperationNode(
            operation_def._replace(
                label='{}[{}]'.format(operation_def.label, dataset_key)),
            (value_node,)).outputs
        if operation_def.cache_coder:
          encoded_cache = nodes.apply_operation(
              analyzer_nodes.EncodeCache,
              op_output,
              coder=operation_def.cache_coder,
              label='EncodeCache[{}][{}]'.format(operation_def.label,
                                                 dataset_key))
          self.cache_output_nodes[(dataset_key,
                                   cache_entry_key)] = encoded_cache
      result_fine_grained_view[dataset_key] = op_output

    return result_fine_grained_view