Пример #1
0
def reduce_window_shape_tuple(operand_shape, window_dimensions, window_strides,
                              padding, base_dilation=None,
                              window_dilation=None):
  if base_dilation is not None:
    operand_shape = lax._dilate_shape(operand_shape, base_dilation)
  if window_dilation is not None:
    window_dimensions = lax._dilate_shape(window_dimensions, window_dilation)
  pads_lo, pads_hi = zip(*padding)
  operand_padded = core.sum_shapes(operand_shape, pads_lo, pads_hi)
  return core.stride_shape(operand_padded, window_dimensions, window_strides)
Пример #2
0
  def test_stride_shape(self):
    da, = shape_poly.parse_spec("a,", (2,))

    self.assertEqual((8, 9), core.stride_shape((10, 20), (3, 3), (1, 2)))
    self.assertEqual((da, 9), core.stride_shape((da, 20), (1, 3), (1, 2)))

    with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
                                re.escape("Only striding with window_size == window_stride == 1 is supported for shape variables (var = a, window_size = 2, stride = 1")):
      core.stride_shape((da, 20), (2, 3), (1, 2))

    with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
                                re.escape("Only striding with window_size == window_stride == 1 is supported for shape variables (var = a, window_size = 1, stride = 2")):
      core.stride_shape((da, 20), (1, 3), (2, 2))
Пример #3
0
def conv_shape_tuple(lhs_shape, rhs_shape, strides, pads, batch_group_count=1):
    """Compute the shape tuple of a conv given input shapes in canonical order."""
    if isinstance(pads, str):
        pads = lax.padtype_to_pads(lhs_shape[2:], rhs_shape[2:], strides, pads)
    if len(pads) != len(lhs_shape) - 2:
        msg = "Wrong number of explicit pads for convolution: expected {}, got {}."
        raise TypeError(msg.format(len(lhs_shape) - 2, len(pads)))

    lhs_padded = np.add(lhs_shape[2:],
                        np.sum(np.array(pads).reshape(-1, 2), axis=1))
    out_space = core.stride_shape(lhs_padded, rhs_shape[2:], strides)
    out_space = np.maximum(0, out_space)
    if batch_group_count > 1:
        assert lhs_shape[0] % batch_group_count == 0
        out_shape_0 = lhs_shape[0] // batch_group_count
    else:
        out_shape_0 = lhs_shape[0]
    out_shape = (out_shape_0, rhs_shape[0])
    return tuple(out_shape + tuple(out_space))
Пример #4
0
  def test_stride_shape(self):
    """(s - window_size) // window_stride + 1"""
    a, stride = shape_poly.parse_spec("a, s", (2, 3))

    self.assertEqual((8, 9), core.stride_shape((10, 20), window_size=(3, 3), window_stride=(1, 2)))
    self.assertEqual((a, 9), core.stride_shape((a, 20), (1, 3), (1, 2)))

    self.assertEqual((a - 1, 9), core.stride_shape((a, 20), (2, 3), (1, 2)))
    self.assertEqual((a + 1, 9), core.stride_shape((a * stride + 2, 20), (2, 3), (stride, 2)))

    with self.assertRaisesRegex(
        core.InconclusiveDimensionOperation,
        re.escape(
          "Cannot compute stride for dimension 'a', window_size '1', stride '2'. Reason: Dimension polynomial 'a + -1' is not a multiple of '2'")):
      core.stride_shape((a, 20), (1, 3), (2, 2))