Example #1
0
def _prepare_for_rebinding(bb):
    """Replaces `bb` with semantically equivalent version for rebinding."""
    bb = compiler.normalize_all_equal_bit(bb)
    bb, _ = tree_transformations.remove_mapped_or_applied_identity(bb)
    bb = transformations.to_call_dominant(bb)
    bb, _ = tree_transformations.remove_unused_block_locals(bb)
    return bb
Example #2
0
 def test_converts_not_all_equal_at_server_reference_to_equal(self):
     fed_type_not_all_equal = computation_types.FederatedType(
         tf.int32, placements.SERVER, all_equal=False)
     normalized_comp = compiler.normalize_all_equal_bit(
         building_blocks.Reference('x', fed_type_not_all_equal))
     self.assertEqual(
         normalized_comp.type_signature,
         computation_types.FederatedType(tf.int32,
                                         placements.SERVER,
                                         all_equal=True))
     self.assertIsInstance(normalized_comp, building_blocks.Reference)
     self.assertEqual(str(normalized_comp), 'x')
Example #3
0
 def test_converts_not_all_equal_at_server_lambda_parameter_to_equal(self):
     fed_type_not_all_equal = computation_types.FederatedType(
         tf.int32, placements.SERVER, all_equal=False)
     normalized_fed_type = computation_types.FederatedType(
         tf.int32, placements.SERVER)
     ref = building_blocks.Reference('x', fed_type_not_all_equal)
     lam = building_blocks.Lambda('x', fed_type_not_all_equal, ref)
     normalized_lambda = compiler.normalize_all_equal_bit(lam)
     self.assertEqual(
         lam.type_signature,
         computation_types.FunctionType(fed_type_not_all_equal,
                                        fed_type_not_all_equal))
     self.assertIsInstance(normalized_lambda, building_blocks.Lambda)
     self.assertEqual(str(normalized_lambda), '(x -> x)')
     self.assertEqual(
         normalized_lambda.type_signature,
         computation_types.FunctionType(normalized_fed_type,
                                        normalized_fed_type))
Example #4
0
 def test_converts_federated_map_all_equal_to_federated_map(self):
     fed_type_all_equal = computation_types.FederatedType(
         tf.int32, placements.CLIENTS, all_equal=True)
     normalized_fed_type = computation_types.FederatedType(
         tf.int32, placements.CLIENTS)
     int_ref = building_blocks.Reference('x', tf.int32)
     int_identity = building_blocks.Lambda('x', tf.int32, int_ref)
     federated_int_ref = building_blocks.Reference('y', fed_type_all_equal)
     called_federated_map_all_equal = building_block_factory.create_federated_map_all_equal(
         int_identity, federated_int_ref)
     normalized_federated_map = compiler.normalize_all_equal_bit(
         called_federated_map_all_equal)
     self.assertEqual(called_federated_map_all_equal.function.uri,
                      intrinsic_defs.FEDERATED_MAP_ALL_EQUAL.uri)
     self.assertIsInstance(normalized_federated_map, building_blocks.Call)
     self.assertIsInstance(normalized_federated_map.function,
                           building_blocks.Intrinsic)
     self.assertEqual(normalized_federated_map.function.uri,
                      intrinsic_defs.FEDERATED_MAP.uri)
     self.assertEqual(normalized_federated_map.type_signature,
                      normalized_fed_type)
Example #5
0
 def test_raises_on_none(self):
     with self.assertRaises(TypeError):
         compiler.normalize_all_equal_bit(None)