Beispiel #1
0
def _select_and_gather_add_transpose(t, tangents, operand, *, select_prim,
                                     window_dimensions, window_strides,
                                     padding, base_dilation, window_dilation):
    assert select_prim in (lax.le_p, lax.ge_p)
    assert (ad.is_undefined_primal(tangents)
            and not ad.is_undefined_primal(operand))
    if any(d != 1 for d in window_dilation):
        msg = (
            "VJP not implemented for select_and_gather (MaxPool) with window "
            "dilation, got window_dilation={}.")
        raise NotImplementedError(msg.format(window_dilation))
    if type(t) is ad_util.Zero:
        return [ad_util.Zero(tangents.aval), None]
    has_base_dilation = any(d != 1 for d in base_dilation)
    if has_base_dilation:
        select_identity = (lax._get_max_identity if select_prim is lax.ge_p
                           else lax._get_min_identity)
        operand = lax.pad(operand, select_identity(operand.dtype),
                          tuple((0, 0, d - 1) for d in base_dilation))
    result = _select_and_scatter_add(t, operand, select_prim,
                                     window_dimensions, window_strides,
                                     padding)
    if has_base_dilation:
        result = slicing.slice(result, (0, ) * len(result.shape), result.shape,
                               base_dilation)
    return [result, None]
Beispiel #2
0
def _select_and_scatter_add_impl(source, operand, *,
                                 select_prim, window_dimensions, window_strides,
                                 padding, expand_padding):
  dtype = source.dtype
  select = lambda x, y: select_prim.bind(x, y)
  scatter = lax.bitwise_or if dtype == np.bool_ else lax.add
  if expand_padding:
    operand_shape = operand.shape
    original_padding = padding
    identity = (lax._get_max_identity if select_prim is lax.ge_p
                else lax._get_min_identity)
    pads = [(lo, hi, 0) for (lo, hi) in padding]
    operand = lax.pad(operand, identity(dtype), pads)
    padding = [(0, 0) for _ in padding]
  out = _select_and_scatter(
      operand, select, window_dimensions, window_strides, padding, source,
      lax._zero(operand), scatter)
  if expand_padding:
    start_indices = [lo for (lo, hi) in original_padding]
    stop_indices = [lo + d for ((lo, hi), d) in zip(original_padding,
                                                    operand_shape)]
    out = slicing.slice(out, start_indices, stop_indices)
  return out