def sequence_map(self, fn, arg): """Implements `sequence_map` as defined in `api/intrinsics.py`.""" fn = value_impl.to_value(fn, None, self._context_stack) py_typecheck.check_type(fn.type_signature, computation_types.FunctionType) arg = value_impl.to_value(arg, None, self._context_stack) if isinstance(arg.type_signature, computation_types.SequenceType): fn = value_impl.ValueImpl.get_comp(fn) arg = value_impl.ValueImpl.get_comp(arg) return value_impl.ValueImpl( building_block_factory.create_sequence_map(fn, arg), self._context_stack) elif isinstance(arg.type_signature, computation_types.FederatedType): parameter_type = computation_types.SequenceType( fn.type_signature.parameter) result_type = computation_types.SequenceType(fn.type_signature.result) intrinsic_type = computation_types.FunctionType( (fn.type_signature, parameter_type), result_type) intrinsic = building_blocks.Intrinsic(intrinsic_defs.SEQUENCE_MAP.uri, intrinsic_type) intrinsic_impl = value_impl.ValueImpl(intrinsic, self._context_stack) local_fn = value_utils.get_curried(intrinsic_impl)(fn) if arg.type_signature.placement in [ placements.SERVER, placements.CLIENTS ]: return self.federated_map(local_fn, arg) else: raise TypeError('Unsupported placement {}.'.format( arg.type_signature.placement)) else: raise TypeError( 'Cannot apply `tff.sequence_map()` to a value of type {}.'.format( arg.type_signature))
def test_get_curried(self): add_numbers = value_impl.ValueImpl( building_blocks.ComputationBuildingBlock.from_proto( computation_impl.ComputationImpl.get_proto( computations.tf_computation(tf.add, [tf.int32, tf.int32]))), _context_stack) curried = value_utils.get_curried(add_numbers) self.assertEqual(str(curried.type_signature), '(int32 -> (int32 -> int32))') comp, _ = tree_transformations.uniquify_compiled_computation_names( value_impl.ValueImpl.get_comp(curried)) self.assertEqual(comp.compact_representation(), '(arg0 -> (arg1 -> comp#1(<arg0,arg1>)))')
def sequence_map(self, mapping_fn, value): """Implements `sequence_map` as defined in `api/intrinsics.py`. Args: mapping_fn: As in `api/intrinsics.py`. value: As in `api/intrinsics.py`. Returns: As in `api/intrinsics.py`. Raises: TypeError: As in `api/intrinsics.py`. """ mapping_fn = value_impl.to_value(mapping_fn, None, self._context_stack) py_typecheck.check_type(mapping_fn.type_signature, computation_types.FunctionType) sequence_map_intrinsic = value_impl.ValueImpl( computation_building_blocks.Intrinsic( intrinsic_defs.SEQUENCE_MAP.uri, computation_types.FunctionType( [ mapping_fn.type_signature, computation_types.SequenceType( mapping_fn.type_signature.parameter) ], computation_types.SequenceType( mapping_fn.type_signature.result))), self._context_stack) value = value_impl.to_value(value, None, self._context_stack) if isinstance(value.type_signature, computation_types.SequenceType): return sequence_map_intrinsic(mapping_fn, value) elif isinstance(value.type_signature, computation_types.FederatedType): local_func = value_utils.get_curried(sequence_map_intrinsic)( mapping_fn) if value.type_signature.placement is placements.SERVER: return self.federated_apply(local_func, value) elif value.type_signature.placement is placements.CLIENTS: return self.federated_map(local_func, value) else: raise TypeError('Unsupported placement {}.'.format( str(value.type_signature.placement))) else: raise TypeError( 'Cannot apply `tff.sequence_map()` to a value of type {}.'. format(str(value.type_signature)))