Ejemplo n.º 1
0
    def _apply_operation_on_fine_grained_view(self, operation_def,
                                              fine_grained_views,
                                              next_hashed_path):
        """Applies a shardable operation on a fine grained view.

    This also updates `cache_output_nodes` when necessary.

    Args:
      operation_def: A shardable `OperationDef`.
      fine_grained_views: A tuple of `_OptimizationView.fine_grained_view`s.
      next_hashed_path: The hashed path for the currently processed
        operation_def.

    Returns:
      The resulting list of `_OptimizationView.fine_grained_view`s.
    """
        result_fine_grained_view = collections.OrderedDict()

        cache_entry_key = analyzer_cache.make_cache_entry_key(
            tf.compat.as_bytes(operation_def.label) + b'-' + next_hashed_path)

        for (dataset_idx, dataset_key) in enumerate(self._sorted_dataset_keys):
            # We use an index for the label in order to make beam labels more stable.
            infix = 'AnalysisIndex{}'.format(dataset_idx)
            if (operation_def.cache_coder and self._cache_dict.get(
                    dataset_key, {}).get(cache_entry_key) is not None):
                self._dataset_has_cache_misses[dataset_key] |= False
                decode_cache = analyzer_nodes.DecodeCache(
                    dataset_key,
                    cache_entry_key,
                    coder=operation_def.cache_coder,
                    label='DecodeCache[{}][{}]'.format(operation_def.label,
                                                       infix))
                (op_output, ) = nodes.OperationNode(decode_cache,
                                                    tuple()).outputs
            else:
                value_nodes = tuple(v[dataset_key] for v in fine_grained_views)
                (op_output, ) = nodes.OperationNode(
                    operation_def._replace(
                        label='{}[{}]'.format(operation_def.label, infix)),
                    value_nodes).outputs
                if operation_def.cache_coder:
                    self._dataset_has_cache_misses[dataset_key] = True
                    encode_cache = nodes.apply_operation(
                        analyzer_nodes.EncodeCache,
                        op_output,
                        coder=operation_def.cache_coder,
                        label='EncodeCache[{}][{}]'.format(
                            operation_def.label, infix))
                    self.cache_output_nodes[(dataset_key,
                                             cache_entry_key)] = encode_cache
            result_fine_grained_view[dataset_key] = op_output

        return result_fine_grained_view
Ejemplo n.º 2
0
    def test_cache_helpers_round_trip(self):
        base_test_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        with beam.Pipeline() as p:
            cache_pcoll_dict = {
                'dataset_key_0': {
                    'a': p | 'CreateA' >> beam.Create([b'[1, 2, 3]']),
                    'b': p | 'CreateB' >> beam.Create([b'[5]']),
                },
                'dataset_key_1': {
                    'c': p | 'CreateC' >> beam.Create([b'[9, 5, 2, 1]']),
                },
            }
            _ = cache_pcoll_dict | analyzer_cache.WriteAnalysisCacheToFS(
                base_test_dir)

        with beam.Pipeline() as p:
            read_cache = p | analyzer_cache.ReadAnalysisCacheFromFS(
                base_test_dir, list(cache_pcoll_dict.keys()))

            def assert_equal_matcher(expected_encoded):
                def _assert_equal(encoded_cache_list):
                    (encode_cache, ) = encoded_cache_list
                    self.assertEqual(expected_encoded, encode_cache)

                return _assert_equal

            beam_test_util.assert_that(read_cache['dataset_key_0'][
                analyzer_cache.make_cache_entry_key('a')],
                                       beam_test_util.equal_to([b'[1, 2, 3]']),
                                       label='AssertA')
            beam_test_util.assert_that(read_cache['dataset_key_0'][
                analyzer_cache.make_cache_entry_key('b')],
                                       assert_equal_matcher(b'[5]'),
                                       label='AssertB')
            beam_test_util.assert_that(read_cache['dataset_key_1'][
                analyzer_cache.make_cache_entry_key('c')],
                                       assert_equal_matcher(b'[9, 5, 2, 1]'),
                                       label='AssertC')
Ejemplo n.º 3
0
    def _apply_operation_on_fine_grained_view(self, operation_def,
                                              fine_grained_view):
        """Applies a shardable operation on a fine grained view.

    This also updates `cache_output_nodes` when necessary.

    Args:
      operation_def: A shardable `OperationDef`.
      fine_grained_view: A `_OptimizationView.fine_grained_view`.

    Returns:
      The resulting list of `_OptimizationView.fine_grained_view`s.
    """
        result_fine_grained_view = collections.OrderedDict()

        # TODO(b/37788560): Use a better cache key than label. A good alternative is
        # to reuse graph_tools logic to compose names that include properties and
        # fingerprint it.
        cache_entry_key = analyzer_cache.make_cache_entry_key(
            operation_def.label)
        for dataset_key in self._dataset_keys:

            # TODO(b/37788560): Add instrumentation.

            if self._cache_dict.get(dataset_key,
                                    {}).get(cache_entry_key) is not None:
                (op_output, ) = nodes.OperationNode(
                    analyzer_nodes.DecodeCache(
                        dataset_key,
                        cache_entry_key,
                        coder=operation_def.cache_coder), tuple()).outputs
            else:
                value_node = fine_grained_view[dataset_key]
                (op_output, ) = nodes.OperationNode(
                    operation_def._replace(label='{}[{}]'.format(
                        operation_def.label, dataset_key)),
                    (value_node, )).outputs
                if operation_def.cache_coder:
                    encoded_cache = nodes.apply_operation(
                        analyzer_nodes.EncodeCache,
                        op_output,
                        coder=operation_def.cache_coder,
                        label='EncodeCache[{}][{}]'.format(
                            operation_def.label, dataset_key))
                    self.cache_output_nodes[(dataset_key,
                                             cache_entry_key)] = encoded_cache
            result_fine_grained_view[dataset_key] = op_output

        return result_fine_grained_view
  def _apply_operation_on_fine_grained_view(self, operation_def,
                                            fine_grained_view,
                                            next_hashed_path):
    """Applies a shardable operation on a fine grained view.

    This also updates `cache_output_nodes` when necessary.

    Args:
      operation_def: A shardable `OperationDef`.
      fine_grained_view: A `_OptimizationView.fine_grained_view`.
      next_hashed_path: The hashed path for the currently processed
        operation_def.

    Returns:
      The resulting list of `_OptimizationView.fine_grained_view`s.
    """
    result_fine_grained_view = collections.OrderedDict()

    cache_entry_key = analyzer_cache.make_cache_entry_key(
        tf.compat.as_bytes(operation_def.label) + b'-' + next_hashed_path)

    for dataset_key in self._dataset_keys:

      if (operation_def.cache_coder and self._cache_dict.get(
          dataset_key, {}).get(cache_entry_key) is not None):
        (op_output,) = nodes.OperationNode(
            analyzer_nodes.DecodeCache(
                dataset_key,
                cache_entry_key,
                operation_def.label,
                coder=operation_def.cache_coder), tuple()).outputs
      else:
        value_node = fine_grained_view[dataset_key]
        (op_output,) = nodes.OperationNode(
            operation_def._replace(
                label='{}[{}]'.format(operation_def.label, dataset_key)),
            (value_node,)).outputs
        if operation_def.cache_coder:
          encoded_cache = nodes.apply_operation(
              analyzer_nodes.EncodeCache,
              op_output,
              coder=operation_def.cache_coder,
              label='EncodeCache[{}][{}]'.format(operation_def.label,
                                                 dataset_key))
          self.cache_output_nodes[(dataset_key,
                                   cache_entry_key)] = encoded_cache
      result_fine_grained_view[dataset_key] = op_output

    return result_fine_grained_view