Example #1
0
def _as_function_of_some_federated_subparameters(
    bb: building_blocks.Lambda,
    paths,
) -> building_blocks.Lambda:
    """Turns `x -> ...only uses parts of x...` into `parts_of_x -> ...`."""
    tree_analysis.check_has_unique_names(bb)
    bb = _prepare_for_rebinding(bb)
    name_generator = building_block_factory.unique_name_generator(bb)

    type_list = []
    int_paths = []
    for path in paths:
        selected_type = bb.parameter_type
        int_path = []
        for index in path:
            if not selected_type.is_struct():
                raise _ParameterSelectionError(path, bb)
            if isinstance(index, int):
                if index > len(selected_type):
                    raise _ParameterSelectionError(path, bb)
                int_path.append(index)
            else:
                py_typecheck.check_type(index, str)
                if not structure.has_field(selected_type, index):
                    raise _ParameterSelectionError(path, bb)
                int_path.append(
                    structure.name_to_index_map(selected_type)[index])
            selected_type = selected_type[index]
        if not selected_type.is_federated():
            raise _NonFederatedSelectionError(
                'Attempted to rebind references to parameter selection path '
                f'{path} from type {bb.parameter_type}, but the value at that path '
                f'was of non-federated type {selected_type}. Selections must all '
                f'be of federated type. Original AST:\n{bb}')
        int_paths.append(tuple(int_path))
        type_list.append(selected_type)

    placement = type_list[0].placement
    if not all(x.placement is placement for x in type_list):
        raise _MismatchedSelectionPlacementError(
            'In order to zip the argument to the lower-level lambda together, all '
            'selected arguments should be at the same placement. Your selections '
            f'have resulted in the list of types:\n{type_list}')

    zip_type = computation_types.FederatedType([x.member for x in type_list],
                                               placement=placement)
    ref_to_zip = building_blocks.Reference(next(name_generator), zip_type)
    path_to_replacement = {}
    for i, path in enumerate(int_paths):
        path_to_replacement[path] = _construct_selection_from_federated_tuple(
            ref_to_zip, i, name_generator)

    new_lambda_body = _replace_selections(bb.result, bb.parameter_name,
                                          path_to_replacement)
    lambda_with_zipped_param = building_blocks.Lambda(
        ref_to_zip.name, ref_to_zip.type_signature, new_lambda_body)
    tree_analysis.check_contains_no_new_unbound_references(
        bb, lambda_with_zipped_param)

    return lambda_with_zipped_param
 async def create_selection(self, source, index=None, name=None):
     py_typecheck.check_type(source, RemoteValue)
     py_typecheck.check_type(source.type_signature,
                             computation_types.StructType)
     if index is None:
         py_typecheck.check_type(name, str)
         index = structure.name_to_index_map(source.type_signature)[name]
     py_typecheck.check_type(index, int)
     result_type = source.type_signature[index]
     request = executor_pb2.CreateSelectionRequest(
         source_ref=source.value_ref, index=index)
     response = _request(self._stub.CreateSelection, request)
     py_typecheck.check_type(response, executor_pb2.CreateSelectionResponse)
     return RemoteValue(response.value_ref, result_type, self)
Example #3
0
 def _traverse_selection(comp, transform, context_tree, identifier_seq):
     """Helper function holding traversal logic for selection nodes."""
     _ = next(identifier_seq)
     source, source_modified = _transform_postorder_with_symbol_bindings_switch(
         comp.source, transform, context_tree, identifier_seq)
     if source_modified:
         # Normalize selection to index based on the type signature of the
         # original source. The new source may not have names present.
         if comp.index is not None:
             index = comp.index
         else:
             index = structure.name_to_index_map(
                 comp.source.type_signature)[comp.name]
         comp = building_blocks.Selection(source, index=index)
     comp, comp_modified = transform(comp, context_tree)
     return comp, comp_modified or source_modified
Example #4
0
  def test_name_to_index_map_fully_named_struct(self):
    partially_named_struct = structure.Struct([('b', 10), ('a', 20)])

    name_to_index_dict = structure.name_to_index_map(partially_named_struct)
    expected_name_to_index_map = {'b': 0, 'a': 1}
    self.assertEqual(name_to_index_dict, expected_name_to_index_map)
Example #5
0
 def test_name_to_index_map_empty_unnamed_struct(self):
   unnamed_struct = structure.Struct([(None, 10), (None, 20)])
   self.assertEmpty(structure.name_to_index_map(unnamed_struct))
Example #6
0
  def test_name_to_index_map_fully_named_struct(self):
    partially_named_struct = structure.Struct.named(b=10, a=20)

    name_to_index_dict = structure.name_to_index_map(partially_named_struct)
    expected_name_to_index_map = {'b': 0, 'a': 1}
    self.assertEqual(name_to_index_dict, expected_name_to_index_map)
Example #7
0
 def test_name_to_index_map_empty_unnamed_struct(self):
   unnamed_struct = structure.Struct.unnamed(10, 20)
   self.assertEmpty(structure.name_to_index_map(unnamed_struct))