def _preprocessing_fn_with_chained_ptransforms(inputs):
    class FakeChainable(tfx_namedtuple.namedtuple('FakeChainable', ['label']),
                        nodes.OperationDef):
        def __new__(cls):
            scope = tf.compat.v1.get_default_graph().get_name_scope()
            label = '{}[{}]'.format(cls.__name__, scope)
            return super(FakeChainable, cls).__new__(cls, label=label)

    with tf.compat.v1.name_scope('x'):
        input_values_node = nodes.apply_operation(analyzer_nodes.TensorSource,
                                                  tensors=[inputs['x']])
        with tf.compat.v1.name_scope('ptransform1'):
            intermediate_value_node = nodes.apply_operation(
                FakeChainable, input_values_node)
        with tf.compat.v1.name_scope('ptransform2'):
            output_value_node = nodes.apply_operation(FakeChainable,
                                                      intermediate_value_node)
        x_chained = analyzer_nodes.bind_future_as_tensor(
            output_value_node,
            analyzer_nodes.TensorInfo(tf.float32, (17, 27), None))
        return {'x_chained': x_chained}
Beispiel #2
0
def _preprocessing_fn_for_generalized_chained_ptransforms(inputs):
    class FakeChainablePartitionable(
            collections.namedtuple('FakeChainablePartitionable', ['label']),
            nodes.OperationDef):
        def __new__(cls, label=None):
            if label is None:
                scope = tf.compat.v1.get_default_graph().get_name_scope()
                label = '{}[{}]'.format(cls.__name__, scope)
            return super(FakeChainablePartitionable, cls).__new__(cls,
                                                                  label=label)

        @property
        def num_outputs(self):
            return 1

        @property
        def is_partitionable(self):
            return True

    class FakeChainableCacheable(
            collections.namedtuple('FakeChainableCacheable', ['label']),
            nodes.OperationDef):
        def __new__(cls, label=None):
            if label is None:
                scope = tf.compat.v1.get_default_graph().get_name_scope()
                label = '{}[{}]'.format(cls.__name__, scope)
            return super(FakeChainableCacheable, cls).__new__(cls, label=label)

        @property
        def num_outputs(self):
            return 1

        @property
        def is_partitionable(self):
            return True

        @property
        def cache_coder(self):
            return 'Not-a-coder-but-thats-ok!'

    class FakeChainable(collections.namedtuple('FakeChainable', ['label']),
                        nodes.OperationDef):
        def __new__(cls, label=None):
            if label is None:
                scope = tf.compat.v1.get_default_graph().get_name_scope()
                label = '{}[{}]'.format(cls.__name__, scope)
            return super(FakeChainable, cls).__new__(cls, label=label)

        @property
        def num_outputs(self):
            return 1

        @property
        def is_partitionable(self):
            return False

    with tf.compat.v1.name_scope('x'):
        input_values_node = nodes.apply_operation(analyzer_nodes.TensorSource,
                                                  tensors=[inputs['x']])
        with tf.compat.v1.name_scope('partitionable1'):
            partitionable_outputs = nodes.apply_multi_output_operation(
                FakeChainablePartitionable, input_values_node)
        with tf.compat.v1.name_scope('cacheable1'):
            intermediate_cached_value_node = nodes.apply_multi_output_operation(
                FakeChainableCacheable, *partitionable_outputs)
        with tf.compat.v1.name_scope('partitionable2'):
            partitionable_outputs = nodes.apply_multi_output_operation(
                FakeChainablePartitionable, *intermediate_cached_value_node)
        with tf.compat.v1.name_scope('cacheable2'):
            cached_value_node = nodes.apply_multi_output_operation(
                FakeChainableCacheable, *partitionable_outputs)
        with tf.compat.v1.name_scope('partitionable3'):
            output_value_node = nodes.apply_multi_output_operation(
                FakeChainablePartitionable, *cached_value_node)
        with tf.compat.v1.name_scope('merge'):
            output_value_node = nodes.apply_operation(FakeChainable,
                                                      *output_value_node)
        with tf.compat.v1.name_scope('not-cacheable'):
            non_cached_output = nodes.apply_operation(FakeChainable,
                                                      input_values_node)
        x_chained = analyzer_nodes.bind_future_as_tensor(
            output_value_node,
            analyzer_nodes.TensorInfo(tf.float32, (17, 27), False))
        x_plain = analyzer_nodes.bind_future_as_tensor(
            non_cached_output,
            analyzer_nodes.TensorInfo(tf.int64, (7, 13), False))
        return {'x_chained': x_chained, 'x_plain': x_plain}