def generic_plus(arg): """Adds two arguments when possible.""" x = arg[0] y = arg[1] _check_top_level_compatibility_with_generic_operators( x, y, 'Generic plus') if _generic_op_can_be_applied(x, y): return _apply_generic_op(tf.add, x, y) # TODO(b/136587334): Push this logic down a level elif isinstance(x.type_signature, computation_types.NamedTupleType): # This case is needed if federated types are nested deeply. names = [ t[0] for t in anonymous_tuple.iter_elements(x.type_signature) ] added = [ value_impl.ValueImpl.get_comp(generic_plus([x[i], y[i]])) for i in range(len(names)) ] named_added = building_block_factory.create_named_tuple( building_blocks.Tuple(added), names) return value_impl.ValueImpl(named_added, context_stack) else: raise TypeError( 'Generic plus encountered unexpected type {}, {}'.format( x.type_signature, y.type_signature))
def generic_multiply(arg): """Multiplies two arguments when possible.""" x = arg[0] y = arg[1] _check_top_level_compatibility_with_generic_operators( x, y, 'Generic multiply') if _generic_op_can_be_applied(x, y): return _apply_generic_op(tf.multiply, x, y) elif x.type_signature.is_struct(): # This case is needed if federated types are nested deeply. names = [t[0] for t in structure.iter_elements(x.type_signature)] multiplied = [ value_impl.ValueImpl.get_comp(generic_multiply([x[i], y[i]])) for i in range(len(names)) ] named_multiplied = building_block_factory.create_named_tuple( building_blocks.Struct(multiplied), names) return value_impl.ValueImpl(named_multiplied, context_stack) else: raise TypeError( 'Generic multiply encountered unexpected type {}, {}'.format( x.type_signature, y.type_signature))