def reduce_window_shape_tuple(operand_shape, window_dimensions, window_strides, padding, base_dilation=None, window_dilation=None): if base_dilation is not None: operand_shape = lax._dilate_shape(operand_shape, base_dilation) if window_dilation is not None: window_dimensions = lax._dilate_shape(window_dimensions, window_dilation) pads_lo, pads_hi = zip(*padding) operand_padded = core.sum_shapes(operand_shape, pads_lo, pads_hi) return core.stride_shape(operand_padded, window_dimensions, window_strides)
def test_stride_shape(self): da, = shape_poly.parse_spec("a,", (2,)) self.assertEqual((8, 9), core.stride_shape((10, 20), (3, 3), (1, 2))) self.assertEqual((da, 9), core.stride_shape((da, 20), (1, 3), (1, 2))) with self.assertRaisesRegex(core.InconclusiveDimensionOperation, re.escape("Only striding with window_size == window_stride == 1 is supported for shape variables (var = a, window_size = 2, stride = 1")): core.stride_shape((da, 20), (2, 3), (1, 2)) with self.assertRaisesRegex(core.InconclusiveDimensionOperation, re.escape("Only striding with window_size == window_stride == 1 is supported for shape variables (var = a, window_size = 1, stride = 2")): core.stride_shape((da, 20), (1, 3), (2, 2))
def conv_shape_tuple(lhs_shape, rhs_shape, strides, pads, batch_group_count=1): """Compute the shape tuple of a conv given input shapes in canonical order.""" if isinstance(pads, str): pads = lax.padtype_to_pads(lhs_shape[2:], rhs_shape[2:], strides, pads) if len(pads) != len(lhs_shape) - 2: msg = "Wrong number of explicit pads for convolution: expected {}, got {}." raise TypeError(msg.format(len(lhs_shape) - 2, len(pads))) lhs_padded = np.add(lhs_shape[2:], np.sum(np.array(pads).reshape(-1, 2), axis=1)) out_space = core.stride_shape(lhs_padded, rhs_shape[2:], strides) out_space = np.maximum(0, out_space) if batch_group_count > 1: assert lhs_shape[0] % batch_group_count == 0 out_shape_0 = lhs_shape[0] // batch_group_count else: out_shape_0 = lhs_shape[0] out_shape = (out_shape_0, rhs_shape[0]) return tuple(out_shape + tuple(out_space))
def test_stride_shape(self): """(s - window_size) // window_stride + 1""" a, stride = shape_poly.parse_spec("a, s", (2, 3)) self.assertEqual((8, 9), core.stride_shape((10, 20), window_size=(3, 3), window_stride=(1, 2))) self.assertEqual((a, 9), core.stride_shape((a, 20), (1, 3), (1, 2))) self.assertEqual((a - 1, 9), core.stride_shape((a, 20), (2, 3), (1, 2))) self.assertEqual((a + 1, 9), core.stride_shape((a * stride + 2, 20), (2, 3), (stride, 2))) with self.assertRaisesRegex( core.InconclusiveDimensionOperation, re.escape( "Cannot compute stride for dimension 'a', window_size '1', stride '2'. Reason: Dimension polynomial 'a + -1' is not a multiple of '2'")): core.stride_shape((a, 20), (1, 3), (2, 2))