def test_binding_single_arg_leaves_no_unbound_references(self): fed_at_clients = computation_types.FederatedType(tf.int32, placements.CLIENTS) fed_at_server = computation_types.FederatedType(tf.int32, placements.SERVER) tuple_of_federated_types = computation_types.NamedTupleType( [fed_at_clients, fed_at_server]) lam = building_blocks.Lambda( 'x', tuple_of_federated_types, building_blocks.Selection( building_blocks.Reference('x', tuple_of_federated_types), index=0)) zeroth_index_extracted = mapreduce_transformations.zip_selection_as_argument_to_lower_level_lambda( lam, [[0]]) unbound_references = transformations.get_map_of_unbound_references( zeroth_index_extracted)[zeroth_index_extracted] self.assertEmpty(unbound_references)
def _check_no_unbound_references(comp): """Checks that `comp` has no unbound references. This is a temporary helper function, to be removed once we provide a more complete support. Args: comp: An instance of `pb.Computation` to check. Raises: ValueError: If `comp` has unbound references. """ py_typecheck.check_type(comp, pb.Computation) blk = building_blocks.ComputationBuildingBlock.from_proto(comp) unbound_map = transformations.get_map_of_unbound_references(blk) unbound_refs = unbound_map[blk] if unbound_refs: raise ValueError( 'The computation contains unbound references: {}.'.format(unbound_refs))