def fwd(inputs, layer_params): """JAX forward pass of Linear, Conv, or ConvTranspose.""" if isinstance(layer_params, dict): # Conv/ConvTranspose: Reshape input if necessary: if len(inputs.shape) < 4: w = h = int(np.sqrt(inputs.shape[-1] / layer_params['n_cin'])) inputs = inputs.reshape(inputs.shape[0], h, w, layer_params['n_cin']) W, b = layer_params['W'], np.reshape(layer_params['b'], [1, 1, 1, -1]) if 'transpose' in layer_params and layer_params['transpose']: # ConvTranspose return default_conv_transpose( inputs, W, (layer_params['stride'], layer_params['stride']), layer_params['padding']) + b else: # Normal Conv2D dn = lax.conv_dimension_numbers(inputs.shape, W.shape, ('NHWC', 'HWIO', 'NHWC')) return lax.conv_general_dilated( inputs, W, (layer_params['stride'], layer_params['stride']), layer_params['padding'], (1, 1), (1, 1), dn) + b elif isinstance(layer_params, tuple): # Linear fully-connected layer # TODO: Figure out why we were dropping batch dim before here inputs = (inputs.reshape(inputs.shape[0], -1) if len(inputs.shape) == 4 else inputs) (W, b) = layer_params return jnp.dot(inputs, W) + b elif callable(layer_params): # Most general way of specifying an affine layer is to provide its function. return layer_params(inputs) else: raise NotImplementedError('Unknown layer')
def build(self, input_shape: Tuple): "Initializes the Kernel and stores the Conv2D weights" input_shape = (None, *input_shape[-3:]) k1, k2 = random.split(self.seed) kernel_shape = self.compute_kernel_shape(input_shape) self.kernel_weights = self.add_weight( key=k1, shape=kernel_shape, initializer=self.kernel_initializer, dtype=self.dtype, name=f"{self.name}_kernel", trainable=self.trainable, ) if self.use_bias: bias_shape = self.compute_bias_shape() self.bias_weights = self.add_weight( key=k2, shape=bias_shape, initializer=self.bias_initializer, dtype=self.dtype, name=f"{self.name}_bias", trainable=self.trainable, ) self._input_shape = input_shape self.dn = lax.conv_dimension_numbers(input_shape, kernel_shape, self.dimension_numbers) self.built = True
def build(self, input_shape): if len(input_shape) > 3: raise ValueError( f"`input_shape` should have only 3 dimensions (H, C). Recieved: input_shape={input_shape}" ) input_shape = (None, *input_shape[-2:]) self._input_shape = input_shape self._output_shape = self.compute_output_shape(input_shape=input_shape) k1, k2 = random.split(self.seed) kernel_shape = self.compute_kernel_shape(input_shape) self.kernel_weights = self.add_weight( key=k1, shape=kernel_shape, initializer=self.kernel_initializer, dtype=self.dtype, name=f"{self.name}_kernel", trainable=self.trainable, ) if self.use_bias: bias_shape = self.compute_bias_shape() self.bias_weights = self.add_weight( key=k2, shape=bias_shape, initializer=self.bias_initializer, dtype=self.dtype, name=f"{self.name}_bias", trainable=self.trainable, ) self.dn = lax.conv_dimension_numbers( input_shape, kernel_shape, self.dimension_numbers ) self.built = True
def _conv_block(stride, with_non_linearity, inp, kernel, bias): no_dilation = (1, 1) some_height_width = 10 # values don't matter; just shape of input input_shape = (1, some_height_width, some_height_width, 3) kernel_shape = (3, 3, 1, 1) input_kernel_output = ('NHWC', 'HWIO', 'NHWC') conv_dimension_numbers = lax.conv_dimension_numbers( input_shape, kernel_shape, input_kernel_output) block = lax.conv_general_dilated(inp, kernel, (stride, stride), 'VALID', no_dilation, no_dilation, conv_dimension_numbers) if bias is not None: block += bias if with_non_linearity: block = gelu(block) return block
def conv1d_lax(signal, kernel): ''' CPU impl. is insanely slow for large kernels, jaxlib-cuda (i.e. cudnn's GPU impl.) is highly recommended ''' x = device_put(signal) h = device_put(kernel) x = x[jnp.newaxis,:,jnp.newaxis] h = h[::-1,jnp.newaxis,jnp.newaxis] dn = lax.conv_dimension_numbers(x.shape, h.shape, ('NWC', 'WIO', 'NWC')) # lax.conv_general_dilated runs much slower than numpy.convolve on CPU_device x = lax.conv_general_dilated(x, # lhs = image tensor h, # rhs = conv kernel tensor (1,), # window strides 'SAME', # padding mode (1,), # lhs/image dilation (1,), # rhs/kernel dilation dn) # dimension_numbers = lhs, rhs, out dimension permu return x[0,:,0]
def conv_general_dilated_patches( lhs: jnp.ndarray, filter_shape: Sequence[int], window_strides: Sequence[int], padding: Union[str, Sequence[Tuple[int, int]]], lhs_dilation: Sequence[int] = None, rhs_dilation: Sequence[int] = None, dimension_numbers: lax.ConvGeneralDilatedDimensionNumbers = None, ) -> jnp.ndarray: # This hasn't been added to JAX yet filter_shape = tuple(filter_shape) dimension_numbers = lax.conv_dimension_numbers(lhs.shape, (1, 1) + filter_shape, dimension_numbers) lhs_spec, rhs_spec, out_spec = dimension_numbers spatial_size = np.prod(filter_shape) n_channels = lhs.shape[lhs_spec[1]] # Move separate `lhs` spatial locations into separate `rhs` channels. rhs = jnp.eye(spatial_size, dtype=lhs.dtype).reshape(filter_shape * 2) rhs = rhs.reshape((spatial_size, 1) + filter_shape) rhs = jnp.tile(rhs, (n_channels, ) + (1, ) * (rhs.ndim - 1)) rhs = jnp.moveaxis(rhs, (0, 1), (rhs_spec[0], rhs_spec[1])) out = lax.conv_general_dilated(lhs=lhs, rhs=rhs, window_strides=window_strides, padding=padding, lhs_dilation=lhs_dilation, rhs_dilation=rhs_dilation, dimension_numbers=dimension_numbers, precision=None, feature_group_count=n_channels) return out
import pyprobml_utils as pml # The K number of states and size of the board K= 10 ix = 128 iy = 128 # Initalising the key and the kernel key = random.PRNGKey(12234) kernel = jnp.zeros((3, 3, 1, 1), dtype=jnp.float32) kernel += jnp.array([[0, 1, 0], [1, 0,1], [0,1,0]])[:, :, jnp.newaxis, jnp.newaxis] dn = lax.conv_dimension_numbers((K, ix, iy, 1), # only ndim matters, not shape kernel.shape, # only ndim matters, not shape ('NHWC', 'HWIO', 'NHWC')) # the important bit # Creating the checkerboard mask = jnp.indices((K, iy, ix, 1)).sum(axis=0) % 2 def checkerboard_pattern1(x): return mask[0, :, : , 0] def checkerboard_pattern2(x): return mask[1, :, : , 0] def make_checkerboard_pattern1(): arr = vmap(checkerboard_pattern1, in_axes=0)(jnp.array(K*[1])) return jnp.expand_dims(arr, -1)