Пример #1
0
def _prepare_for_rebinding(bb):
    """Replaces `bb` with semantically equivalent version for rebinding."""
    all_equal_normalized = transformations.normalize_all_equal_bit(bb)
    identities_removed, _ = tree_transformations.remove_mapped_or_applied_identity(
        all_equal_normalized)
    for_rebind, _ = compiler_transformations.prepare_for_rebinding(
        identities_removed)
    return for_rebind
 def test_converts_not_all_equal_at_server_reference_to_equal(self):
   fed_type_not_all_equal = computation_types.FederatedType(
       tf.int32, placements.SERVER, all_equal=False)
   normalized_comp = transformations.normalize_all_equal_bit(
       building_blocks.Reference('x', fed_type_not_all_equal))
   self.assertEqual(
       normalized_comp.type_signature,
       computation_types.FederatedType(
           tf.int32, placements.SERVER, all_equal=True))
   self.assertIsInstance(normalized_comp, building_blocks.Reference)
   self.assertEqual(str(normalized_comp), 'x')
 def test_converts_all_equal_at_clients_lambda_parameter_to_not_equal(self):
   fed_type_all_equal = computation_types.FederatedType(
       tf.int32, placements.CLIENTS, all_equal=True)
   normalized_fed_type = computation_types.FederatedType(
       tf.int32, placements.CLIENTS)
   ref = building_blocks.Reference('x', fed_type_all_equal)
   lam = building_blocks.Lambda('x', fed_type_all_equal, ref)
   normalized_lambda = transformations.normalize_all_equal_bit(lam)
   self.assertEqual(
       lam.type_signature,
       computation_types.FunctionType(fed_type_all_equal, fed_type_all_equal))
   self.assertIsInstance(normalized_lambda, building_blocks.Lambda)
   self.assertEqual(str(normalized_lambda), '(x -> x)')
   self.assertEqual(
       normalized_lambda.type_signature,
       computation_types.FunctionType(normalized_fed_type,
                                      normalized_fed_type))
Пример #4
0
 def test_converts_federated_map_all_equal_to_federated_map(self):
     fed_type_all_equal = computation_types.FederatedType(
         tf.int32, placements.CLIENTS, all_equal=True)
     normalized_fed_type = computation_types.FederatedType(
         tf.int32, placements.CLIENTS)
     int_ref = building_blocks.Reference('x', tf.int32)
     int_identity = building_blocks.Lambda('x', tf.int32, int_ref)
     federated_int_ref = building_blocks.Reference('y', fed_type_all_equal)
     called_federated_map_all_equal = building_block_factory.create_federated_map_all_equal(
         int_identity, federated_int_ref)
     normalized_federated_map = mapreduce_transformations.normalize_all_equal_bit(
         called_federated_map_all_equal)
     self.assertEqual(called_federated_map_all_equal.function.uri,
                      intrinsic_defs.FEDERATED_MAP_ALL_EQUAL.uri)
     self.assertIsInstance(normalized_federated_map, building_blocks.Call)
     self.assertIsInstance(normalized_federated_map.function,
                           building_blocks.Intrinsic)
     self.assertEqual(normalized_federated_map.function.uri,
                      intrinsic_defs.FEDERATED_MAP.uri)
     self.assertEqual(normalized_federated_map.type_signature,
                      normalized_fed_type)
Пример #5
0
 def test_raises_on_none(self):
     with self.assertRaises(TypeError):
         mapreduce_transformations.normalize_all_equal_bit(None)