예제 #1
0
    def testGetDotGraph(self):
        a = nodes.apply_operation(_Constant, value='a', label='Constant[a]')
        b = nodes.apply_operation(_Constant, value='b', label='Constant[b]')
        b_copy, a_copy = nodes.apply_multi_output_operation(_Swap,
                                                            a,
                                                            b,
                                                            label='Swap[0]')
        b_copy2, unused_a_copy2 = nodes.apply_multi_output_operation(
            _Swap, a_copy, b_copy, label='Swap[1]')
        dot_string = nodes.get_dot_graph([b_copy2]).to_string()
        self.WriteRenderedDotFile(dot_string)

        self.assertMultiLineEqual(
            dot_string,
            """\
digraph G {
directed=True;
node [shape=Mrecord];
"Constant[a]" [label="{_Constant|value: a|label: Constant[a]}"];
"Constant[b]" [label="{_Constant|value: b|label: Constant[b]}"];
"Swap[0]" [label="{_Swap|label: Swap[0]|{<0>0|<1>1}}"];
"Constant[a]" -> "Swap[0]";
"Constant[b]" -> "Swap[0]";
"Swap[1]" [label="{_Swap|label: Swap[1]|{<0>0|<1>1}}"];
"Swap[0]":1 -> "Swap[1]";
"Swap[0]":0 -> "Swap[1]";
}
""",
            msg='Result dot graph is:\n{}'.format(dot_string))
예제 #2
0
def _apply_analyzer(ptransform: Union[_BeamPTransform,
                                      CacheablePTransformAnalyzer],
                    *tensor_inputs: common_types.TensorType,
                    **analyzer_def_kwargs: Any) -> Tuple[tf.Tensor, ...]:
    """Applies the analyzer over the whole dataset.

  Args:
    ptransform: A class inheriting from analyzer_nodes.AnalyzerDef or
      CacheablePTransformAnalyzer that should be applied.
    *tensor_inputs: A list of input `Tensor`s or `CompositeTensor`s.
    **analyzer_def_kwargs: KW arguments to use when constructing
      analyzer_def_cls.

  Returns:
    A list of `Tensor`s representing the values of the analysis result.
  """
    input_values_node = analyzer_nodes.get_input_tensors_value_nodes(
        tensor_inputs)
    if isinstance(ptransform, CacheablePTransformAnalyzer):
        with tf.compat.v1.name_scope('make_accumulators'):
            make_accumulators_value_node = nodes.apply_multi_output_operation(
                analyzer_nodes.PTransform,
                input_values_node,
                ptransform=ptransform.make_accumulators_ptransform,
                is_partitionable=True,
                **analyzer_def_kwargs)
        with tf.compat.v1.name_scope('local_merge_accumulators'):
            cached_value_nodes = nodes.apply_multi_output_operation(
                analyzer_nodes.PTransform,
                *make_accumulators_value_node,
                ptransform=ptransform.merge_accumulators_ptransform,
                is_partitionable=True,
                cache_coder=ptransform.cache_coder,
                **analyzer_def_kwargs)
        with tf.compat.v1.name_scope('global_merge_accumulators'):
            merge_output_value_nodes = nodes.apply_multi_output_operation(
                analyzer_nodes.PTransform,
                *cached_value_nodes,
                ptransform=ptransform.merge_accumulators_ptransform,
                is_partitionable=False,
                **analyzer_def_kwargs)
        with tf.compat.v1.name_scope('extract_output'):
            output_value_nodes = nodes.apply_multi_output_operation(
                analyzer_nodes.PTransform,
                *merge_output_value_nodes,
                ptransform=ptransform.extract_output_ptransform,
                is_partitionable=False,
                **analyzer_def_kwargs)
    else:
        output_value_nodes = nodes.apply_multi_output_operation(
            analyzer_nodes.PTransform,
            input_values_node,
            ptransform=ptransform,
            is_partitionable=False,
            **analyzer_def_kwargs)
    return tuple(map(analyzer_nodes.wrap_as_tensor, output_value_nodes))
예제 #3
0
    def testTraverserComplexGraphMultipleCalls(self):
        a = nodes.apply_operation(_Constant, value='a', label='Constant[a]')
        b = nodes.apply_operation(_Constant, value='b', label='Constant[b]')
        c = nodes.apply_operation(_Constant, value='c', label='Constant[c]')
        b_copy, a_copy = nodes.apply_multi_output_operation(_Swap,
                                                            a,
                                                            b,
                                                            label='Swap')
        b_a = nodes.apply_operation(_Concat, b_copy, a_copy, label='Concat[0]')
        b_a_c = nodes.apply_operation(_Concat, b_a, c, label='Concat[1]')

        mock_visitor = mock.MagicMock()
        mock_visitor.visit.side_effect = [('a', ), ('b', ), ('b', 'a'),
                                          ('ba', ), ('c', ), ('bac', )]

        traverser = nodes.Traverser(mock_visitor)
        traverser.visit_value_node(b_a)
        traverser.visit_value_node(b_a_c)

        mock_visitor.assert_has_calls([
            mock.call.visit(_Constant('a', 'Constant[a]'), ()),
            mock.call.validate_value('a'),
            mock.call.visit(_Constant('b', 'Constant[b]'), ()),
            mock.call.validate_value('b'),
            mock.call.visit(_Swap('Swap'), ('a', 'b')),
            mock.call.validate_value('b'),
            mock.call.validate_value('a'),
            mock.call.visit(_Concat('Concat[0]'), ('b', 'a')),
            mock.call.validate_value('ba'),
            mock.call.visit(_Constant('c', 'Constant[c]'), ()),
            mock.call.validate_value('c'),
            mock.call.visit(_Concat('Concat[1]'), ('ba', 'c')),
            mock.call.validate_value('bac'),
        ])
예제 #4
0
    def testTraverserComplexGraph(self):
        a = nodes.apply_operation(_Constant, value='a')
        b = nodes.apply_operation(_Constant, value='b')
        c = nodes.apply_operation(_Constant, value='c')
        b_copy, a_copy = nodes.apply_multi_output_operation(_Swap, a, b)
        b_a = nodes.apply_operation(_Concat, b_copy, a_copy)
        b_a_c = nodes.apply_operation(_Concat, b_a, c)

        mock_visitor = mock.MagicMock()
        mock_visitor.visit.side_effect = [('a', ), ('b', ), ('b', 'a'),
                                          ('ba', ), ('c', ), ('bac', )]

        nodes.Traverser(mock_visitor).visit_value_node(b_a_c)

        mock_visitor.assert_has_calls([
            mock.call.visit(_Constant('a'), ()),
            mock.call.validate_value('a'),
            mock.call.visit(_Constant('b'), ()),
            mock.call.validate_value('b'),
            mock.call.visit(_Swap(), ('a', 'b')),
            mock.call.validate_value('b'),
            mock.call.validate_value('a'),
            mock.call.visit(_Concat(), ('b', 'a')),
            mock.call.validate_value('ba'),
            mock.call.visit(_Constant('c'), ()),
            mock.call.validate_value('c'),
            mock.call.visit(_Concat(), ('ba', 'c')),
            mock.call.validate_value('bac'),
        ])
예제 #5
0
 def testApplyOperationWithTupleOutput(self):
     a = nodes.apply_operation(_Constant, value='a')
     b = nodes.apply_operation(_Constant, value='b')
     b_copy, a_copy = nodes.apply_multi_output_operation(_Swap, a, b)
     op = b_copy.parent_operation
     self.assertEqual(b_copy.value_index, 0)
     self.assertEqual(a_copy.parent_operation, op)
     self.assertEqual(a_copy.value_index, 1)
     self.assertEqual(op.operation_def, _Swap())
     self.assertEqual(op.inputs, (a, b))
     self.assertEqual(op.outputs, (b_copy, a_copy))
예제 #6
0
    def visit(self, operation_def, input_values):
        self._validate_operation_def(operation_def)

        # TODO(b/37788560): Possibly make this generic instead of special casing the
        # ApplySavedModel operation.
        if (isinstance(operation_def, beam_nodes.ApplySavedModel)
                and operation_def.phase == 0):
            return self._visit_apply_savedmodel_operation(
                operation_def, input_values)

        # When self._cache_dict is None this means that we shouldn't do any cacheing
        # for this pipeline, and so there's no need to create any fine grained
        # views.
        if self._cache_dict is not None and operation_def.is_partitionable:
            return self._visit_partitionable_operation(operation_def,
                                                       input_values)

        if input_values and any(
                v.fine_grained_view and v.prefer_fine_grained_view
                for v in input_values):
            # We can 'flatten' the cached outputs of the parent operation since this
            # operation doesn't support partitioning.
            disaggregated_input_values = []
            for view in input_values:
                disaggregated_input_values.extend(
                    view.fine_grained_view.values())

            # Checking that all cache has the same size.
            assert len({len(value)
                        for value in disaggregated_input_values}) == 1

            next_inputs = nodes.apply_multi_output_operation(
                beam_nodes.Flatten,
                *disaggregated_input_values,
                label='FlattenCache[{}]'.format(operation_def.label))
        else:
            # Parent operation output is not cacheable, therefore we can just use
            # a flattened view.
            next_inputs = tuple(v.flattened_view for v in input_values)

        flattened_view = nodes.OperationNode(operation_def,
                                             next_inputs).outputs

        return tuple(
            _OptimizationView(  # pylint: disable=g-complex-comprehension
                prefer_fine_grained_view=False,
                flattened_view=flat,
                fine_grained_view=None,
                hashed_path=None) for flat in flattened_view)
예제 #7
0
    def visit(self, operation_def, input_values):
        self._validate_operation_def(operation_def)

        if (isinstance(operation_def, beam_nodes.ApplySavedModel)
                and operation_def.phase == 0):
            return self._visit_apply_savedmodel_operation(
                operation_def, input_values)

        if self._cache_location and operation_def.is_partitionable:
            return self._visit_partitionable_operation(operation_def,
                                                       input_values)

        if input_values and any(
                v.fine_grained_view and v.prefer_fine_grained_view
                for v in input_values):
            # We can 'flatten' the cached outputs of the parent operation since this
            # operation doesn't support partitioning.
            disaggregated_input_values = []
            for view in input_values:
                disaggregated_input_values.extend(
                    view.fine_grained_view.values())

            # Checking that all cache has the same size.
            assert len({len(value)
                        for value in disaggregated_input_values}) == 1

            next_inputs = nodes.apply_multi_output_operation(
                beam_nodes.Flatten,
                *disaggregated_input_values,
                label='FlattenCache[{}]'.format(operation_def.label))
        else:
            # Parent operation output is not cacheable, therefore we can just use
            # a flattened view.
            next_inputs = tuple(v.flattened_view for v in input_values)

        flattened_view = nodes.OperationNode(operation_def,
                                             next_inputs).outputs

        return tuple(
            _OptimizationView(prefer_fine_grained_view=False,
                              flattened_view=flat,
                              fine_grained_view=None)
            for flat in flattened_view)
예제 #8
0
 def _add_flatten_placeholder(self, operation_def, input_values):
   assert isinstance(operation_def, analyzer_nodes.ExtractCombineMergeOutputs)
   parent = input_values[0].parent_operation
   assert isinstance(parent.operation_def,
                     analyzer_nodes.CacheableCombineMerge)
   packed_combine = self._get_packed_combine(
       parent.operation_def, parent.inputs)
   # For the current combine, create the ExtractFromDict node which
   # extracts the accumulator corresponding to this combine from the
   # packed combine output.
   extract_dict_node = nodes.apply_operation(
       beam_nodes.ExtractFromDict,
       packed_combine,
       keys=parent.operation_def.label,
       label='ExtractFromDict[{}]'.format(parent.operation_def.label))
   # Create the new ExtractPackedCombineMergeOutputs node.
   return nodes.apply_multi_output_operation(
       analyzer_nodes.ExtractPackedCombineMergeOutputs,
       extract_dict_node,
       output_tensor_info_list=operation_def.output_tensor_infos,
       label='ExtractPackedCombineMergeOutputs[{}]'.format(
           parent.operation_def.label)
   )
예제 #9
0
def _preprocessing_fn_for_generalized_chained_ptransforms(inputs):
    class FakeChainablePartitionable(
            collections.namedtuple('FakeChainablePartitionable', ['label']),
            nodes.OperationDef):
        def __new__(cls, label=None):
            if label is None:
                scope = tf.compat.v1.get_default_graph().get_name_scope()
                label = '{}[{}]'.format(cls.__name__, scope)
            return super(FakeChainablePartitionable, cls).__new__(cls,
                                                                  label=label)

        @property
        def num_outputs(self):
            return 1

        @property
        def is_partitionable(self):
            return True

    class FakeChainableCacheable(
            collections.namedtuple('FakeChainableCacheable', ['label']),
            nodes.OperationDef):
        def __new__(cls, label=None):
            if label is None:
                scope = tf.compat.v1.get_default_graph().get_name_scope()
                label = '{}[{}]'.format(cls.__name__, scope)
            return super(FakeChainableCacheable, cls).__new__(cls, label=label)

        @property
        def num_outputs(self):
            return 1

        @property
        def is_partitionable(self):
            return True

        @property
        def cache_coder(self):
            return 'Not-a-coder-but-thats-ok!'

    class FakeChainable(collections.namedtuple('FakeChainable', ['label']),
                        nodes.OperationDef):
        def __new__(cls, label=None):
            if label is None:
                scope = tf.compat.v1.get_default_graph().get_name_scope()
                label = '{}[{}]'.format(cls.__name__, scope)
            return super(FakeChainable, cls).__new__(cls, label=label)

        @property
        def num_outputs(self):
            return 1

        @property
        def is_partitionable(self):
            return False

    with tf.compat.v1.name_scope('x'):
        input_values_node = nodes.apply_operation(analyzer_nodes.TensorSource,
                                                  tensors=[inputs['x']])
        with tf.compat.v1.name_scope('partitionable1'):
            partitionable_outputs = nodes.apply_multi_output_operation(
                FakeChainablePartitionable, input_values_node)
        with tf.compat.v1.name_scope('cacheable1'):
            intermediate_cached_value_node = nodes.apply_multi_output_operation(
                FakeChainableCacheable, *partitionable_outputs)
        with tf.compat.v1.name_scope('partitionable2'):
            partitionable_outputs = nodes.apply_multi_output_operation(
                FakeChainablePartitionable, *intermediate_cached_value_node)
        with tf.compat.v1.name_scope('cacheable2'):
            cached_value_node = nodes.apply_multi_output_operation(
                FakeChainableCacheable, *partitionable_outputs)
        with tf.compat.v1.name_scope('partitionable3'):
            output_value_node = nodes.apply_multi_output_operation(
                FakeChainablePartitionable, *cached_value_node)
        with tf.compat.v1.name_scope('merge'):
            output_value_node = nodes.apply_operation(FakeChainable,
                                                      *output_value_node)
        with tf.compat.v1.name_scope('not-cacheable'):
            non_cached_output = nodes.apply_operation(FakeChainable,
                                                      input_values_node)
        x_chained = analyzer_nodes.bind_future_as_tensor(
            output_value_node,
            analyzer_nodes.TensorInfo(tf.float32, (17, 27), False))
        x_plain = analyzer_nodes.bind_future_as_tensor(
            non_cached_output,
            analyzer_nodes.TensorInfo(tf.int64, (7, 13), False))
        return {'x_chained': x_chained, 'x_plain': x_plain}
예제 #10
0
  def _visit_partitionable_operation(self, operation_def, upstream_views):
    # TODO(b/37788560) Possibly support partitionable operations with multiple
    # inputs.
    (upstream_view,) = upstream_views
    prefer_fine_grained_view = (
        upstream_view.prefer_fine_grained_view or
        upstream_view.fine_grained_view and
        operation_def.cache_coder is not None)

    if upstream_view.fine_grained_view:
      value_nodes = collections.OrderedDict()
      for key in self._dataset_keys:

        if operation_def.cache_coder is not None:
          # TODO(b/37788560): Add instrumentation.
          # TODO(b/37788560): Use a better cache key than label. A good
          # alternative is to reuse graph_tools logic to compose names that
          # include properties and fingerprint it.
          cache_file_path = analyzer_cache.make_cache_file_path(
              key, operation_def.label)
          # TODO(b/37788560): Come up with a more abstract way to do this that
          # also ensures concistency.
          pattern = '{}-00000*.gz'.format(
              os.path.join(self._cache_location.input_cache_dir,
                           cache_file_path))
          try:
            if tf.gfile.Glob(pattern):
              op_outputs = nodes.apply_multi_output_operation(
                  analyzer_nodes.ReadCache,
                  path=cache_file_path,
                  coder=operation_def.cache_coder,
                  label='ReadCache[{}][{}]'.format(operation_def.label, key))
              value_nodes[key] = op_outputs
              continue
          except tf.errors.NotFoundError:
            pass
        else:
          cache_file_path = None

        values = upstream_view.fine_grained_view[key]
        op_outputs = nodes.OperationNode(
            operation_def._replace(
                label='{}[{}]'.format(operation_def.label, key)),
            (values,)).outputs
        if cache_file_path is not None:
          op_outputs = nodes.apply_multi_output_operation(
              analyzer_nodes.WriteCache,
              *op_outputs,
              path=cache_file_path,
              coder=operation_def.cache_coder,
              label='WriteCache[{}][{}]'.format(operation_def.label, key))
        value_nodes[key] = op_outputs

      fine_grained_views = (
          [collections.OrderedDict()] * operation_def.num_outputs)
      for key in self._dataset_keys:
        for idx in range(operation_def.num_outputs):
          fine_grained_views[idx][key] = value_nodes[key][idx]
    else:
      fine_grained_views = (None,) * operation_def.num_outputs

    flattened_views = nodes.OperationNode(
        operation_def, (upstream_view.flattened_view,)).outputs

    return tuple(
        _OptimizationView(
            prefer_fine_grained_view=prefer_fine_grained_view,
            flattened_view=flat,
            fine_grained_view=fine)
        for flat, fine in zip(flattened_views, fine_grained_views))
예제 #11
0
    def _visit_partitionable_operation(self, operation_def, upstream_views):
        (upstream_view, ) = upstream_views
        prefer_fine_grained_view = (upstream_view.prefer_fine_grained_view
                                    or upstream_view.fine_grained_view
                                    and operation_def.cache_coder is not None)

        if upstream_view.fine_grained_view:
            value_nodes = collections.OrderedDict()
            for key in self._dataset_keys:

                if operation_def.cache_coder is not None:
                    cache_file_path = analyzer_cache.make_cache_file_path(
                        key, operation_def.label)
                    pattern = '{}-00000*.gz'.format(
                        os.path.join(self._cache_location.input_cache_dir,
                                     cache_file_path))
                    try:
                        if tf.gfile.Glob(pattern):
                            op_outputs = nodes.apply_multi_output_operation(
                                analyzer_nodes.ReadCache,
                                path=cache_file_path,
                                coder=operation_def.cache_coder,
                                label='ReadCache[{}][{}]'.format(
                                    operation_def.label, key))
                            value_nodes[key] = op_outputs
                            continue
                    except tf.errors.NotFoundError:
                        pass
                else:
                    cache_file_path = None

                values = upstream_view.fine_grained_view[key]
                op_outputs = nodes.OperationNode(
                    operation_def._replace(
                        label='{}[{}]'.format(operation_def.label, key)),
                    (values, )).outputs
                if cache_file_path is not None:
                    op_outputs = nodes.apply_multi_output_operation(
                        analyzer_nodes.WriteCache,
                        *op_outputs,
                        path=cache_file_path,
                        coder=operation_def.cache_coder,
                        label='WriteCache[{}][{}]'.format(
                            operation_def.label, key))
                value_nodes[key] = op_outputs

            fine_grained_views = ([collections.OrderedDict()] *
                                  operation_def.num_outputs)
            for key in self._dataset_keys:
                for idx in range(operation_def.num_outputs):
                    fine_grained_views[idx][key] = value_nodes[key][idx]
        else:
            fine_grained_views = (None, ) * operation_def.num_outputs

        flattened_views = nodes.OperationNode(
            operation_def, (upstream_view.flattened_view, )).outputs

        return tuple(
            _OptimizationView(
                prefer_fine_grained_view=prefer_fine_grained_view,
                flattened_view=flat,
                fine_grained_view=fine)
            for flat, fine in zip(flattened_views, fine_grained_views))