Example #1
0
 def test_returns_single_called_graph_after_resolving_multiple_variables(
         self):
     ref_to_int = building_blocks.Reference('var', tf.int32)
     first_tf_id_type = computation_types.TensorType(tf.int32)
     first_tf_id = building_block_factory.create_compiled_identity(
         first_tf_id_type)
     called_tf_id = building_blocks.Call(first_tf_id, ref_to_int)
     ref_to_call = building_blocks.Reference('call',
                                             called_tf_id.type_signature)
     second_tf_id_type = computation_types.TensorType(tf.int32)
     second_tf_id = building_block_factory.create_compiled_identity(
         second_tf_id_type)
     second_called = building_blocks.Call(second_tf_id, ref_to_call)
     ref_to_second_call = building_blocks.Reference(
         'second_call', called_tf_id.type_signature)
     block_locals = [('call', called_tf_id), ('second_call', second_called)]
     block = building_blocks.Block(block_locals, ref_to_second_call)
     tf_representing_block, _ = transformations.create_tensorflow_representing_block(
         block)
     self.assertEqual(tf_representing_block.type_signature,
                      block.type_signature)
     self.assertIsInstance(tf_representing_block, building_blocks.Call)
     self.assertIsInstance(tf_representing_block.function,
                           building_blocks.CompiledComputation)
     self.assertIsInstance(tf_representing_block.argument,
                           building_blocks.Reference)
     self.assertEqual(tf_representing_block.argument.name, 'var')
    def test_replaces_lambda_to_named_tuple_of_called_graphs_with_tf_of_same_type(
            self):
        int_identity_tf_block = building_block_factory.create_compiled_identity(
            tf.int32)
        float_identity_tf_block = building_block_factory.create_compiled_identity(
            tf.float32)
        tuple_ref = building_blocks.Reference('x', [tf.int32, tf.float32])
        selected_int = building_blocks.Selection(tuple_ref, index=0)
        selected_float = building_blocks.Selection(tuple_ref, index=1)

        called_int_tf_block = building_blocks.Call(int_identity_tf_block,
                                                   selected_int)
        called_float_tf_block = building_blocks.Call(float_identity_tf_block,
                                                     selected_float)
        tuple_of_called_graphs = building_blocks.Tuple([
            ('a', called_int_tf_block), ('b', called_float_tf_block)
        ])
        lambda_wrapper = building_blocks.Lambda('x', [tf.int32, tf.float32],
                                                tuple_of_called_graphs)

        parsed, modified = parse_tff_to_tf(lambda_wrapper)
        exec_lambda = computation_wrapper_instances.building_block_to_computation(
            lambda_wrapper)
        exec_tf = computation_wrapper_instances.building_block_to_computation(
            parsed)

        self.assertIsInstance(parsed, building_blocks.CompiledComputation)
        self.assertTrue(modified)
        self.assertEqual(parsed.type_signature, lambda_wrapper.type_signature)
        self.assertEqual(exec_lambda([13, 14.]), exec_tf([13, 14.]))
Example #3
0
 def test_returned_tensorflow_executes_correctly_with_no_unbound_refs(self):
     concrete_int_type = computation_types.TensorType(tf.int32)
     concrete_int = building_block_factory.create_tensorflow_constant(
         concrete_int_type, 1)
     first_tf_id_type = computation_types.TensorType(tf.int32)
     first_tf_id = building_block_factory.create_compiled_identity(
         first_tf_id_type)
     called_tf_id = building_blocks.Call(first_tf_id, concrete_int)
     ref_to_call = building_blocks.Reference('call',
                                             called_tf_id.type_signature)
     second_tf_id_type = computation_types.TensorType(tf.int32)
     second_tf_id = building_block_factory.create_compiled_identity(
         second_tf_id_type)
     second_called = building_blocks.Call(second_tf_id, ref_to_call)
     ref_to_second_call = building_blocks.Reference(
         'second_call', called_tf_id.type_signature)
     block_locals = [('call', called_tf_id), ('second_call', second_called)]
     block = building_blocks.Block(
         block_locals,
         building_blocks.Tuple([ref_to_second_call, ref_to_second_call]))
     tf_representing_block, _ = transformations.create_tensorflow_representing_block(
         block)
     result = test_utils.run_tensorflow(
         tf_representing_block.function.proto)
     self.assertAllEqual(result, [1, 1])
    def test_replaces_lambda_to_named_tuple_of_called_graphs_with_tf_of_same_type(
            self):
        int_tensor_type = computation_types.TensorType(tf.int32)
        int_identity_tf_block = building_block_factory.create_compiled_identity(
            int_tensor_type)
        float_tensor_type = computation_types.TensorType(tf.float32)
        float_identity_tf_block = building_block_factory.create_compiled_identity(
            float_tensor_type)
        tuple_ref = building_blocks.Reference('x', [tf.int32, tf.float32])
        selected_int = building_blocks.Selection(tuple_ref, index=0)
        selected_float = building_blocks.Selection(tuple_ref, index=1)

        called_int_tf_block = building_blocks.Call(int_identity_tf_block,
                                                   selected_int)
        called_float_tf_block = building_blocks.Call(float_identity_tf_block,
                                                     selected_float)
        tuple_of_called_graphs = building_blocks.Struct([
            ('a', called_int_tf_block), ('b', called_float_tf_block)
        ])
        lambda_wrapper = building_blocks.Lambda('x', [tf.int32, tf.float32],
                                                tuple_of_called_graphs)

        parsed, modified = parse_tff_to_tf(lambda_wrapper)
        exec_lambda = computation_wrapper_instances.building_block_to_computation(
            lambda_wrapper)
        exec_tf = computation_wrapper_instances.building_block_to_computation(
            parsed)

        self.assertIsInstance(parsed, building_blocks.CompiledComputation)
        self.assertTrue(modified)
        # TODO(b/157172423): change to assertEqual when Py container is preserved.
        parsed.type_signature.check_equivalent_to(
            lambda_wrapper.type_signature)
        self.assertEqual(exec_lambda([13, 14.]), exec_tf([13, 14.]))
Example #5
0
 def test_executes_correctly_with_tuple_in_result(self):
     ref_to_int = building_blocks.Reference('var', tf.int32)
     first_tf_id_type = computation_types.TensorType(tf.int32)
     first_tf_id = building_block_factory.create_compiled_identity(
         first_tf_id_type)
     called_tf_id = building_blocks.Call(first_tf_id, ref_to_int)
     ref_to_call = building_blocks.Reference('call',
                                             called_tf_id.type_signature)
     second_tf_id_type = computation_types.TensorType(tf.int32)
     second_tf_id = building_block_factory.create_compiled_identity(
         second_tf_id_type)
     second_called = building_blocks.Call(second_tf_id, ref_to_call)
     ref_to_second_call = building_blocks.Reference(
         'second_call', called_tf_id.type_signature)
     block_locals = [('call', called_tf_id), ('second_call', second_called)]
     block = building_blocks.Block(
         block_locals,
         building_blocks.Tuple([ref_to_second_call, ref_to_second_call]))
     tf_representing_block, _ = transformations.create_tensorflow_representing_block(
         block)
     result_ones = test_utils.run_tensorflow(
         tf_representing_block.function.proto, 1)
     self.assertAllEqual(result_ones, [1, 1])
     result_zeros = test_utils.run_tensorflow(
         tf_representing_block.function.proto, 0)
     self.assertAllEqual(result_zeros, [0, 0])
Example #6
0
 def test_returns_correct_structure_with_no_unbound_references(self):
     concrete_int_type = computation_types.TensorType(tf.int32)
     concrete_int = building_block_factory.create_tensorflow_constant(
         concrete_int_type, 1)
     first_tf_id_type = computation_types.TensorType(tf.int32)
     first_tf_id = building_block_factory.create_compiled_identity(
         first_tf_id_type)
     called_tf_id = building_blocks.Call(first_tf_id, concrete_int)
     ref_to_call = building_blocks.Reference('call',
                                             called_tf_id.type_signature)
     second_tf_id_type = computation_types.TensorType(tf.int32)
     second_tf_id = building_block_factory.create_compiled_identity(
         second_tf_id_type)
     second_called = building_blocks.Call(second_tf_id, ref_to_call)
     ref_to_second_call = building_blocks.Reference(
         'second_call', called_tf_id.type_signature)
     block_locals = [('call', called_tf_id), ('second_call', second_called)]
     block = building_blocks.Block(
         block_locals,
         building_blocks.Tuple([ref_to_second_call, ref_to_second_call]))
     tf_representing_block, _ = transformations.create_tensorflow_representing_block(
         block)
     self.assertEqual(tf_representing_block.type_signature,
                      block.type_signature)
     self.assertIsInstance(tf_representing_block, building_blocks.Call)
     self.assertIsInstance(tf_representing_block.function,
                           building_blocks.CompiledComputation)
     self.assertIsNone(tf_representing_block.argument)
Example #7
0
 def test_returns_true_for_compiled_computations_with_different_names(self):
     tensor_type = computation_types.TensorType(tf.int32)
     compiled_1 = building_block_factory.create_compiled_identity(
         tensor_type, 'a')
     compiled_2 = building_block_factory.create_compiled_identity(
         tensor_type, 'b')
     self.assertTrue(tree_analysis.trees_equal(compiled_1, compiled_2))
 def test_returns_false_for_compiled_computations_with_different_types(
         self):
     compiled_1 = building_block_factory.create_compiled_identity(
         tf.int32, 'a')
     compiled_2 = building_block_factory.create_compiled_identity(
         tf.float32, 'a')
     self.assertFalse(tree_analysis.trees_equal(compiled_1, compiled_2))
Example #9
0
 def test_returns_correct_structure_with_tuple_in_result(self):
     ref_to_int = building_blocks.Reference('var', tf.int32)
     first_tf_id = building_block_factory.create_compiled_identity(tf.int32)
     called_tf_id = building_blocks.Call(first_tf_id, ref_to_int)
     ref_to_call = building_blocks.Reference('call',
                                             called_tf_id.type_signature)
     second_tf_id = building_block_factory.create_compiled_identity(
         tf.int32)
     second_called = building_blocks.Call(second_tf_id, ref_to_call)
     ref_to_second_call = building_blocks.Reference(
         'second_call', called_tf_id.type_signature)
     block_locals = [('call', called_tf_id), ('second_call', second_called)]
     block = building_blocks.Block(
         block_locals,
         building_blocks.Tuple([ref_to_second_call, ref_to_second_call]))
     tf_representing_block, _ = transformations.create_tensorflow_representing_block(
         block)
     self.assertEqual(tf_representing_block.type_signature,
                      block.type_signature)
     self.assertIsInstance(tf_representing_block, building_blocks.Call)
     self.assertIsInstance(tf_representing_block.function,
                           building_blocks.CompiledComputation)
     self.assertIsInstance(tf_representing_block.argument,
                           building_blocks.Reference)
     self.assertEqual(tf_representing_block.argument.name, 'var')
Example #10
0
 def test_basic_functionality_of_compiled_computation_class(self):
     x = building_block_factory.create_compiled_identity(tf.int32, 'a')
     self.assertEqual(x.type_signature.compact_representation(),
                      '(int32 -> int32)')
     self.assertIsInstance(x.proto, pb.Computation)
     self.assertEqual(x.name, 'a')
     self.assertTrue(
         repr(x),
         'CompiledComputation(\'a\', FunctionType(TensorType(tf.int32), TensorType(tf.int32)))'
     )
     self.assertTrue(x.compact_representation(), 'comp#a')
     y = building_block_factory.create_compiled_identity(tf.int32)
     self._serialize_deserialize_roundtrip_test(y)
 def test_should_transform_tf_computation(self):
     tuple_type = computation_types.TensorType(tf.int32)
     compiled_computation = building_block_factory.create_compiled_identity(
         tuple_type)
     self.assertTrue(
         compiled_computation_transformations.RaiseOnDisallowedOp(
             frozenset()).should_transform(compiled_computation))
Example #12
0
 def test_single_tensorflow_node_count_agrees_with_node_count(self):
   integer_identity = building_block_factory.create_compiled_identity(tf.int32)
   node_tf_op_count = building_block_analysis.count_tensorflow_ops_in(
       integer_identity)
   tree_tf_op_count = tree_analysis.count_tensorflow_ops_under(
       integer_identity)
   self.assertEqual(node_tf_op_count, tree_tf_op_count)
Example #13
0
    def test_replaces_lambda_to_called_graph_on_tuple_of_selections_from_arg_with_tf_of_same_type(
            self):
        identity_tf_block = building_block_factory.create_compiled_identity(
            [tf.int32, tf.bool])
        tuple_ref = building_blocks.Reference('x',
                                              [tf.int32, tf.float32, tf.bool])
        selected_int = building_blocks.Selection(tuple_ref, index=0)
        selected_bool = building_blocks.Selection(tuple_ref, index=2)
        created_tuple = building_blocks.Tuple([selected_int, selected_bool])
        called_tf_block = building_blocks.Call(identity_tf_block,
                                               created_tuple)
        lambda_wrapper = building_blocks.Lambda(
            'x', [tf.int32, tf.float32, tf.bool], called_tf_block)

        parsed, modified = parse_tff_to_tf(lambda_wrapper)
        exec_lambda = computation_wrapper_instances.building_block_to_computation(
            lambda_wrapper)
        exec_tf = computation_wrapper_instances.building_block_to_computation(
            parsed)

        self.assertIsInstance(parsed, building_blocks.CompiledComputation)
        self.assertTrue(modified)
        self.assertEqual(parsed.type_signature, lambda_wrapper.type_signature)
        exec_lambda = computation_wrapper_instances.building_block_to_computation(
            lambda_wrapper)
        exec_tf = computation_wrapper_instances.building_block_to_computation(
            parsed)
        self.assertEqual(exec_lambda([7, 8., True]), exec_tf([7, 8., True]))
 def test_raises_with_naked_graph_as_block_local(self):
   graph = building_block_factory.create_compiled_identity(tf.int32)
   block_locals = [('graph', graph)]
   ref_to_graph = building_blocks.Reference('graph', graph.type_signature)
   block = building_blocks.Block(block_locals, ref_to_graph)
   with self.assertRaises(ValueError):
     compiler_transformations.create_tensorflow_representing_block(block)
Example #15
0
  def test_returns_string_for_compiled_computation(self):
    tensor_type = computation_types.TensorType(tf.int32)
    comp = building_block_factory.create_compiled_identity(tensor_type, 'a')

    self.assertEqual(comp.compact_representation(), 'comp#a')
    self.assertEqual(comp.formatted_representation(), 'comp#a')
    self.assertEqual(comp.structural_representation(), 'Compiled(a)')
Example #16
0
    def test_replaces_lambda_to_called_graph_on_tuple_of_selections_from_arg_with_tf_of_same_type_with_names(
            self):
        identity_tf_block_type = computation_types.StructType(
            [tf.int32, tf.bool])
        identity_tf_block = building_block_factory.create_compiled_identity(
            identity_tf_block_type)
        tuple_ref = building_blocks.Reference('x', [('a', tf.int32),
                                                    ('b', tf.float32),
                                                    ('c', tf.bool)])
        selected_int = building_blocks.Selection(tuple_ref, index=0)
        selected_bool = building_blocks.Selection(tuple_ref, index=2)
        created_tuple = building_blocks.Struct([selected_int, selected_bool])
        called_tf_block = building_blocks.Call(identity_tf_block,
                                               created_tuple)
        lambda_wrapper = building_blocks.Lambda('x', [('a', tf.int32),
                                                      ('b', tf.float32),
                                                      ('c', tf.bool)],
                                                called_tf_block)

        parsed, modified = parse_tff_to_tf(lambda_wrapper)

        self.assertIsInstance(parsed, building_blocks.CompiledComputation)
        self.assertTrue(modified)
        self.assertEqual(parsed.type_signature, lambda_wrapper.type_signature)
        result = test_utils.run_tensorflow(parsed.proto, {
            'a': 9,
            'b': 10.0,
            'c': False,
        })
        self.assertEqual(structure.Struct([(None, 9), (None, False)]), result)
Example #17
0
 def federated_reduce(arg):
     x = arg[0]
     zero = arg[1]
     op = arg[2]
     identity = building_block_factory.create_compiled_identity(
         op.type_signature.result)
     return intrinsics.federated_aggregate(x, zero, op, op, identity)
Example #18
0
 def test_tensorflow_op_count_doubles_number_of_ops_in_two_tuple(self):
   integer_identity = building_block_factory.create_compiled_identity(tf.int32)
   node_tf_op_count = building_block_analysis.count_tensorflow_ops_in(
       integer_identity)
   tf_tuple = building_blocks.Tuple([integer_identity, integer_identity])
   tree_tf_op_count = tree_analysis.count_tensorflow_ops_under(tf_tuple)
   self.assertEqual(tree_tf_op_count, 2 * node_tf_op_count)
Example #19
0
    def test_replaces_lambda_to_called_graph_on_tuple_of_selections_from_arg_with_tf_of_same_type(
            self):
        identity_tf_block_type = computation_types.StructType(
            [tf.int32, tf.bool])
        identity_tf_block = building_block_factory.create_compiled_identity(
            identity_tf_block_type)
        tuple_ref = building_blocks.Reference('x',
                                              [tf.int32, tf.float32, tf.bool])
        selected_int = building_blocks.Selection(tuple_ref, index=0)
        selected_bool = building_blocks.Selection(tuple_ref, index=2)
        created_tuple = building_blocks.Struct([selected_int, selected_bool])
        called_tf_block = building_blocks.Call(identity_tf_block,
                                               created_tuple)
        lambda_wrapper = building_blocks.Lambda(
            'x', [tf.int32, tf.float32, tf.bool], called_tf_block)

        parsed, modified = parse_tff_to_tf(lambda_wrapper)

        self.assertIsInstance(parsed, building_blocks.CompiledComputation)
        self.assertTrue(modified)
        # TODO(b/157172423): change to assertEqual when Py container is preserved.
        parsed.type_signature.check_equivalent_to(
            lambda_wrapper.type_signature)
        result = test_utils.run_tensorflow(parsed.proto, [7, 8.0, True])
        self.assertEqual(structure.Struct([(None, 7), (None, True)]), result)
Example #20
0
    def test_replaces_lambda_to_called_graph_on_selection_from_arg_with_tf_of_same_type_with_names(
            self):
        identity_tf_block = building_block_factory.create_compiled_identity(
            tf.int32)
        tuple_ref = building_blocks.Reference('x', [('a', tf.int32),
                                                    ('b', tf.float32)])
        selected_int = building_blocks.Selection(tuple_ref, index=0)
        called_tf_block = building_blocks.Call(identity_tf_block, selected_int)
        lambda_wrapper = building_blocks.Lambda('x', [('a', tf.int32),
                                                      ('b', tf.float32)],
                                                called_tf_block)

        parsed, modified = parse_tff_to_tf(lambda_wrapper)
        exec_lambda = computation_wrapper_instances.building_block_to_computation(
            lambda_wrapper)
        exec_tf = computation_wrapper_instances.building_block_to_computation(
            parsed)

        self.assertIsInstance(parsed, building_blocks.CompiledComputation)
        self.assertTrue(modified)
        self.assertEqual(parsed.type_signature, lambda_wrapper.type_signature)
        self.assertEqual(exec_lambda({
            'a': 5,
            'b': 6.
        }), exec_tf({
            'a': 5,
            'b': 6.
        }))
Example #21
0
 def test_does_not_reduce_no_unnecessary_ops(self):
   comp = building_block_factory.create_compiled_identity(tf.int32)
   pruned = building_blocks.CompiledComputation(
       proto_transformations.prune_tensorflow_proto(comp.proto))
   ops_before = building_block_analysis.count_tensorflow_ops_in(comp)
   ops_after = building_block_analysis.count_tensorflow_ops_in(pruned)
   self.assertEqual(ops_before, ops_after)
    def test_replaces_lambda_to_selection_from_called_graph_with_tf_of_same_type(
            self):
        identity_tf_block_type = computation_types.StructType(
            [tf.int32, tf.float32])
        identity_tf_block = building_block_factory.create_compiled_identity(
            identity_tf_block_type)
        tuple_ref = building_blocks.Reference('x', [tf.int32, tf.float32])
        called_tf_block = building_blocks.Call(identity_tf_block, tuple_ref)
        selection_from_call = building_blocks.Selection(called_tf_block,
                                                        index=1)
        lambda_wrapper = building_blocks.Lambda('x', [tf.int32, tf.float32],
                                                selection_from_call)

        parsed, modified = parse_tff_to_tf(lambda_wrapper)
        exec_lambda = computation_wrapper_instances.building_block_to_computation(
            lambda_wrapper)
        exec_tf = computation_wrapper_instances.building_block_to_computation(
            parsed)

        self.assertIsInstance(parsed, building_blocks.CompiledComputation)
        self.assertTrue(modified)
        # TODO(b/157172423): change to assertEqual when Py container is preserved.
        parsed.type_signature.check_equivalent_to(
            lambda_wrapper.type_signature)
        self.assertEqual(exec_lambda([0, 1.]), exec_tf([0, 1.]))
Example #23
0
def _create_complex_computation():
    tensor_type = computation_types.TensorType(tf.int32)
    compiled = building_block_factory.create_compiled_identity(
        tensor_type, 'a')
    federated_type = computation_types.FederatedType(tf.int32,
                                                     placements.SERVER)
    arg_ref = building_blocks.Reference('arg', federated_type)
    bindings = []
    results = []

    def _bind(name, value):
        bindings.append((name, value))
        return building_blocks.Reference(name, value.type_signature)

    for i in range(2):
        called_federated_broadcast = building_block_factory.create_federated_broadcast(
            arg_ref)
        called_federated_map = building_block_factory.create_federated_map(
            compiled, _bind(f'broadcast_{i}', called_federated_broadcast))
        called_federated_mean = building_block_factory.create_federated_mean(
            _bind(f'map_{i}', called_federated_map), None)
        results.append(_bind(f'mean_{i}', called_federated_mean))
    result = building_blocks.Struct(results)
    block = building_blocks.Block(bindings, result)
    return building_blocks.Lambda('arg', tf.int32, block)
 def test_transform_compiled_computation_returns_compiled_computation_with_id(
         self):
     tuple_type = computation_types.TensorType(tf.int32)
     compiled_computation = building_block_factory.create_compiled_identity(
         tuple_type)
     add_ids = compiled_computation_transformations.AddUniqueIDs()
     with self.subTest('first_comp_non_zero_id'):
         first_transformed_comp, mutated = add_ids.transform(
             compiled_computation)
         self.assertTrue(mutated)
         self.assertIsInstance(first_transformed_comp,
                               building_blocks.CompiledComputation)
         self.assertTrue(
             first_transformed_comp.proto.tensorflow.HasField('cache_key'))
         self.assertNotEqual(
             first_transformed_comp.proto.tensorflow.cache_key.id, 0)
     with self.subTest('second_comp_same_id'):
         second_transformed_comp, mutated = add_ids.transform(
             compiled_computation)
         self.assertTrue(mutated)
         self.assertIsInstance(second_transformed_comp,
                               building_blocks.CompiledComputation)
         self.assertTrue(
             second_transformed_comp.proto.tensorflow.HasField('cache_key'))
         self.assertNotEqual(
             second_transformed_comp.proto.tensorflow.cache_key.id, 0)
         self.assertEqual(
             first_transformed_comp.proto.tensorflow.cache_key.id,
             second_transformed_comp.proto.tensorflow.cache_key.id)
     with self.subTest('restart_transformation_same_id'):
         # Test that the sequence ids are the same if we run a new compiler pass.
         # With compiler running inside the `invoke` call, we need to ensure
         # running different computations doesn't produce the same ids.
         add_ids = compiled_computation_transformations.AddUniqueIDs()
         third_transformed_comp, mutated = add_ids.transform(
             compiled_computation)
         self.assertTrue(mutated)
         self.assertTrue(
             third_transformed_comp.proto.tensorflow.HasField('cache_key'))
         self.assertNotEqual(
             third_transformed_comp.proto.tensorflow.cache_key.id, 0)
         self.assertEqual(
             first_transformed_comp.proto.tensorflow.cache_key.id,
             third_transformed_comp.proto.tensorflow.cache_key.id)
     with self.subTest('different_computation_different_id'):
         different_compiled_computation = _create_compiled_computation(
             lambda x: x + tf.constant(1.0),
             computation_types.TensorType(tf.float32))
         different_transformed_comp, mutated = add_ids.transform(
             different_compiled_computation)
         self.assertTrue(mutated)
         self.assertTrue(
             different_transformed_comp.proto.tensorflow.HasField(
                 'cache_key'))
         self.assertNotEqual(
             different_transformed_comp.proto.tensorflow.cache_key.id, 0)
         self.assertNotEqual(
             first_transformed_comp.proto.tensorflow.cache_key.id,
             different_transformed_comp.proto.tensorflow.cache_key.id)
Example #25
0
 def test_returns_tf_computation_block_with_compiled_comp(self):
     concrete_int_type = computation_types.TensorType(tf.int32)
     tf_identity = building_block_factory.create_compiled_identity(
         concrete_int_type)
     unused_int = building_block_factory.create_tensorflow_constant(
         concrete_int_type, 1)
     block_to_id = building_blocks.Block([('x', unused_int)], tf_identity)
     self.assert_compiles_to_tensorflow(block_to_id)
 def test_transform_no_disallowed_ops(self):
     tuple_type = computation_types.TensorType(tf.int32)
     compiled_computation = building_block_factory.create_compiled_identity(
         tuple_type)
     disallowed_op_names = frozenset(['ShardedFilename'])
     _, mutated = compiled_computation_transformations.RaiseOnDisallowedOp(
         disallowed_op_names).transform(compiled_computation)
     self.assertFalse(mutated)
Example #27
0
 def test_returns_string_for_compiled_computation(self):
   comp = building_block_factory.create_compiled_identity(tf.int32, 'a')
   compact_string = comp.compact_representation()
   self.assertEqual(compact_string, 'comp#a')
   formatted_string = comp.formatted_representation()
   self.assertEqual(formatted_string, 'comp#a')
   structural_string = comp.structural_representation()
   self.assertEqual(structural_string, 'Compiled(a)')
 def test_should_transform_compiled_computation(self):
     tuple_type = computation_types.TensorType(tf.int32)
     compiled_computation = building_block_factory.create_compiled_identity(
         tuple_type)
     config = tf.compat.v1.ConfigProto()
     tf_optimizer = compiled_computation_transformations.TensorFlowOptimizer(
         config)
     self.assertTrue(tf_optimizer.should_transform(compiled_computation))
 def test_transform_only_allowed_ops(self):
     tuple_type = computation_types.TensorType(tf.int32)
     compiled_computation = building_block_factory.create_compiled_identity(
         tuple_type)
     allowed_op_names = frozenset(
         ['Const', 'PartitionedCall', 'Identity', 'Placeholder'])
     _, mutated = compiled_computation_transformations.VerifyAllowedOps(
         allowed_op_names).transform(compiled_computation)
     self.assertFalse(mutated)
 def test_transform_disallowed_ops(self):
     tuple_type = computation_types.TensorType(tf.int32)
     compiled_computation = building_block_factory.create_compiled_identity(
         tuple_type)
     disallowed_op_names = frozenset(['Identity'])
     with self.assertRaises(tensorflow_computation_transformations.
                            DisallowedOpInTensorFlowComputationError):
         compiled_computation_transformations.RaiseOnDisallowedOp(
             disallowed_op_names).transform(compiled_computation)