Example #1
0
def _conv_general_vjp_lhs_padding(in_shape, window_dimensions, window_strides,
                                  out_shape, padding, lhs_dilation,
                                  rhs_dilation) -> List[Tuple[int, int]]:
    lhs_dilated_shape = lax._dilate_shape(in_shape, lhs_dilation)
    rhs_dilated_shape = lax._dilate_shape(window_dimensions, rhs_dilation)
    out_dilated_shape = lax._dilate_shape(out_shape, window_strides)
    pad_before = np.subtract(rhs_dilated_shape, [lo for lo, _ in padding]) - 1
    pad_after = (np.add(lhs_dilated_shape, rhs_dilated_shape) - 1 -
                 out_dilated_shape - pad_before)
    return safe_zip(pad_before, pad_after)
Example #2
0
def reduce_window_shape_tuple(operand_shape, window_dimensions, window_strides,
                              padding, base_dilation=None,
                              window_dilation=None):
  if base_dilation is not None:
    operand_shape = lax._dilate_shape(operand_shape, base_dilation)
  if window_dilation is not None:
    window_dimensions = lax._dilate_shape(window_dimensions, window_dilation)
  pads_lo, pads_hi = zip(*padding)
  operand_padded = core.sum_shapes(operand_shape, pads_lo, pads_hi)
  return core.stride_shape(operand_padded, window_dimensions, window_strides)
Example #3
0
def _conv_general_vjp_rhs_padding(in_shape, window_dimensions, window_strides,
                                  out_shape, padding, lhs_dilation,
                                  rhs_dilation):

    if len(in_shape) == 0:  # 0D conv
        return []
    lhs_dilated_shape = lax._dilate_shape(in_shape, lhs_dilation)
    rhs_dilated_shape = lax._dilate_shape(window_dimensions, rhs_dilation)
    out_dilated_shape = lax._dilate_shape(out_shape, window_strides)
    pads_lo, _ = zip(*padding)
    pads_from_lhs = core.diff_shape(out_dilated_shape, lhs_dilated_shape)
    pads_from_rhs = core.diff_shape(
        core.diff_shape(rhs_dilated_shape, pads_lo), (1, ) * len(pads_lo))
    pads_hi = core.sum_shapes(pads_from_lhs, pads_from_rhs)
    return list(zip(pads_lo, pads_hi))
Example #4
0
def reduce_window(operand,
                  init_value,
                  computation: Callable,
                  window_dimensions: core.Shape,
                  window_strides: Sequence[int],
                  padding: Union[str, Sequence[Tuple[int, int]]],
                  base_dilation: Optional[Sequence[int]] = None,
                  window_dilation: Optional[Sequence[int]] = None) -> Array:
    """Wraps XLA's `ReduceWindowWithGeneralPadding
  <https://www.tensorflow.org/xla/operation_semantics#reducewindow>`_
  operator.
  """
    flat_operands, operand_tree = tree_util.tree_flatten(operand)
    flat_init_values, init_value_tree = tree_util.tree_flatten(init_value)
    if operand_tree != init_value_tree:
        raise ValueError('Operands must have the same tree structure as '
                         f'init_values: {operand_tree} vs. {init_value_tree}')
    if len(flat_operands) == 0:
        raise ValueError('reduce_window must have at least one operand.')
    if len(flat_operands) != len(flat_init_values):
        raise ValueError(
            'Must have same total number of operands as init_values: '
            f' {len(flat_operands)} vs. {len(flat_init_values)}')
    if isinstance(padding, str):
        dilated_window_dims = (window_dimensions if
                               window_dilation is None else lax._dilate_shape(
                                   window_dimensions, window_dilation))
        padding = tuple(
            lax.padtype_to_pads(flat_operands[0].shape, dilated_window_dims,
                                window_strides, padding))
    else:
        padding = tuple(padding)
    if base_dilation is None:
        base_dilation = (1, ) * len(window_dimensions)
    if window_dilation is None:
        window_dilation = (1, ) * len(window_dimensions)
    monoid_reducer = _get_monoid_window_reducer(computation, flat_init_values)
    if monoid_reducer:
        return monoid_reducer(operand, window_dimensions, window_strides,
                              padding, base_dilation, window_dilation)
    else:
        flat_init_avals = map(lax._abstractify, flat_init_values)
        jaxpr, consts, out_tree = lax._variadic_reduction_jaxpr(
            computation, tuple(flat_init_avals), init_value_tree)
        if operand_tree != out_tree:
            raise ValueError(
                'reduce_window output must have the same tree structure as the operands'
                f' {operand_tree} vs. {out_tree}')
        out_flat = reduce_window_p.bind(
            *(flat_operands + flat_init_values),
            jaxpr=jaxpr,
            consts=consts,
            window_dimensions=tuple(window_dimensions),
            window_strides=tuple(window_strides),
            padding=padding,
            base_dilation=tuple(base_dilation),
            window_dilation=tuple(window_dilation))
        return tree_util.tree_unflatten(out_tree, out_flat)
Example #5
0
def _conv_general_dilated_shape_rule(lhs: core.ShapedArray,
                                     rhs: core.ShapedArray, *, window_strides,
                                     padding, lhs_dilation, rhs_dilation,
                                     dimension_numbers, feature_group_count,
                                     batch_group_count,
                                     **unused_kwargs) -> Tuple[int, ...]:
    assert type(dimension_numbers) is ConvDimensionNumbers
    if len(lhs.shape) != len(rhs.shape):
        msg = ("conv_general_dilated lhs and rhs must have the same number of "
               "dimensions, but got {} and {}.")
        raise ValueError(msg.format(lhs.shape, rhs.shape))
    if not feature_group_count > 0:
        msg = ("conv_general_dilated feature_group_count "
               "must be a positive integer, got {}.")
        raise ValueError(msg.format(feature_group_count))
    lhs_feature_count = lhs.shape[dimension_numbers.lhs_spec[1]]
    quot, rem = divmod(lhs_feature_count, feature_group_count)
    if rem:
        msg = (
            "conv_general_dilated feature_group_count must divide lhs feature "
            "dimension size, but {} does not divide {}.")
        raise ValueError(msg.format(feature_group_count, lhs_feature_count))
    if not core.symbolic_equal_dim(quot,
                                   rhs.shape[dimension_numbers.rhs_spec[1]]):
        msg = (
            "conv_general_dilated lhs feature dimension size divided by "
            "feature_group_count must equal the rhs input feature dimension "
            "size, but {} // {} != {}.")
        raise ValueError(
            msg.format(lhs_feature_count, feature_group_count,
                       rhs.shape[dimension_numbers.rhs_spec[1]]))
    if rhs.shape[dimension_numbers.rhs_spec[0]] % feature_group_count:
        msg = (
            "conv_general_dilated rhs output feature dimension size must be a "
            "multiple of feature_group_count, but {} is not a multiple of {}.")
        raise ValueError(
            msg.format(rhs.shape[dimension_numbers.rhs_spec[0]],
                       feature_group_count))

    if not batch_group_count > 0:
        msg = ("conv_general_dilated batch_group_count "
               "must be a positive integer, got {}.")
        raise ValueError(msg.format(batch_group_count))
    lhs_batch_count = lhs.shape[dimension_numbers.lhs_spec[0]]
    if batch_group_count > 1 and lhs_batch_count % batch_group_count != 0:
        msg = ("conv_general_dilated batch_group_count must divide lhs batch "
               "dimension size, but {} does not divide {}.")
        raise ValueError(msg.format(batch_group_count, lhs_batch_count))

    if rhs.shape[dimension_numbers.rhs_spec[0]] % batch_group_count:
        msg = (
            "conv_general_dilated rhs output feature dimension size must be a "
            "multiple of batch_group_count, but {} is not a multiple of {}.")
        raise ValueError(
            msg.format(rhs.shape[dimension_numbers.rhs_spec[0]],
                       batch_group_count))

    if batch_group_count > 1 and feature_group_count > 1:
        msg = (
            "At most one of batch_group_count and feature_group_count may be > "
            "1, got batch_group_count={} and feature_group_count={}")
        raise ValueError(msg.format(batch_group_count, feature_group_count))

    if len(_conv_sdims(dimension_numbers.rhs_spec)) != len(window_strides):
        msg = ("conv_general_dilated window and window_strides must have "
               "the same number of dimensions, but got {} and {}")
        raise ValueError(
            msg.format(len(_conv_sdims(dimension_numbers.rhs_spec)),
                       len(window_strides)))

    lhs_perm, rhs_perm, out_perm = dimension_numbers
    lhs_trans = lax._dilate_shape(np.take(lhs.shape, lhs_perm), lhs_dilation)
    rhs_trans = lax._dilate_shape(np.take(rhs.shape, rhs_perm), rhs_dilation)
    out_trans = conv_shape_tuple(lhs_trans, rhs_trans, window_strides, padding,
                                 batch_group_count)
    return tuple(np.take(out_trans,
                         np.argsort(out_perm)))  # type: ignore[arg-type]