def pad_dims(input, leftdims, rightdims): """Reshapes the input to a (leftdims + rightdims) tensor This helper function is used to convert pooling inputs with arbitrary non-pooling dimensions to the correct number of dimensions for the GPU pooling ops. This reduces or expands the number of dimensions of the input to exactly `leftdims`, by adding extra dimensions on the left or by combining some existing dimensions on the left of the input. Use `unpad_dims` to reshape back to the original dimensions. Examples -------- Given input of shape (3, 5, 7), ``pad_dims(input, 2, 2)`` adds a singleton dimension and reshapes to (1, 3, 5, 7). Given that output from pad_dims, ``unpad_dims(output, input, 2, 2)`` reshapes back to (3, 5, 7). Given input of shape (3, 5, 7, 9), ``pad_dims(input, 2, 2)`` does not reshape and returns output with shape (3, 5, 7, 9). Given input of shape (3, 5, 7, 9, 11), ``pad_dims(input, 2, 2)`` combines the first two dimensions and reshapes to (15, 7, 9, 11). Given input of shape (3, 5, 7, 9), ``pad_dims(input, 2, 3)`` adds a singleton dimension and reshapes to (1, 3, 5, 7, 9). """ assert input.ndim >= rightdims if input.ndim == (leftdims + rightdims): return input # extract image dimensions img_shape = input.shape[-rightdims:] non_pool_ndim = input.ndim - rightdims if non_pool_ndim < leftdims: # too few dimensions, pad on the left dummy_dims = as_tensor([1] * (leftdims - non_pool_ndim)) new_shape = join(0, dummy_dims, input.shape[:non_pool_ndim], img_shape) else: # too many dimensions, combine the leading dimensions batched_ndim = non_pool_ndim - leftdims + 1 batch_size = prod(input.shape[:batched_ndim]) # convert to a vector for join batch_size = shape_padright(batch_size, 1) new_shape = join( 0, batch_size, input.shape[batched_ndim:non_pool_ndim], img_shape ) # store in the required shape new_shape = cast(new_shape, "int64") input_ND = GpuReshape(leftdims + rightdims)(input, new_shape) return input_ND
def unpad_dims(output, input, leftdims, rightdims): """Reshapes the output after pad_dims. This reverts the padding by `pad_dims`. """ if output.ndim == input.ndim: return output # restore the output to the original shape outshp = join(0, input.shape[:-rightdims], output.shape[-rightdims:]) return GpuReshape(input.ndim)(output, outshp)
def test_jax_Join(): a = matrix("a") b = matrix("b") x = aet.join(0, a, b) x_fg = FunctionGraph([a, b], [x]) compare_jax_and_py( x_fg, [ np.c_[[1.0, 2.0, 3.0]].astype(config.floatX), np.c_[[4.0, 5.0, 6.0]].astype(config.floatX), ], ) compare_jax_and_py( x_fg, [ np.c_[[1.0, 2.0, 3.0]].astype(config.floatX), np.c_[[4.0, 5.0]].astype(config.floatX), ], ) x = aet.join(1, a, b) x_fg = FunctionGraph([a, b], [x]) compare_jax_and_py( x_fg, [ np.c_[[1.0, 2.0, 3.0]].astype(config.floatX), np.c_[[4.0, 5.0, 6.0]].astype(config.floatX), ], ) compare_jax_and_py( x_fg, [ np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX), np.c_[[5.0, 6.0]].astype(config.floatX), ], )