def init_fun(rng, input_shape): # add padding dimensions for periodic BC; move this line into conv_general_shape_tuple after defining padding='PERIODIC' add_input = list(np.array(filter_shape) - 1) # new input_shape += np.array([0] + add_input + [0]) # only works with stride=(1,1) filter_shape_iter = iter(filter_shape) kernel_shape = [ out_chan if c == 'O' else input_shape[lhs_spec.index('C')] if c == 'I' else next(filter_shape_iter) for c in rhs_spec ] output_shape = lax.conv_general_shape_tuple(input_shape, kernel_shape, strides, padding, dimension_numbers) k1, k2 = random.split(rng) if not ignore_b: bias_shape = [out_chan if c == 'C' else 1 for c in out_spec] bias_shape = tuple( itertools.dropwhile(lambda x: x == 1, bias_shape)) W, b = W_init(k1, kernel_shape, dtype=dtype), b_init(k2, bias_shape, dtype=dtype) return tuple(output_shape), (W, b) else: W = W_init(k1, kernel_shape, dtype=dtype) return output_shape, (W, )
def init_fun(rng, input_shape): filter_shape_iter = iter(filter_shape) kernel_shape = [out_chan if c == 'O' else input_shape[lhs_spec.index('C')] if c == 'I' else next(filter_shape_iter) for c in rhs_spec] output_shape = lax.conv_general_shape_tuple( input_shape, kernel_shape, strides, padding, dimension_numbers) W = W_init(rng, kernel_shape) return output_shape, (W,)
def init_fun(rng, input_shape): filter_shape_iter = iter(filter_shape) kernel_shape = [out_chan if c == 'O' else input_shape[lhs_spec.index('C')] if c == 'I' else next(filter_shape_iter) for c in rhs_spec] output_shape = lax.conv_general_shape_tuple( input_shape, kernel_shape, strides, padding, dimension_numbers) bias_shape = [out_chan if c == 'C' else 1 for c in out_spec] bias_shape = tuple(itertools.dropwhile(lambda x: x == 1, bias_shape)) W, b = W_init(rng, kernel_shape), b_init(rng, bias_shape) return output_shape, (W, b)
def init_fun(rng, input_shape): filter_shape_iter = iter(filter_shape) kernel_shape = [out_chan if c == 'O' else input_shape[lhs_spec.index('C')] if c == 'I' else next(filter_shape_iter) for c in rhs_spec] output_shape = lax.conv_general_shape_tuple( input_shape, kernel_shape, strides, padding, dimension_numbers) bias_shape = [out_chan if c == 'C' else 1 for c in out_spec] k1, k2 = random.split(rng) W, b = W_init(k1, kernel_shape), b_init(k2, bias_shape) return output_shape, (W, b)
def init_fun(rng, input_shape): kernel_shape = (filter_shape[0], filter_shape[1], 1, out_chan * input_shape[3]) output_shape = lax.conv_general_shape_tuple(input_shape, kernel_shape, strides, padding, ("NHWC", "HWIO", "NHWC")) bias_shape = tuple(input_shape[0], out_chan * input_shape[3]) k1, k2 = random.split(rng) if b_init is None: b_init = normal(1. / np.sqrt(np.prod(kernel_shape[:-1]))) W, b = W_init(k1, kernel_shape), b_init(k2, bias_shape) return output_shape, (W, b)
def compute_output_shape(self): if self.built: return lax.conv_general_shape_tuple( lhs_shape=self.input_shape, rhs_shape=self.kernel_weights.shape, window_strides=self.strides, padding=self.padding, dimension_numbers=self.dimension_numbers, ) else: raise Exception( f"{self.name} is not built yet, use call() or build() to build it." )
def init_fun(rng, input_shape): filter_shape_iter = iter(filter_shape) kernel_shape = [ out_chan if c == "O" else input_shape[lhs_spec.index("C")] if c == "I" else next(filter_shape_iter) for c in rhs_spec ] output_shape = lax.conv_general_shape_tuple(input_shape, kernel_shape, strides, padding, dimension_numbers) bias_shape = [out_chan if c == "C" else 1 for c in out_spec] bias_shape = tuple(itertools.dropwhile(lambda x: x == 1, bias_shape)) k1, k2 = random.split(rng) if b_init is None: b_init = normal(1. / np.sqrt(np.prod(kernel_shape[:-1]))) W, b = W_init(k1, kernel_shape), b_init(k2, bias_shape) return output_shape, (W, b)
def conv_info(in_shape, out_chan, filter_shape, strides=None, padding='VALID', kernel_init=None, bias_init=stax.randn(1e-6), transpose=False): """Returns parameters and output shape information given input shapes.""" # Essentially the `stax` implementation if len(in_shape) != 3: raise ValueError('Need to `jax.vmap` in order to batch') in_shape = (1, ) + in_shape lhs_spec, rhs_spec, out_spec = DIMENSION_NUMBERS one = (1, ) * len(filter_shape) strides = strides or one kernel_init = kernel_init or stax.glorot(rhs_spec.index('O'), rhs_spec.index('I')) filter_shape_iter = iter(filter_shape) kernel_shape = tuple([ out_chan if c == 'O' else in_shape[lhs_spec.index('C')] if c == 'I' else next(filter_shape_iter) for c in rhs_spec ]) if transpose: out_shape = lax.conv_transpose_shape_tuple(in_shape, kernel_shape, strides, padding, DIMENSION_NUMBERS) else: out_shape = lax.conv_general_shape_tuple(in_shape, kernel_shape, strides, padding, DIMENSION_NUMBERS) bias_shape = [out_chan if c == 'C' else 1 for c in out_spec] bias_shape = tuple(itertools.dropwhile(lambda x: x == 1, bias_shape)) out_shape = out_shape[1:] shapes = (out_shape, kernel_shape, bias_shape) inits = (kernel_init, bias_init) return shapes, inits, (strides, padding, one)
def output_shape(self, input_shape): kernel_shape = self._kernel_shape(input_shape) return lax.conv_general_shape_tuple(input_shape, kernel_shape, self._strides, self._padding, self._dimension_numbers)