def test_raises_on_different_parameter_types(self): int_reference = building_blocks.Reference('x', tf.int32) int_lambda = building_blocks.Lambda('x', tf.int32, int_reference) float_reference = building_blocks.Reference('x', tf.float32) float_lambda = building_blocks.Lambda('x', tf.float32, float_reference) with self.assertRaises(TypeError): transformations.concatenate_function_outputs(int_lambda, float_lambda)
def test_raises_on_non_lambda_args(self): reference = building_blocks.Reference('x', tf.int32) tff_lambda = building_blocks.Lambda('x', tf.int32, reference) with self.assertRaises(TypeError): transformations.concatenate_function_outputs(tff_lambda, reference) with self.assertRaises(TypeError): transformations.concatenate_function_outputs(reference, tff_lambda)
def test_raises_on_non_unique_names(self): reference = building_blocks.Reference('x', tf.int32) good_lambda = building_blocks.Lambda('x', tf.int32, reference) bad_lambda = building_blocks.Lambda('x', tf.int32, good_lambda) with self.assertRaises(ValueError): transformations.concatenate_function_outputs(good_lambda, bad_lambda) with self.assertRaises(ValueError): transformations.concatenate_function_outputs(bad_lambda, good_lambda)
def extract_work(before_aggregate, after_aggregate, canonical_form_types): """Converts `before_aggregate` and `after_aggregate` to `work`. Args: before_aggregate: The first result of splitting `after_broadcast` on `tff_framework.FEDERATED_AGGREGATE`. after_aggregate: The second result of splitting `after_broadcast` on `tff_framework.FEDERATED_AGGREGATE`. canonical_form_types: `dict` holding the `canonical_form.CanonicalForm` type signatures specified by the `tff.utils.IterativeProcess` we are compiling. Returns: `work` as specified by `canonical_form.CanonicalForm`, an instance of `tff_framework.CompiledComputation`. Raises: transformations.CanonicalFormCompilationError: If we fail to extract a `tff_framework.CompiledComputation`, or we extract one of the wrong type. """ # See `get_iterative_process_for_canonical_form()` above for the meaning of # variable names used in the code below. c3_elements_in_before_aggregate_parameter = [[0, 1], [1]] c3_to_before_aggregate_computation = ( transformations.zip_selection_as_argument_to_lower_level_lambda( before_aggregate, c3_elements_in_before_aggregate_parameter).result.function) c5_index_in_before_aggregate_result = 0 c3_to_c5_computation = transformations.select_output_from_lambda( c3_to_before_aggregate_computation, c5_index_in_before_aggregate_result) c6_index_in_after_aggregate_result = 2 after_aggregate_to_c6_computation = transformations.select_output_from_lambda( after_aggregate, c6_index_in_after_aggregate_result) c3_elements_in_after_aggregate_parameter = [[0, 0, 1], [0, 1]] c3_to_c6_computation = ( transformations.zip_selection_as_argument_to_lower_level_lambda( after_aggregate_to_c6_computation, c3_elements_in_after_aggregate_parameter).result.function) c3_to_unzipped_c4_computation = transformations.concatenate_function_outputs( c3_to_c5_computation, c3_to_c6_computation) c3_to_c4_computation = tff_framework.Lambda( c3_to_unzipped_c4_computation.parameter_name, c3_to_unzipped_c4_computation.parameter_type, tff_framework.create_federated_zip( c3_to_unzipped_c4_computation.result)) work = transformations.consolidate_and_extract_local_processing( c3_to_c4_computation) if not isinstance(work, tff_framework.CompiledComputation): raise transformations.CanonicalFormCompilationError( 'Failed to extract a `tff_framework.CompiledComputation` from ' 'work, instead received a {} (of type {}).'.format( type(work), work.type_signature)) if work.type_signature != canonical_form_types['work_type']: raise transformations.CanonicalFormCompilationError( 'Extracted a TF block of the wrong type. Expected a function with type ' '{}, but the type signature of the TF block was {}'.format( canonical_form_types['work_type'], work.type_signature)) return work
def test_concatenates_identities(self): x_reference = building_blocks.Reference('x', tf.int32) x_lambda = building_blocks.Lambda('x', tf.int32, x_reference) y_reference = building_blocks.Reference('y', tf.int32) y_lambda = building_blocks.Lambda('y', tf.int32, y_reference) concatenated = mapreduce_transformations.concatenate_function_outputs( x_lambda, y_lambda) self.assertEqual(str(concatenated), '(_var1 -> <_var1,_var1>)')
def _extract_work(before_aggregate, after_aggregate): """Extracts `work` from `before_aggregate` and `after_aggregate`. This function is intended to be used by `get_canonical_form_for_iterative_process` only. As a result, this function does not assert that `before_aggregate` or `after_aggregate` has the expected structure, the caller is expected to perform these checks before calling this function. Args: before_aggregate: The first result of splitting `after_broadcast` on aggregate intrinsics. after_aggregate: The second result of splitting `after_broadcast` on aggregate intrinsics. Returns: `work` as specified by `canonical_form.CanonicalForm`, an instance of `building_blocks.CompiledComputation`. Raises: transformations.CanonicalFormCompilationError: If we extract an AST of the wrong type. """ c3_elements_in_before_aggregate_parameter = [[0, 1], [1]] c3_to_before_aggregate_computation = ( transformations.zip_selection_as_argument_to_lower_level_lambda( before_aggregate, c3_elements_in_before_aggregate_parameter).result.function) c6_index_in_before_aggregate_result = [[0, 0], [1, 0]] c3_to_c6_computation = transformations.select_output_from_lambda( c3_to_before_aggregate_computation, c6_index_in_before_aggregate_result) c8_index_in_after_aggregate_result = 2 after_aggregate_to_c8_computation = transformations.select_output_from_lambda( after_aggregate, c8_index_in_after_aggregate_result) c3_elements_in_after_aggregate_parameter = [[0, 0, 1], [0, 1]] c3_to_c8_computation = ( transformations.zip_selection_as_argument_to_lower_level_lambda( after_aggregate_to_c8_computation, c3_elements_in_after_aggregate_parameter).result.function) c3_to_unzipped_c4_computation = transformations.concatenate_function_outputs( c3_to_c6_computation, c3_to_c8_computation) c3_to_c4_computation = building_blocks.Lambda( c3_to_unzipped_c4_computation.parameter_name, c3_to_unzipped_c4_computation.parameter_type, building_block_factory.create_federated_zip( c3_to_unzipped_c4_computation.result)) return transformations.consolidate_and_extract_local_processing( c3_to_c4_computation)
def _extract_work(before_aggregate, after_aggregate): """Converts `before_aggregate` and `after_aggregate` to `work`. Args: before_aggregate: The first result of splitting `after_broadcast` on `intrinsic_defs.FEDERATED_AGGREGATE`. after_aggregate: The second result of splitting `after_broadcast` on `intrinsic_defs.FEDERATED_AGGREGATE`. Returns: `work` as specified by `canonical_form.CanonicalForm`, an instance of `building_blocks.CompiledComputation`. Raises: transformations.CanonicalFormCompilationError: If we extract an AST of the wrong type. """ # See `get_iterative_process_for_canonical_form()` above for the meaning of # variable names used in the code below. c3_elements_in_before_aggregate_parameter = [[0, 1], [1]] c3_to_before_aggregate_computation = ( transformations.zip_selection_as_argument_to_lower_level_lambda( before_aggregate, c3_elements_in_before_aggregate_parameter).result.function) c5_index_in_before_aggregate_result = 0 c3_to_c5_computation = transformations.select_output_from_lambda( c3_to_before_aggregate_computation, c5_index_in_before_aggregate_result) c6_index_in_after_aggregate_result = 2 after_aggregate_to_c6_computation = transformations.select_output_from_lambda( after_aggregate, c6_index_in_after_aggregate_result) c3_elements_in_after_aggregate_parameter = [[0, 0, 1], [0, 1]] c3_to_c6_computation = ( transformations.zip_selection_as_argument_to_lower_level_lambda( after_aggregate_to_c6_computation, c3_elements_in_after_aggregate_parameter).result.function) c3_to_unzipped_c4_computation = transformations.concatenate_function_outputs( c3_to_c5_computation, c3_to_c6_computation) c3_to_c4_computation = building_blocks.Lambda( c3_to_unzipped_c4_computation.parameter_name, c3_to_unzipped_c4_computation.parameter_type, building_block_factory.create_federated_zip( c3_to_unzipped_c4_computation.result)) work = transformations.consolidate_and_extract_local_processing( c3_to_c4_computation) return work
def test_parameters_are_mapped_together(self): x_reference = building_blocks.Reference('x', tf.int32) x_lambda = building_blocks.Lambda('x', tf.int32, x_reference) y_reference = building_blocks.Reference('y', tf.int32) y_lambda = building_blocks.Lambda('y', tf.int32, y_reference) concatenated = transformations.concatenate_function_outputs( x_lambda, y_lambda) parameter_name = concatenated.parameter_name def _raise_on_other_name_reference(comp): if isinstance(comp, building_blocks.Reference) and comp.name != parameter_name: raise ValueError return comp, True tree_analysis.check_has_unique_names(concatenated) transformation_utils.transform_postorder(concatenated, _raise_on_other_name_reference)