def init_fun(rng, input_shape): padding_vals = lax.padtype_to_pads(input_shape, window_shape, strides, padding) ones = (1,) * len(window_shape) out_shape = lax.reduce_window_shape_tuple( input_shape, window_shape, strides, padding_vals, ones, ones) return out_shape, ()
def _pooling_output_shape(input_shape, pool_size=(2, 2), strides=None, padding='VALID'): """Helper: compute the output shape for the pooling layer.""" dims = (1, ) + pool_size + (1, ) # NHWC strides = strides or (1, ) * len(pool_size) strides = (1, ) + strides + (1, ) return lax.reduce_window_shape_tuple(input_shape, dims, strides, padding)
def spec(cls, in_spec, window_shape, strides=None, padding='VALID'): in_shape = in_spec.shape if len(in_shape) > 3: raise ValueError('Need to `jax.vmap` in order to batch') in_shape = (1, ) + in_shape dims = (1, ) + window_shape + (1, ) # NHWC or NHC strides = strides or (1, ) * len(window_shape) strides = (1, ) + strides + (1, ) out_shape = lax.reduce_window_shape_tuple(in_shape, dims, strides, padding) out_shape = out_shape[1:] return state.Shape(out_shape, dtype=in_spec.dtype)
def spec(cls, in_spec, window_shape, strides=None, padding='VALID'): in_shape = in_spec.shape if len(in_shape) > 3: raise ValueError('Need to `jax.vmap` in order to batch') in_shape = (1, ) + in_shape dims = (1, ) + window_shape + (1, ) # NHWC or NHC non_spatial_axes = 0, len(window_shape) + 1 strides = strides or (1, ) * len(window_shape) for i in sorted(non_spatial_axes): window_shape = window_shape[:i] + (1, ) + window_shape[i:] strides = strides[:i] + (1, ) + strides[i:] padding = lax.padtype_to_pads(in_shape, window_shape, strides, padding) out_shape = lax.reduce_window_shape_tuple(in_shape, dims, strides, padding) out_shape = out_shape[1:] return state.Shape(out_shape, dtype=in_spec.dtype)
def compute_output_shape(self): # lax.reduce_window_shape_tuple() does not accept batch size with None # so it's replaced with '1' only in this function input_shape = (1, *self._input_shape[1:]) padding_vals = lax.padtype_to_pads(input_shape, self.pool_size, self.strides, self.padding) num_dims = tuple(1 for _ in range(self.dims)) base_dilation = (1, *num_dims, 1) window_dilation = (1, *num_dims, 1) out_shape = lax.reduce_window_shape_tuple( operand_shape=input_shape, window_dimensions=self.pool_size, window_strides=self.strides, padding=padding_vals, base_dilation=base_dilation, window_dilation=window_dilation, ) return out_shape
def init_fun(input_shape): out_shape = lax.reduce_window_shape_tuple(input_shape, dims, strides, padding) return out_shape, ()