Ejemplo n.º 1
0
 def test_returns_false_for_blocks_with_different_variable_values(self):
     data = building_blocks.Data('data', tf.int32)
     data_1 = building_blocks.Data('data', tf.float32)
     comp_1 = building_blocks.Block([('a', data_1)], data)
     data_2 = building_blocks.Data('data', tf.bool)
     comp_2 = building_blocks.Block([('a', data_2)], data)
     self.assertFalse(tree_analysis.trees_equal(comp_1, comp_2))
 def test_returns_true_for_blocks_resulting_reference_to_same_local(self):
     data = building_blocks.Data('data', tf.int32)
     ref_to_a = building_blocks.Reference('a', data.type_signature)
     ref_to_b = building_blocks.Reference('b', data.type_signature)
     comp_1 = building_blocks.Block([('a', data)], ref_to_a)
     comp_2 = building_blocks.Block([('b', data)], ref_to_b)
     self.assertTrue(tree_analysis.trees_equal(comp_1, comp_2))
Ejemplo n.º 3
0
 def test_returns_false_for_blocks_referring_to_different_local(self):
   data = building_blocks.Data('data', tf.int32)
   ref_to_a = building_blocks.Reference('a', data.type_signature)
   ref_to_b = building_blocks.Reference('b', data.type_signature)
   comp_1 = building_blocks.Block([('a', data), ('b', ref_to_a)], ref_to_a)
   comp_2 = building_blocks.Block([('b', data), ('a', ref_to_b)], ref_to_a)
   self.assertFalse(tree_analysis.trees_equal(comp_1, comp_2))
   self.assertFalse(tree_analysis.trees_equal(comp_2, comp_1))
Ejemplo n.º 4
0
 def test_removes_nested_blocks_with_unused_reference(self):
     input_data = building_blocks.Data('b', tf.int32)
     blk = building_blocks.Block(
         [('x', building_blocks.Data('a', tf.int32))], input_data)
     higher_level_blk = building_blocks.Block([('y', input_data)], blk)
     data, modified = transformation_utils.transform_postorder(
         higher_level_blk, self._unused_block_remover.transform)
     self.assertTrue(modified)
     self.assertEqual(data.compact_representation(),
                      input_data.compact_representation())
Ejemplo n.º 5
0
 def test_returns_tf_computation_with_functional_type_block_to_lambda_with_block(
         self):
     concrete_int_type = computation_types.TensorType(tf.int32)
     param = building_blocks.Reference('x', tf.float32)
     block_to_param = building_blocks.Block([('x', param)], param)
     lam = building_blocks.Lambda(param.name, param.type_signature,
                                  block_to_param)
     unused_int = building_block_factory.create_tensorflow_constant(
         concrete_int_type, 1)
     blk_to_lam = building_blocks.Block([('y', unused_int)], lam)
     self.assert_compiles_to_tensorflow(blk_to_lam)
Ejemplo n.º 6
0
    def test_nested_blocks(self):
        x_ref = building_blocks.Reference('a', tf.int32)
        data = building_blocks.Data('data', tf.int32)
        block1 = building_blocks.Block([('a', data), ('a', x_ref)], x_ref)
        block2 = building_blocks.Block([('a', data), ('a', x_ref)], block1)

        transformed_comp, modified = tree_transformations.uniquify_reference_names(
            block2)

        self.assertEqual(block2.compact_representation(),
                         '(let a=data,a=a in (let a=data,a=a in a))')
        self.assertEqual(
            transformed_comp.compact_representation(),
            '(let a=data,_var1=a in (let _var2=data,_var3=_var2 in _var3))')
        tree_analysis.check_has_unique_names(transformed_comp)
        self.assertTrue(modified)
Ejemplo n.º 7
0
 def test_returns_single_called_graph_after_resolving_multiple_variables(
         self):
     ref_to_int = building_blocks.Reference('var', tf.int32)
     first_tf_id_type = computation_types.TensorType(tf.int32)
     first_tf_id = building_block_factory.create_compiled_identity(
         first_tf_id_type)
     called_tf_id = building_blocks.Call(first_tf_id, ref_to_int)
     ref_to_call = building_blocks.Reference('call',
                                             called_tf_id.type_signature)
     second_tf_id_type = computation_types.TensorType(tf.int32)
     second_tf_id = building_block_factory.create_compiled_identity(
         second_tf_id_type)
     second_called = building_blocks.Call(second_tf_id, ref_to_call)
     ref_to_second_call = building_blocks.Reference(
         'second_call', called_tf_id.type_signature)
     block_locals = [('call', called_tf_id), ('second_call', second_called)]
     block = building_blocks.Block(block_locals, ref_to_second_call)
     tf_representing_block, _ = transformations.create_tensorflow_representing_block(
         block)
     self.assertEqual(tf_representing_block.type_signature,
                      block.type_signature)
     self.assertIsInstance(tf_representing_block, building_blocks.Call)
     self.assertIsInstance(tf_representing_block.function,
                           building_blocks.CompiledComputation)
     self.assertIsInstance(tf_representing_block.argument,
                           building_blocks.Reference)
     self.assertEqual(tf_representing_block.argument.name, 'var')
Ejemplo n.º 8
0
 def test_returned_tensorflow_executes_correctly_with_no_unbound_refs(self):
     concrete_int_type = computation_types.TensorType(tf.int32)
     concrete_int = building_block_factory.create_tensorflow_constant(
         concrete_int_type, 1)
     first_tf_id_type = computation_types.TensorType(tf.int32)
     first_tf_id = building_block_factory.create_compiled_identity(
         first_tf_id_type)
     called_tf_id = building_blocks.Call(first_tf_id, concrete_int)
     ref_to_call = building_blocks.Reference('call',
                                             called_tf_id.type_signature)
     second_tf_id_type = computation_types.TensorType(tf.int32)
     second_tf_id = building_block_factory.create_compiled_identity(
         second_tf_id_type)
     second_called = building_blocks.Call(second_tf_id, ref_to_call)
     ref_to_second_call = building_blocks.Reference(
         'second_call', called_tf_id.type_signature)
     block_locals = [('call', called_tf_id), ('second_call', second_called)]
     block = building_blocks.Block(
         block_locals,
         building_blocks.Tuple([ref_to_second_call, ref_to_second_call]))
     tf_representing_block, _ = transformations.create_tensorflow_representing_block(
         block)
     result = test_utils.run_tensorflow(
         tf_representing_block.function.proto)
     self.assertAllEqual(result, [1, 1])
Ejemplo n.º 9
0
def _create_complex_computation():
    tensor_type = computation_types.TensorType(tf.int32)
    compiled = building_block_factory.create_compiled_identity(
        tensor_type, 'a')
    federated_type = computation_types.FederatedType(tf.int32,
                                                     placements.SERVER)
    arg_ref = building_blocks.Reference('arg', federated_type)
    bindings = []
    results = []

    def _bind(name, value):
        bindings.append((name, value))
        return building_blocks.Reference(name, value.type_signature)

    for i in range(2):
        called_federated_broadcast = building_block_factory.create_federated_broadcast(
            arg_ref)
        called_federated_map = building_block_factory.create_federated_map(
            compiled, _bind(f'broadcast_{i}', called_federated_broadcast))
        called_federated_mean = building_block_factory.create_federated_mean(
            _bind(f'map_{i}', called_federated_map), None)
        results.append(_bind(f'mean_{i}', called_federated_mean))
    result = building_blocks.Struct(results)
    block = building_blocks.Block(bindings, result)
    return building_blocks.Lambda('arg', tf.int32, block)
Ejemplo n.º 10
0
 def test_propogates_dependence_up_through_block_result(self):
   dummy_intrinsic = building_blocks.Intrinsic('dummy_intrinsic', tf.int32)
   integer_reference = building_blocks.Reference('int', tf.int32)
   block = building_blocks.Block([('x', integer_reference)], dummy_intrinsic)
   dependent_nodes = tree_analysis.extract_nodes_consuming(
       block, dummy_intrinsic_predicate)
   self.assertIn(block, dependent_nodes)
Ejemplo n.º 11
0
 def test_raises_lambda_rebinding_of_block_variable(self):
     x_ref = building_blocks.Reference('x', tf.int32)
     lambda_1 = building_blocks.Lambda('x', tf.int32, x_ref)
     x_data = building_blocks.Data('x', tf.int32)
     single_block = building_blocks.Block([('x', x_data)], lambda_1)
     with self.assertRaises(tree_analysis.NonuniqueNameError):
         tree_analysis.check_has_unique_names(single_block)
  def test_with_block(self):
    ex = reference_resolving_executor.ReferenceResolvingExecutor(
        eager_tf_executor.EagerTFExecutor())
    loop = asyncio.get_event_loop()

    f_type = computation_types.FunctionType(tf.int32, tf.int32)
    a = building_blocks.Reference(
        'a', computation_types.StructType([('f', f_type), ('x', tf.int32)]))
    ret = building_blocks.Block([('f', building_blocks.Selection(a, name='f')),
                                 ('x', building_blocks.Selection(a, name='x'))],
                                building_blocks.Call(
                                    building_blocks.Reference('f', f_type),
                                    building_blocks.Call(
                                        building_blocks.Reference('f', f_type),
                                        building_blocks.Reference(
                                            'x', tf.int32))))
    comp = building_blocks.Lambda(a.name, a.type_signature, ret)

    @computations.tf_computation(tf.int32)
    def add_one(x):
      return x + 1

    v1 = loop.run_until_complete(
        ex.create_value(comp.proto, comp.type_signature))
    v2 = loop.run_until_complete(ex.create_value(add_one))
    v3 = loop.run_until_complete(ex.create_value(10, tf.int32))
    v4 = loop.run_until_complete(
        ex.create_struct(
            anonymous_tuple.AnonymousTuple([('f', v2), ('x', v3)])))
    v5 = loop.run_until_complete(ex.create_call(v1, v4))
    result = loop.run_until_complete(v5.compute())
    self.assertEqual(result.numpy(), 12)
Ejemplo n.º 13
0
 def test_returns_correct_structure_with_tuple_in_result(self):
     ref_to_int = building_blocks.Reference('var', tf.int32)
     first_tf_id = building_block_factory.create_compiled_identity(tf.int32)
     called_tf_id = building_blocks.Call(first_tf_id, ref_to_int)
     ref_to_call = building_blocks.Reference('call',
                                             called_tf_id.type_signature)
     second_tf_id = building_block_factory.create_compiled_identity(
         tf.int32)
     second_called = building_blocks.Call(second_tf_id, ref_to_call)
     ref_to_second_call = building_blocks.Reference(
         'second_call', called_tf_id.type_signature)
     block_locals = [('call', called_tf_id), ('second_call', second_called)]
     block = building_blocks.Block(
         block_locals,
         building_blocks.Tuple([ref_to_second_call, ref_to_second_call]))
     tf_representing_block, _ = compiler_transformations.create_tensorflow_representing_block(
         block)
     self.assertEqual(tf_representing_block.type_signature,
                      block.type_signature)
     self.assertIsInstance(tf_representing_block, building_blocks.Call)
     self.assertIsInstance(tf_representing_block.function,
                           building_blocks.CompiledComputation)
     self.assertIsInstance(tf_representing_block.argument,
                           building_blocks.Reference)
     self.assertEqual(tf_representing_block.argument.name, 'var')
Ejemplo n.º 14
0
def _insert_comp_in_top_level_lambda(comp, name, comp_to_insert):
    """Inserts a computation into `comp` with the given `name`.

  Args:
    comp: The `building_blocks.Lambda` to transform. The names of lambda
      parameters and block variables in `comp` must be unique.
    name: The name to use.
    comp_to_insert: The `building_blocks.ComputationBuildingBlock` to insert.

  Returns:
    A new computation with the transformation applied or the original `comp`.
  """
    py_typecheck.check_type(comp, building_blocks.Lambda)
    py_typecheck.check_type(name, str)
    py_typecheck.check_type(comp_to_insert,
                            building_blocks.ComputationBuildingBlock)
    tree_analysis.check_has_unique_names(comp)

    result = comp.result
    if result.is_block():
        variables = result.locals
        result = result.result
    else:
        variables = []
    variables.insert(0, (name, comp_to_insert))
    block = building_blocks.Block(variables, result)
    return building_blocks.Lambda(comp.parameter_name, comp.parameter_type,
                                  block)
Ejemplo n.º 15
0
 def _transform(comp, context_tree):
   """Renames References in `comp` to unique names."""
   if comp.is_reference():
     payload = context_tree.get_payload_with_name(comp.name)
     if payload is None:
       return comp, False
     new_name = payload.new_name
     if new_name is comp.name:
       return comp, False
     return building_blocks.Reference(new_name, comp.type_signature,
                                      comp.context), True
   elif comp.is_block():
     new_locals = []
     modified = False
     for name, val in comp.locals:
       context_tree.walk_down_one_variable_binding()
       new_name = context_tree.get_payload_with_name(name).new_name
       modified = modified or (new_name is not name)
       new_locals.append((new_name, val))
     return building_blocks.Block(new_locals, comp.result), modified
   elif comp.is_lambda():
     if comp.parameter_type is None:
       return comp, False
     context_tree.walk_down_one_variable_binding()
     new_name = context_tree.get_payload_with_name(
         comp.parameter_name).new_name
     if new_name is comp.parameter_name:
       return comp, False
     return building_blocks.Lambda(new_name, comp.parameter_type,
                                   comp.result), True
   return comp, False
Ejemplo n.º 16
0
 def test_raises_with_naked_graph_as_block_local(self):
   graph = building_block_factory.create_compiled_identity(tf.int32)
   block_locals = [('graph', graph)]
   ref_to_graph = building_blocks.Reference('graph', graph.type_signature)
   block = building_blocks.Block(block_locals, ref_to_graph)
   with self.assertRaises(ValueError):
     compiler_transformations.create_tensorflow_representing_block(block)
Ejemplo n.º 17
0
 def test_returns_correct_structure_with_no_unbound_references(self):
     concrete_int_type = computation_types.TensorType(tf.int32)
     concrete_int = building_block_factory.create_tensorflow_constant(
         concrete_int_type, 1)
     first_tf_id_type = computation_types.TensorType(tf.int32)
     first_tf_id = building_block_factory.create_compiled_identity(
         first_tf_id_type)
     called_tf_id = building_blocks.Call(first_tf_id, concrete_int)
     ref_to_call = building_blocks.Reference('call',
                                             called_tf_id.type_signature)
     second_tf_id_type = computation_types.TensorType(tf.int32)
     second_tf_id = building_block_factory.create_compiled_identity(
         second_tf_id_type)
     second_called = building_blocks.Call(second_tf_id, ref_to_call)
     ref_to_second_call = building_blocks.Reference(
         'second_call', called_tf_id.type_signature)
     block_locals = [('call', called_tf_id), ('second_call', second_called)]
     block = building_blocks.Block(
         block_locals,
         building_blocks.Tuple([ref_to_second_call, ref_to_second_call]))
     tf_representing_block, _ = transformations.create_tensorflow_representing_block(
         block)
     self.assertEqual(tf_representing_block.type_signature,
                      block.type_signature)
     self.assertIsInstance(tf_representing_block, building_blocks.Call)
     self.assertIsInstance(tf_representing_block.function,
                           building_blocks.CompiledComputation)
     self.assertIsNone(tf_representing_block.argument)
Ejemplo n.º 18
0
    def test_with_block(self):
        ex = reference_resolving_executor.ReferenceResolvingExecutor(
            eager_tf_executor.EagerTFExecutor())

        f_type = computation_types.FunctionType(tf.int32, tf.int32)
        a = building_blocks.Reference(
            'a', computation_types.StructType([('f', f_type),
                                               ('x', tf.int32)]))
        ret = building_blocks.Block(
            [('f', building_blocks.Selection(a, name='f')),
             ('x', building_blocks.Selection(a, name='x'))],
            building_blocks.Call(
                building_blocks.Reference('f', f_type),
                building_blocks.Call(building_blocks.Reference('f', f_type),
                                     building_blocks.Reference('x',
                                                               tf.int32))))
        comp = building_blocks.Lambda(a.name, a.type_signature, ret)

        @tensorflow_computation.tf_computation(tf.int32)
        def add_one(x):
            return x + 1

        v1 = asyncio.run(ex.create_value(comp.proto, comp.type_signature))
        v2 = asyncio.run(ex.create_value(add_one))
        v3 = asyncio.run(ex.create_value(10, tf.int32))
        v4 = asyncio.run(
            ex.create_struct(structure.Struct([('f', v2), ('x', v3)])))
        v5 = asyncio.run(ex.create_call(v1, v4))
        result = asyncio.run(v5.compute())
        self.assertEqual(result.numpy(), 12)
Ejemplo n.º 19
0
 def test_returns_comp_with_block_untransformed(self):
     data = building_blocks.Data('a', tf.int32)
     block = building_blocks.Block([('x', data), ('y', data)], data)
     untransformed, modified_indicator = compiler_transformations.remove_duplicate_called_graphs(
         block)
     self.assertEqual(untransformed, block)
     self.assertFalse(modified_indicator)
Ejemplo n.º 20
0
 def test_basic_functionality_of_block_class(self):
   x = building_blocks.Block(
       [('x', building_blocks.Reference('arg', (tf.int32, tf.int32))),
        ('y',
         building_blocks.Selection(
             building_blocks.Reference('x', (tf.int32, tf.int32)), index=0))],
       building_blocks.Reference('y', tf.int32))
   self.assertEqual(str(x.type_signature), 'int32')
   self.assertEqual([(k, v.compact_representation()) for k, v in x.locals],
                    [('x', 'arg'), ('y', 'x[0]')])
   self.assertEqual(x.result.compact_representation(), 'y')
   self.assertEqual(
       repr(x), 'Block([(\'x\', Reference(\'arg\', '
       'StructType([TensorType(tf.int32), TensorType(tf.int32)]) as tuple)), '
       '(\'y\', Selection(Reference(\'x\', '
       'StructType([TensorType(tf.int32), TensorType(tf.int32)]) as tuple), '
       'index=0))], '
       'Reference(\'y\', TensorType(tf.int32)))')
   self.assertEqual(x.compact_representation(), '(let x=arg,y=x[0] in y)')
   x_proto = x.proto
   self.assertEqual(
       type_serialization.deserialize_type(x_proto.type), x.type_signature)
   self.assertEqual(x_proto.WhichOneof('computation'), 'block')
   self.assertEqual(str(x_proto.block.result), str(x.result.proto))
   for idx, loc_proto in enumerate(x_proto.block.local):
     loc_name, loc_value = x.locals[idx]
     self.assertEqual(loc_proto.name, loc_name)
     self.assertEqual(str(loc_proto.value), str(loc_value.proto))
     self._serialize_deserialize_roundtrip_test(x)
Ejemplo n.º 21
0
 def test_executes_correctly_with_tuple_in_result(self):
     ref_to_int = building_blocks.Reference('var', tf.int32)
     first_tf_id_type = computation_types.TensorType(tf.int32)
     first_tf_id = building_block_factory.create_compiled_identity(
         first_tf_id_type)
     called_tf_id = building_blocks.Call(first_tf_id, ref_to_int)
     ref_to_call = building_blocks.Reference('call',
                                             called_tf_id.type_signature)
     second_tf_id_type = computation_types.TensorType(tf.int32)
     second_tf_id = building_block_factory.create_compiled_identity(
         second_tf_id_type)
     second_called = building_blocks.Call(second_tf_id, ref_to_call)
     ref_to_second_call = building_blocks.Reference(
         'second_call', called_tf_id.type_signature)
     block_locals = [('call', called_tf_id), ('second_call', second_called)]
     block = building_blocks.Block(
         block_locals,
         building_blocks.Tuple([ref_to_second_call, ref_to_second_call]))
     tf_representing_block, _ = transformations.create_tensorflow_representing_block(
         block)
     result_ones = test_utils.run_tensorflow(
         tf_representing_block.function.proto, 1)
     self.assertAllEqual(result_ones, [1, 1])
     result_zeros = test_utils.run_tensorflow(
         tf_representing_block.function.proto, 0)
     self.assertAllEqual(result_zeros, [0, 0])
Ejemplo n.º 22
0
 def test_unwraps_block_with_empty_locals(self):
     input_data = building_blocks.Data('b', tf.int32)
     blk = building_blocks.Block([], input_data)
     data, modified = transformation_utils.transform_postorder(
         blk, self._unused_block_remover.transform)
     self.assertTrue(modified)
     self.assertEqual(data.compact_representation(),
                      input_data.compact_representation())
Ejemplo n.º 23
0
 def test_returns_tf_computation_block_with_compiled_comp(self):
     concrete_int_type = computation_types.TensorType(tf.int32)
     tf_identity = building_block_factory.create_compiled_identity(
         concrete_int_type)
     unused_int = building_block_factory.create_tensorflow_constant(
         concrete_int_type, 1)
     block_to_id = building_blocks.Block([('x', unused_int)], tf_identity)
     self.assert_compiles_to_tensorflow(block_to_id)
Ejemplo n.º 24
0
 def test_returns_tf_computation_with_functional_type_lambda_with_block(
         self):
     param = building_blocks.Reference('x', [('a', tf.int32),
                                             ('b', tf.float32)])
     block_to_param = building_blocks.Block([('x', param)], param)
     lam = building_blocks.Lambda(param.name, param.type_signature,
                                  block_to_param)
     self.assert_compiles_to_tensorflow(lam)
Ejemplo n.º 25
0
 def test_passes_unbound_type_signature_obscured_under_block(self):
     fed_ref = building_blocks.Reference(
         'x', computation_types.FederatedType(tf.int32, placements.SERVER))
     block = building_blocks.Block(
         [('y', fed_ref), ('x', building_blocks.Data('whimsy', tf.int32)),
          ('z', building_blocks.Reference('x', tf.int32))],
         building_blocks.Reference('y', fed_ref.type_signature))
     tree_transformations.strip_placement(block)
Ejemplo n.º 26
0
def _split_by_intrinsics_in_top_level_lambda(comp):
    """Splits by the intrinsics in the frist block local in the result of `comp`.

  This function splits `comp` into two computations `before` and `after` the
  called intrinsic or tuple of called intrinsics found as the first local in the
  `building_blocks.Block` returned by the top level lambda; and returns a Python
  tuple representing the pair of `before` and `after` computations.

  Args:
    comp: The `building_blocks.Lambda` to split.

  Returns:
    A pair of `building_blocks.ComputationBuildingBlock`s.

  Raises:
    ValueError: If the first local in the `building_blocks.Block` referenced by
      the top level lambda is not a called intrincs or a
      `building_blocks.Struct` of called intrinsics.
  """
    py_typecheck.check_type(comp, building_blocks.Lambda)
    py_typecheck.check_type(comp.result, building_blocks.Block)
    tree_analysis.check_has_unique_names(comp)

    name_generator = building_block_factory.unique_name_generator(comp)

    name, first_local = comp.result.locals[0]
    if building_block_analysis.is_called_intrinsic(first_local):
        result = first_local.argument
    elif first_local.is_struct():
        elements = []
        for element in first_local:
            if not building_block_analysis.is_called_intrinsic(element):
                raise ValueError(
                    'Expected all the elements of the `building_blocks.Struct` to be '
                    'called intrinsics, but found: \n{}'.format(element))
            elements.append(element.argument)
        result = building_blocks.Struct(elements)
    else:
        raise ValueError(
            'Expected either a called intrinsic or a `building_blocks.Struct` of '
            'called intrinsics, but found: \n{}'.format(first_local))

    before = building_blocks.Lambda(comp.parameter_name, comp.parameter_type,
                                    result)

    ref_name = next(name_generator)
    ref_type = computation_types.StructType(
        (comp.parameter_type, first_local.type_signature))
    ref = building_blocks.Reference(ref_name, ref_type)
    sel_after_arg_1 = building_blocks.Selection(ref, index=0)
    sel_after_arg_2 = building_blocks.Selection(ref, index=1)

    variables = comp.result.locals
    variables[0] = (name, sel_after_arg_2)
    variables.insert(0, (comp.parameter_name, sel_after_arg_1))
    block = building_blocks.Block(variables, comp.result.result)
    after = building_blocks.Lambda(ref.name, ref.type_signature, block)
    return before, after
Ejemplo n.º 27
0
 def test_leaves_single_used_reference(self):
     blk = building_blocks.Block(
         [('x', building_blocks.Data('a', tf.int32))],
         building_blocks.Reference('x', tf.int32))
     transformed_blk, modified = transformation_utils.transform_postorder(
         blk, self._unused_block_remover.transform)
     self.assertFalse(modified)
     self.assertEqual(transformed_blk.compact_representation(),
                      blk.compact_representation())
Ejemplo n.º 28
0
 def test_propogates_dependence_up_through_block_locals(self):
   type_signature = computation_types.TensorType(tf.int32)
   dummy_intrinsic = building_blocks.Intrinsic('dummy_intrinsic',
                                               type_signature)
   integer_reference = building_blocks.Reference('int', tf.int32)
   block = building_blocks.Block([('x', dummy_intrinsic)], integer_reference)
   dependent_nodes = tree_analysis.extract_nodes_consuming(
       block, dummy_intrinsic_predicate)
   self.assertIn(block, dependent_nodes)
Ejemplo n.º 29
0
 def test_raises_multiple_placements(self):
     server_placed_data = building_blocks.Reference(
         'x', computation_types.at_server(tf.int32))
     clients_placed_data = building_blocks.Reference(
         'y', computation_types.at_clients(tf.int32))
     block_holding_both = building_blocks.Block([('x', server_placed_data)],
                                                clients_placed_data)
     with self.assertRaisesRegex(ValueError,
                                 'multiple different placements'):
         tree_transformations.strip_placement(block_holding_both)
Ejemplo n.º 30
0
 def test_with_simple_block(self):
   data = building_blocks.Data('a', tf.int32)
   simple_block = building_blocks.Block([('x', data)],
                                        building_blocks.Reference(
                                            'x', tf.int32))
   lambdas_and_blocks_removed, modified = compiler_transformations.remove_lambdas_and_blocks(
       simple_block)
   self.assertTrue(modified)
   self.assertNoLambdasOrBlocks(lambdas_and_blocks_removed)
   self.assertEqual(lambdas_and_blocks_removed.compact_representation(), 'a')