def __enter__(self): # pylint: disable=protected-access self._conv_general_dilated_p_original = ( lax_convolution.conv_general_dilated_p) # The following primitive accepts a name argument which is passed into # the HLO metadata field. Here, it is the only argument changed from # the original lax implementation. lax_convolution.conv_general_dilated_p = lax.standard_primitive( shape_rule=lax_convolution._conv_general_dilated_shape_rule, dtype_rule=lax_convolution._conv_general_dilated_dtype_rule, name=self._op_name, translation_rule=functools.partial( lax_convolution._conv_general_dilated_translation_rule, expand_complex_convolutions=False)) xla.register_translation( lax_convolution.conv_general_dilated_p, functools.partial( lax_convolution._conv_general_dilated_translation_rule, expand_complex_convolutions=True), platform='cpu') xla.register_translation( lax_convolution.conv_general_dilated_p, functools.partial( lax_convolution._conv_general_dilated_translation_rule, expand_complex_convolutions=True), platform='gpu') ad.defbilinear(lax_convolution.conv_general_dilated_p, lax_convolution._conv_general_dilated_transpose_lhs, lax_convolution._conv_general_dilated_transpose_rhs) batching.primitive_batchers[lax_convolution.conv_general_dilated_p] = ( lax_convolution._conv_general_dilated_batch_rule) masking.masking_rules[lax_convolution.conv_general_dilated_p] = ( lax_convolution._conv_general_dilated_masking_rule)
def __enter__(self): # pylint: disable=protected-access # The following primitive accepts a name argument which is passed into # the HLO metadata field. Here, it is the only argument changed from # the original lax implementation. self._dot_general_p_original = lax.dot_general_p lax.dot_general_p = lax.standard_primitive( shape_rule=lax._dot_general_shape_rule, dtype_rule=lax._dot_general_dtype_rule, name=self._op_name, translation_rule=lax._dot_general_translation_rule) ad.defbilinear(lax.dot_general_p, lax._dot_general_transpose_lhs, lax._dot_general_transpose_rhs) batching.primitive_batchers[lax.dot_general_p] = lax._dot_general_batch_rule masking.masking_rules[lax.dot_general_p] = lax._dot_general_masking_rule
xla.parallel_translations[pdot_p] = _pdot_translation_rule def _pdot_transpose_lhs(g, y, *, axis_name, pos_contract, pos_batch): # TODO: avals with names, call pbroadcast with axis_name return lax._dot_general_transpose_lhs( g, y, dimension_numbers=[pos_contract, pos_batch], precision=None) def _pdot_transpose_rhs(g, x, *, axis_name, pos_contract, pos_batch): # TODO: avals with names, call pbroadcast with axis_name return lax._dot_general_transpose_rhs( g, x, dimension_numbers=[pos_contract, pos_batch], precision=None) ad.defbilinear(pdot_p, _pdot_transpose_lhs, _pdot_transpose_rhs) pxla.multi_host_supported_collectives.add(pdot_p) @config.register_omnistaging_disabler def omnistaging_disabler() -> None: global axis_index psum_p.bind = partial(core.Primitive.bind, psum_p) # type: ignore psum_p.def_impl(partial(pxla.apply_parallel_primitive, psum_p)) # type: ignore pxla.parallel_pure_rules[psum_p] = lambda *args, shape: ( x * prod(shape) for x in args) # type: ignore def _axis_index_bind(*, axis_name):
window_strides=window_strides, padding=padding, lhs_dilation=lhs_dilation, rhs_dilation=rhs_dilation, dimension_numbers=dimension_numbers, feature_group_count=feature_group_count, batch_group_count=batch_group_count, precision=precision, preferred_element_type=preferred_element_type) conv_general_dilated_p = lax.standard_primitive( _conv_general_dilated_shape_rule, _conv_general_dilated_dtype_rule, 'conv_general_dilated') ad.defbilinear(conv_general_dilated_p, _conv_general_dilated_transpose_lhs, _conv_general_dilated_transpose_rhs) batching.primitive_batchers[conv_general_dilated_p] = \ _conv_general_dilated_batch_rule masking.masking_rules[conv_general_dilated_p] = \ _conv_general_dilated_masking_rule def _complex_mul(mul, x, y): # We use a trick for complex multiplication sometimes attributed to Gauss # which uses three multiplications and five additions; instead of the naive # method of four multiplications and two additions. # https://en.wikipedia.org/wiki/Multiplication_algorithm#Complex_multiplication_algorithm # # This performance win comes with a trade-off in accuracy; especially in # cases when the real and imaginary differ hugely in magnitude. The relative # error bound (e.g. 1p-24 in case of float32) would be relative to the