示例#1
0
 def __enter__(self):
   # pylint: disable=protected-access
   self._conv_general_dilated_p_original = (
       lax_convolution.conv_general_dilated_p)
   # The following primitive accepts a name argument which is passed into
   # the HLO metadata field. Here, it is the only argument changed from
   # the original lax implementation.
   lax_convolution.conv_general_dilated_p = lax.standard_primitive(
       shape_rule=lax_convolution._conv_general_dilated_shape_rule,
       dtype_rule=lax_convolution._conv_general_dilated_dtype_rule,
       name=self._op_name,
       translation_rule=functools.partial(
           lax_convolution._conv_general_dilated_translation_rule,
           expand_complex_convolutions=False))
   xla.register_translation(
       lax_convolution.conv_general_dilated_p,
       functools.partial(
           lax_convolution._conv_general_dilated_translation_rule,
           expand_complex_convolutions=True),
       platform='cpu')
   xla.register_translation(
       lax_convolution.conv_general_dilated_p,
       functools.partial(
           lax_convolution._conv_general_dilated_translation_rule,
           expand_complex_convolutions=True),
       platform='gpu')
   ad.defbilinear(lax_convolution.conv_general_dilated_p,
                  lax_convolution._conv_general_dilated_transpose_lhs,
                  lax_convolution._conv_general_dilated_transpose_rhs)
   batching.primitive_batchers[lax_convolution.conv_general_dilated_p] = (
       lax_convolution._conv_general_dilated_batch_rule)
   masking.masking_rules[lax_convolution.conv_general_dilated_p] = (
       lax_convolution._conv_general_dilated_masking_rule)
示例#2
0
 def __enter__(self):
   # pylint: disable=protected-access
   # The following primitive accepts a name argument which is passed into
   # the HLO metadata field. Here, it is the only argument changed from
   # the original lax implementation.
   self._dot_general_p_original = lax.dot_general_p
   lax.dot_general_p = lax.standard_primitive(
       shape_rule=lax._dot_general_shape_rule,
       dtype_rule=lax._dot_general_dtype_rule,
       name=self._op_name,
       translation_rule=lax._dot_general_translation_rule)
   ad.defbilinear(lax.dot_general_p, lax._dot_general_transpose_lhs,
                  lax._dot_general_transpose_rhs)
   batching.primitive_batchers[lax.dot_general_p] = lax._dot_general_batch_rule
   masking.masking_rules[lax.dot_general_p] = lax._dot_general_masking_rule
示例#3
0
xla.parallel_translations[pdot_p] = _pdot_translation_rule


def _pdot_transpose_lhs(g, y, *, axis_name, pos_contract, pos_batch):
    # TODO: avals with names, call pbroadcast with axis_name
    return lax._dot_general_transpose_lhs(
        g, y, dimension_numbers=[pos_contract, pos_batch], precision=None)


def _pdot_transpose_rhs(g, x, *, axis_name, pos_contract, pos_batch):
    # TODO: avals with names, call pbroadcast with axis_name
    return lax._dot_general_transpose_rhs(
        g, x, dimension_numbers=[pos_contract, pos_batch], precision=None)


ad.defbilinear(pdot_p, _pdot_transpose_lhs, _pdot_transpose_rhs)

pxla.multi_host_supported_collectives.add(pdot_p)


@config.register_omnistaging_disabler
def omnistaging_disabler() -> None:
    global axis_index

    psum_p.bind = partial(core.Primitive.bind, psum_p)  # type: ignore
    psum_p.def_impl(partial(pxla.apply_parallel_primitive,
                            psum_p))  # type: ignore
    pxla.parallel_pure_rules[psum_p] = lambda *args, shape: (
        x * prod(shape) for x in args)  # type: ignore

    def _axis_index_bind(*, axis_name):
示例#4
0
                                window_strides=window_strides,
                                padding=padding,
                                lhs_dilation=lhs_dilation,
                                rhs_dilation=rhs_dilation,
                                dimension_numbers=dimension_numbers,
                                feature_group_count=feature_group_count,
                                batch_group_count=batch_group_count,
                                precision=precision,
                                preferred_element_type=preferred_element_type)


conv_general_dilated_p = lax.standard_primitive(
    _conv_general_dilated_shape_rule, _conv_general_dilated_dtype_rule,
    'conv_general_dilated')

ad.defbilinear(conv_general_dilated_p, _conv_general_dilated_transpose_lhs,
               _conv_general_dilated_transpose_rhs)
batching.primitive_batchers[conv_general_dilated_p] = \
    _conv_general_dilated_batch_rule
masking.masking_rules[conv_general_dilated_p] = \
  _conv_general_dilated_masking_rule


def _complex_mul(mul, x, y):
    # We use a trick for complex multiplication sometimes attributed to Gauss
    # which uses three multiplications and five additions; instead of the naive
    # method of four multiplications and two additions.
    # https://en.wikipedia.org/wiki/Multiplication_algorithm#Complex_multiplication_algorithm
    #
    # This performance win comes with a trade-off in accuracy; especially in
    # cases when the real and imaginary differ hugely in magnitude. The relative
    # error bound (e.g. 1p-24 in case of float32) would be relative to the