Esempio n. 1
0
def pad_dims(input, leftdims, rightdims):
    """Reshapes the input to a (leftdims + rightdims) tensor

    This helper function is used to convert pooling inputs with arbitrary
    non-pooling dimensions to the correct number of dimensions for the
    GPU pooling ops.

    This reduces or expands the number of dimensions of the input to
    exactly `leftdims`, by adding extra dimensions on the left or by
    combining some existing dimensions on the left of the input.

    Use `unpad_dims` to reshape back to the original dimensions.

    Examples
    --------
    Given input of shape (3, 5, 7), ``pad_dims(input, 2, 2)``
    adds a singleton dimension and reshapes to (1, 3, 5, 7).
    Given that output from pad_dims, ``unpad_dims(output, input, 2, 2)``
    reshapes back to (3, 5, 7).

    Given input of shape (3, 5, 7, 9), ``pad_dims(input, 2, 2)``
    does not reshape and returns output with shape (3, 5, 7, 9).

    Given input of shape (3, 5, 7, 9, 11), ``pad_dims(input, 2, 2)``
    combines the first two dimensions and reshapes to (15, 7, 9, 11).

    Given input of shape (3, 5, 7, 9), ``pad_dims(input, 2, 3)``
    adds a singleton dimension and reshapes to (1, 3, 5, 7, 9).
    """
    assert input.ndim >= rightdims

    if input.ndim == (leftdims + rightdims):
        return input

    # extract image dimensions
    img_shape = input.shape[-rightdims:]

    non_pool_ndim = input.ndim - rightdims
    if non_pool_ndim < leftdims:
        # too few dimensions, pad on the left
        dummy_dims = as_tensor([1] * (leftdims - non_pool_ndim))
        new_shape = join(0, dummy_dims, input.shape[:non_pool_ndim], img_shape)
    else:
        # too many dimensions, combine the leading dimensions
        batched_ndim = non_pool_ndim - leftdims + 1
        batch_size = prod(input.shape[:batched_ndim])
        # convert to a vector for join
        batch_size = shape_padright(batch_size, 1)
        new_shape = join(
            0, batch_size, input.shape[batched_ndim:non_pool_ndim], img_shape
        )

    # store in the required shape
    new_shape = cast(new_shape, "int64")
    input_ND = GpuReshape(leftdims + rightdims)(input, new_shape)
    return input_ND
Esempio n. 2
0
def unpad_dims(output, input, leftdims, rightdims):
    """Reshapes the output after pad_dims.

    This reverts the padding by `pad_dims`.
    """
    if output.ndim == input.ndim:
        return output

    # restore the output to the original shape
    outshp = join(0, input.shape[:-rightdims], output.shape[-rightdims:])
    return GpuReshape(input.ndim)(output, outshp)
Esempio n. 3
0
def test_jax_Join():
    a = matrix("a")
    b = matrix("b")

    x = aet.join(0, a, b)
    x_fg = FunctionGraph([a, b], [x])
    compare_jax_and_py(
        x_fg,
        [
            np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
            np.c_[[4.0, 5.0, 6.0]].astype(config.floatX),
        ],
    )
    compare_jax_and_py(
        x_fg,
        [
            np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
            np.c_[[4.0, 5.0]].astype(config.floatX),
        ],
    )

    x = aet.join(1, a, b)
    x_fg = FunctionGraph([a, b], [x])
    compare_jax_and_py(
        x_fg,
        [
            np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
            np.c_[[4.0, 5.0, 6.0]].astype(config.floatX),
        ],
    )
    compare_jax_and_py(
        x_fg,
        [
            np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX),
            np.c_[[5.0, 6.0]].astype(config.floatX),
        ],
    )