예제 #1
0
def fwd(inputs, layer_params):
    """JAX forward pass of Linear, Conv, or ConvTranspose."""
    if isinstance(layer_params, dict):
        # Conv/ConvTranspose: Reshape input if necessary:
        if len(inputs.shape) < 4:
            w = h = int(np.sqrt(inputs.shape[-1] / layer_params['n_cin']))
            inputs = inputs.reshape(inputs.shape[0], h, w,
                                    layer_params['n_cin'])
        W, b = layer_params['W'], np.reshape(layer_params['b'], [1, 1, 1, -1])
        if 'transpose' in layer_params and layer_params['transpose']:
            # ConvTranspose
            return default_conv_transpose(
                inputs, W, (layer_params['stride'], layer_params['stride']),
                layer_params['padding']) + b
        else:
            # Normal Conv2D
            dn = lax.conv_dimension_numbers(inputs.shape, W.shape,
                                            ('NHWC', 'HWIO', 'NHWC'))
            return lax.conv_general_dilated(
                inputs, W, (layer_params['stride'], layer_params['stride']),
                layer_params['padding'], (1, 1), (1, 1), dn) + b
    elif isinstance(layer_params, tuple):
        # Linear fully-connected layer
        # TODO: Figure out why we were dropping batch dim before here
        inputs = (inputs.reshape(inputs.shape[0], -1)
                  if len(inputs.shape) == 4 else inputs)
        (W, b) = layer_params
        return jnp.dot(inputs, W) + b
    elif callable(layer_params):
        # Most general way of specifying an affine layer is to provide its function.
        return layer_params(inputs)
    else:
        raise NotImplementedError('Unknown layer')
예제 #2
0
    def build(self, input_shape: Tuple):
        "Initializes the Kernel and stores the Conv2D weights"
        input_shape = (None, *input_shape[-3:])

        k1, k2 = random.split(self.seed)
        kernel_shape = self.compute_kernel_shape(input_shape)
        self.kernel_weights = self.add_weight(
            key=k1,
            shape=kernel_shape,
            initializer=self.kernel_initializer,
            dtype=self.dtype,
            name=f"{self.name}_kernel",
            trainable=self.trainable,
        )
        if self.use_bias:
            bias_shape = self.compute_bias_shape()
            self.bias_weights = self.add_weight(
                key=k2,
                shape=bias_shape,
                initializer=self.bias_initializer,
                dtype=self.dtype,
                name=f"{self.name}_bias",
                trainable=self.trainable,
            )

        self._input_shape = input_shape
        self.dn = lax.conv_dimension_numbers(input_shape, kernel_shape,
                                             self.dimension_numbers)
        self.built = True
예제 #3
0
    def build(self, input_shape):
        if len(input_shape) > 3:
            raise ValueError(
                f"`input_shape` should have only 3 dimensions (H, C). Recieved: input_shape={input_shape}"
            )

        input_shape = (None, *input_shape[-2:])
        self._input_shape = input_shape
        self._output_shape = self.compute_output_shape(input_shape=input_shape)
        k1, k2 = random.split(self.seed)
        kernel_shape = self.compute_kernel_shape(input_shape)
        self.kernel_weights = self.add_weight(
            key=k1,
            shape=kernel_shape,
            initializer=self.kernel_initializer,
            dtype=self.dtype,
            name=f"{self.name}_kernel",
            trainable=self.trainable,
        )
        if self.use_bias:
            bias_shape = self.compute_bias_shape()
            self.bias_weights = self.add_weight(
                key=k2,
                shape=bias_shape,
                initializer=self.bias_initializer,
                dtype=self.dtype,
                name=f"{self.name}_bias",
                trainable=self.trainable,
            )

        self.dn = lax.conv_dimension_numbers(
            input_shape, kernel_shape, self.dimension_numbers
        )
        self.built = True
def _conv_block(stride, with_non_linearity, inp, kernel, bias):
    no_dilation = (1, 1)
    some_height_width = 10  # values don't matter; just shape of input
    input_shape = (1, some_height_width, some_height_width, 3)
    kernel_shape = (3, 3, 1, 1)
    input_kernel_output = ('NHWC', 'HWIO', 'NHWC')
    conv_dimension_numbers = lax.conv_dimension_numbers(
        input_shape, kernel_shape, input_kernel_output)
    block = lax.conv_general_dilated(inp, kernel, (stride, stride), 'VALID',
                                     no_dilation, no_dilation,
                                     conv_dimension_numbers)
    if bias is not None:
        block += bias
    if with_non_linearity:
        block = gelu(block)
    return block
예제 #5
0
def conv1d_lax(signal, kernel):
    '''
    CPU impl. is insanely slow for large kernels, jaxlib-cuda (i.e. cudnn's GPU impl.)
    is highly recommended
    '''
    x = device_put(signal)
    h = device_put(kernel)

    x = x[jnp.newaxis,:,jnp.newaxis]
    h = h[::-1,jnp.newaxis,jnp.newaxis]
    dn = lax.conv_dimension_numbers(x.shape, h.shape, ('NWC', 'WIO', 'NWC'))

    # lax.conv_general_dilated runs much slower than numpy.convolve on CPU_device
    x = lax.conv_general_dilated(x,      # lhs = image tensor
                                 h,      # rhs = conv kernel tensor
                                 (1,),   # window strides
                                 'SAME', # padding mode
                                 (1,),   # lhs/image dilation
                                 (1,),   # rhs/kernel dilation
                                 dn)     # dimension_numbers = lhs, rhs, out dimension permu

    return x[0,:,0]
예제 #6
0
def conv_general_dilated_patches(
    lhs: jnp.ndarray,
    filter_shape: Sequence[int],
    window_strides: Sequence[int],
    padding: Union[str, Sequence[Tuple[int, int]]],
    lhs_dilation: Sequence[int] = None,
    rhs_dilation: Sequence[int] = None,
    dimension_numbers: lax.ConvGeneralDilatedDimensionNumbers = None,
) -> jnp.ndarray:
    # This hasn't been added to JAX yet
    filter_shape = tuple(filter_shape)
    dimension_numbers = lax.conv_dimension_numbers(lhs.shape,
                                                   (1, 1) + filter_shape,
                                                   dimension_numbers)

    lhs_spec, rhs_spec, out_spec = dimension_numbers

    spatial_size = np.prod(filter_shape)
    n_channels = lhs.shape[lhs_spec[1]]

    # Move separate `lhs` spatial locations into separate `rhs` channels.
    rhs = jnp.eye(spatial_size, dtype=lhs.dtype).reshape(filter_shape * 2)

    rhs = rhs.reshape((spatial_size, 1) + filter_shape)
    rhs = jnp.tile(rhs, (n_channels, ) + (1, ) * (rhs.ndim - 1))
    rhs = jnp.moveaxis(rhs, (0, 1), (rhs_spec[0], rhs_spec[1]))

    out = lax.conv_general_dilated(lhs=lhs,
                                   rhs=rhs,
                                   window_strides=window_strides,
                                   padding=padding,
                                   lhs_dilation=lhs_dilation,
                                   rhs_dilation=rhs_dilation,
                                   dimension_numbers=dimension_numbers,
                                   precision=None,
                                   feature_group_count=n_channels)
    return out
예제 #7
0
import pyprobml_utils as pml

# The K number of states and size of the board
K= 10
ix = 128
iy = 128

# Initalising the key and the kernel
key = random.PRNGKey(12234)
kernel = jnp.zeros((3, 3, 1, 1), dtype=jnp.float32)
kernel += jnp.array([[0, 1, 0],
                     [1, 0,1],
                     [0,1,0]])[:, :, jnp.newaxis, jnp.newaxis]

dn = lax.conv_dimension_numbers((K, ix, iy, 1),     # only ndim matters, not shape
                                 kernel.shape,  # only ndim matters, not shape 
                                ('NHWC', 'HWIO', 'NHWC'))  # the important bit

# Creating the checkerboard
mask = jnp.indices((K, iy, ix, 1)).sum(axis=0) % 2

def checkerboard_pattern1(x):
  return mask[0, :, : , 0]

def checkerboard_pattern2(x):
  return mask[1, :, : , 0]

def make_checkerboard_pattern1():
  arr = vmap(checkerboard_pattern1, in_axes=0)(jnp.array(K*[1]))
  return jnp.expand_dims(arr, -1)