Exemple #1
0
def _select_and_scatter_shape_rule(
    operand, source, init_value, *, select_jaxpr, select_consts, scatter_jaxpr,
    scatter_consts, window_dimensions, window_strides, padding):
  lax._check_shapelike("select_and_scatter", "window_dimensions",
                       window_dimensions)
  lax._check_shapelike("select_and_scatter", "window_strides", window_strides)
  if len(window_dimensions) != len(window_strides):
    msg = ("select_and_scatter got inconsistent window_strides and "
           "window_dimensions: got window_strides {} and window_dimensions {}.")
    raise TypeError(msg.format(window_strides, window_dimensions))
  return operand.shape
Exemple #2
0
def _common_reduce_window_shape_rule(operand, window_dimensions,
                                     window_strides, padding, base_dilation,
                                     window_dilation):
  lax._check_shapelike("reduce_window", "window_dimensions", window_dimensions,
                       non_zero_shape=True)
  lax._check_shapelike("reduce_window", "window_strides", window_strides,
                       non_zero_shape=True)
  lax._check_shapelike("reduce_window", "base_dilation", base_dilation)
  lax._check_shapelike("reduce_window", "window_dilation", window_dilation)
  if operand.ndim != len(window_dimensions):
    msg = ("reduce_window got the wrong number of window_dimensions for "
           "operand: got operand shape {} with window_dimensions {}.")
    raise TypeError(msg.format(operand.shape, window_dimensions))
  if len(window_strides) != len(window_dimensions):
    msg = ("reduce_window got inconsistent window_strides and "
           "window_dimensions: got window_strides {} and window_dimensions {}.")
    raise TypeError(msg.format(window_strides, window_dimensions))
  if len(base_dilation) != len(window_dimensions):
    msg = ("reduce_window got inconsistent base_dilation and "
           "window_dimensions: got base_dilation {} and window_dimensions {}.")
    raise TypeError(msg.format(base_dilation, window_dimensions))
  if len(window_dilation) != len(window_dimensions):
    msg = ("reduce_window got inconsistent window_dilation and "
           "window_dimensions: got window_dilation {} and window_dimensions "
           "{}.")
    raise TypeError(msg.format(window_dilation, window_dimensions))

  return reduce_window_shape_tuple(operand.shape, window_dimensions,
                                   window_strides, padding, base_dilation,
                                   window_dilation)
Exemple #3
0
def _check_conv_shapes(name, lhs_shape, rhs_shape, window_strides):
    """Check that conv shapes are valid and are consistent with window_strides."""
    if len(lhs_shape) != len(rhs_shape):
        msg = "Arguments to {} must have same rank, got {} and {}."
        raise TypeError(msg.format(name, len(lhs_shape), len(rhs_shape)))
    if len(lhs_shape) < 2:
        msg = "Arguments to {} must have rank at least 2, got {} and {}."
        raise TypeError(msg.format(name, len(lhs_shape), len(rhs_shape)))
    if lhs_shape[1] != rhs_shape[1]:
        msg = "Arguments to {} must agree on input feature size, got {} and {}."
        raise TypeError(msg.format(name, lhs_shape[1], rhs_shape[1]))
    lax._check_shapelike(name, "window_strides", window_strides)
    if not np.all(np.greater(window_strides, 0)):
        msg = "All elements of window_strides must be positive, got {}."
        raise TypeError(msg.format(window_strides))
    if len(window_strides) != len(lhs_shape) - 2:
        msg = "{} window_strides has wrong length: expected {}, got {}."
        expected_length = len(lhs_shape) - 2
        raise TypeError(msg.format(name, expected_length, len(window_strides)))