def _select_and_scatter_shape_rule( operand, source, init_value, *, select_jaxpr, select_consts, scatter_jaxpr, scatter_consts, window_dimensions, window_strides, padding): lax._check_shapelike("select_and_scatter", "window_dimensions", window_dimensions) lax._check_shapelike("select_and_scatter", "window_strides", window_strides) if len(window_dimensions) != len(window_strides): msg = ("select_and_scatter got inconsistent window_strides and " "window_dimensions: got window_strides {} and window_dimensions {}.") raise TypeError(msg.format(window_strides, window_dimensions)) return operand.shape
def _common_reduce_window_shape_rule(operand, window_dimensions, window_strides, padding, base_dilation, window_dilation): lax._check_shapelike("reduce_window", "window_dimensions", window_dimensions, non_zero_shape=True) lax._check_shapelike("reduce_window", "window_strides", window_strides, non_zero_shape=True) lax._check_shapelike("reduce_window", "base_dilation", base_dilation) lax._check_shapelike("reduce_window", "window_dilation", window_dilation) if operand.ndim != len(window_dimensions): msg = ("reduce_window got the wrong number of window_dimensions for " "operand: got operand shape {} with window_dimensions {}.") raise TypeError(msg.format(operand.shape, window_dimensions)) if len(window_strides) != len(window_dimensions): msg = ("reduce_window got inconsistent window_strides and " "window_dimensions: got window_strides {} and window_dimensions {}.") raise TypeError(msg.format(window_strides, window_dimensions)) if len(base_dilation) != len(window_dimensions): msg = ("reduce_window got inconsistent base_dilation and " "window_dimensions: got base_dilation {} and window_dimensions {}.") raise TypeError(msg.format(base_dilation, window_dimensions)) if len(window_dilation) != len(window_dimensions): msg = ("reduce_window got inconsistent window_dilation and " "window_dimensions: got window_dilation {} and window_dimensions " "{}.") raise TypeError(msg.format(window_dilation, window_dimensions)) return reduce_window_shape_tuple(operand.shape, window_dimensions, window_strides, padding, base_dilation, window_dilation)
def _check_conv_shapes(name, lhs_shape, rhs_shape, window_strides): """Check that conv shapes are valid and are consistent with window_strides.""" if len(lhs_shape) != len(rhs_shape): msg = "Arguments to {} must have same rank, got {} and {}." raise TypeError(msg.format(name, len(lhs_shape), len(rhs_shape))) if len(lhs_shape) < 2: msg = "Arguments to {} must have rank at least 2, got {} and {}." raise TypeError(msg.format(name, len(lhs_shape), len(rhs_shape))) if lhs_shape[1] != rhs_shape[1]: msg = "Arguments to {} must agree on input feature size, got {} and {}." raise TypeError(msg.format(name, lhs_shape[1], rhs_shape[1])) lax._check_shapelike(name, "window_strides", window_strides) if not np.all(np.greater(window_strides, 0)): msg = "All elements of window_strides must be positive, got {}." raise TypeError(msg.format(window_strides)) if len(window_strides) != len(lhs_shape) - 2: msg = "{} window_strides has wrong length: expected {}, got {}." expected_length = len(lhs_shape) - 2 raise TypeError(msg.format(name, expected_length, len(window_strides)))