def test_raises_on_selections_at_different_placements(self):
   fed_at_clients = computation_types.FederatedType(tf.int32,
                                                    placements.CLIENTS)
   fed_at_server = computation_types.FederatedType(tf.int32, placements.SERVER)
   tuple_of_federated_types = computation_types.StructType(
       [fed_at_clients, fed_at_server])
   lam = building_blocks.Lambda(
       'x', tuple_of_federated_types,
       building_blocks.Selection(
           building_blocks.Reference('x', tuple_of_federated_types), index=0))
   with self.assertRaises(form_utils._MismatchedSelectionPlacementError):
     form_utils._as_function_of_some_federated_subparameters(lam, [(0,), (1,)])
    def test_multiple_nested_named_element_selection(self):
        fed_at_clients = computation_types.FederatedType(
            tf.int32, placements.CLIENTS)
        fed_at_server = computation_types.FederatedType(
            tf.int32, placements.SERVER)
        tuple_of_federated_types = computation_types.StructType([
            ('a', [('a', fed_at_clients)]), ('b', fed_at_server),
            ('c', [('c', fed_at_clients)])
        ])
        first_selection = building_blocks.Selection(building_blocks.Selection(
            building_blocks.Reference('x',
                                      tuple_of_federated_types), name='a'),
                                                    name='a')
        second_selection = building_blocks.Selection(building_blocks.Selection(
            building_blocks.Reference('x',
                                      tuple_of_federated_types), name='c'),
                                                     name='c')
        lam = building_blocks.Lambda(
            'x', tuple_of_federated_types,
            building_blocks.Struct([first_selection, second_selection]))

        new_lam = form_utils._as_function_of_some_federated_subparameters(
            lam, [(0, 0), (2, 0)])

        expected_parameter_type = computation_types.at_clients(
            (tf.int32, tf.int32))
        type_test_utils.assert_types_equivalent(
            new_lam.type_signature,
            computation_types.FunctionType(expected_parameter_type,
                                           lam.result.type_signature))
 def test_single_element_selection_leaves_no_unbound_references(self):
   fed_at_clients = computation_types.FederatedType(tf.int32,
                                                    placements.CLIENTS)
   fed_at_server = computation_types.FederatedType(tf.int32, placements.SERVER)
   tuple_of_federated_types = computation_types.StructType(
       [fed_at_clients, fed_at_server])
   lam = building_blocks.Lambda(
       'x', tuple_of_federated_types,
       building_blocks.Selection(
           building_blocks.Reference('x', tuple_of_federated_types), index=0))
   new_lam = form_utils._as_function_of_some_federated_subparameters(
       lam, [(0,)])
   unbound_references = transformation_utils.get_map_of_unbound_references(
       new_lam)[new_lam]
   self.assertEmpty(unbound_references)
  def test_single_element_selection(self):
    fed_at_clients = computation_types.FederatedType(tf.int32,
                                                     placements.CLIENTS)
    fed_at_server = computation_types.FederatedType(tf.int32, placements.SERVER)
    tuple_of_federated_types = computation_types.StructType(
        [fed_at_clients, fed_at_server])
    lam = building_blocks.Lambda(
        'x', tuple_of_federated_types,
        building_blocks.Selection(
            building_blocks.Reference('x', tuple_of_federated_types), index=0))

    new_lam = form_utils._as_function_of_some_federated_subparameters(
        lam, [(0,)])
    expected_parameter_type = computation_types.at_clients((tf.int32,))
    self.assert_types_equivalent(
        new_lam.type_signature,
        computation_types.FunctionType(expected_parameter_type,
                                       lam.result.type_signature))
 def test_binding_multiple_args_results_in_unique_names(self):
     fed_at_clients = computation_types.FederatedType(
         tf.int32, placements.CLIENTS)
     fed_at_server = computation_types.FederatedType(
         tf.int32, placements.SERVER)
     tuple_of_federated_types = computation_types.StructType(
         [[fed_at_clients], fed_at_server, [fed_at_clients]])
     first_selection = building_blocks.Selection(building_blocks.Selection(
         building_blocks.Reference('x', tuple_of_federated_types), index=0),
                                                 index=0)
     second_selection = building_blocks.Selection(building_blocks.Selection(
         building_blocks.Reference('x', tuple_of_federated_types), index=2),
                                                  index=0)
     lam = building_blocks.Lambda(
         'x', tuple_of_federated_types,
         building_blocks.Struct([first_selection, second_selection]))
     new_lam = form_utils._as_function_of_some_federated_subparameters(
         lam, [(0, 0), (2, 0)])
     tree_analysis.check_has_unique_names(new_lam)
 def test_raises_on_non_federated_selection(self):
     lam = building_blocks.Lambda(
         'x', [tf.int32], building_blocks.Reference('x', [tf.int32]))
     with self.assertRaises(form_utils._NonFederatedSelectionError):
         form_utils._as_function_of_some_federated_subparameters(
             lam, [(0, )])