Пример #1
0
    def init_fun(rng, input_shape):

        # add padding dimensions for periodic BC; move this line into conv_general_shape_tuple after defining padding='PERIODIC'

        add_input = list(np.array(filter_shape) - 1)  # new
        input_shape += np.array([0] + add_input +
                                [0])  # only works with stride=(1,1)

        filter_shape_iter = iter(filter_shape)
        kernel_shape = [
            out_chan if c == 'O' else input_shape[lhs_spec.index('C')]
            if c == 'I' else next(filter_shape_iter) for c in rhs_spec
        ]

        output_shape = lax.conv_general_shape_tuple(input_shape, kernel_shape,
                                                    strides, padding,
                                                    dimension_numbers)

        k1, k2 = random.split(rng)

        if not ignore_b:
            bias_shape = [out_chan if c == 'C' else 1 for c in out_spec]
            bias_shape = tuple(
                itertools.dropwhile(lambda x: x == 1, bias_shape))

            W, b = W_init(k1, kernel_shape, dtype=dtype), b_init(k2,
                                                                 bias_shape,
                                                                 dtype=dtype)
            return tuple(output_shape), (W, b)
        else:
            W = W_init(k1, kernel_shape, dtype=dtype)
            return output_shape, (W, )
Пример #2
0
 def init_fun(rng, input_shape):
   filter_shape_iter = iter(filter_shape)
   kernel_shape = [out_chan if c == 'O' else
                   input_shape[lhs_spec.index('C')] if c == 'I' else
                   next(filter_shape_iter) for c in rhs_spec]
   output_shape = lax.conv_general_shape_tuple(
       input_shape, kernel_shape, strides, padding, dimension_numbers)
   W = W_init(rng, kernel_shape)
   return output_shape, (W,)
Пример #3
0
 def init_fun(rng, input_shape):
   filter_shape_iter = iter(filter_shape)
   kernel_shape = [out_chan if c == 'O' else
                   input_shape[lhs_spec.index('C')] if c == 'I' else
                   next(filter_shape_iter) for c in rhs_spec]
   output_shape = lax.conv_general_shape_tuple(
       input_shape, kernel_shape, strides, padding, dimension_numbers)
   bias_shape = [out_chan if c == 'C' else 1 for c in out_spec]
   bias_shape = tuple(itertools.dropwhile(lambda x: x == 1, bias_shape))
   W, b = W_init(rng, kernel_shape), b_init(rng, bias_shape)
   return output_shape, (W, b)
Пример #4
0
 def init_fun(rng, input_shape):
   filter_shape_iter = iter(filter_shape)
   kernel_shape = [out_chan if c == 'O' else
                   input_shape[lhs_spec.index('C')] if c == 'I' else
                   next(filter_shape_iter) for c in rhs_spec]
   output_shape = lax.conv_general_shape_tuple(
       input_shape, kernel_shape, strides, padding, dimension_numbers)
   bias_shape = [out_chan if c == 'C' else 1 for c in out_spec]
   k1, k2 = random.split(rng)
   W, b = W_init(k1, kernel_shape), b_init(k2, bias_shape)
   return output_shape, (W, b)
Пример #5
0
    def init_fun(rng, input_shape):
        kernel_shape = (filter_shape[0], filter_shape[1], 1,
                        out_chan * input_shape[3])
        output_shape = lax.conv_general_shape_tuple(input_shape, kernel_shape,
                                                    strides, padding,
                                                    ("NHWC", "HWIO", "NHWC"))
        bias_shape = tuple(input_shape[0], out_chan * input_shape[3])
        k1, k2 = random.split(rng)

        if b_init is None:
            b_init = normal(1. / np.sqrt(np.prod(kernel_shape[:-1])))
        W, b = W_init(k1, kernel_shape), b_init(k2, bias_shape)
        return output_shape, (W, b)
Пример #6
0
 def compute_output_shape(self):
     if self.built:
         return lax.conv_general_shape_tuple(
             lhs_shape=self.input_shape,
             rhs_shape=self.kernel_weights.shape,
             window_strides=self.strides,
             padding=self.padding,
             dimension_numbers=self.dimension_numbers,
         )
     else:
         raise Exception(
             f"{self.name} is not built yet, use call() or build() to build it."
         )
Пример #7
0
    def init_fun(rng, input_shape):
        filter_shape_iter = iter(filter_shape)
        kernel_shape = [
            out_chan if c == "O" else input_shape[lhs_spec.index("C")]
            if c == "I" else next(filter_shape_iter) for c in rhs_spec
        ]
        output_shape = lax.conv_general_shape_tuple(input_shape, kernel_shape,
                                                    strides, padding,
                                                    dimension_numbers)
        bias_shape = [out_chan if c == "C" else 1 for c in out_spec]
        bias_shape = tuple(itertools.dropwhile(lambda x: x == 1, bias_shape))
        k1, k2 = random.split(rng)

        if b_init is None:
            b_init = normal(1. / np.sqrt(np.prod(kernel_shape[:-1])))
        W, b = W_init(k1, kernel_shape), b_init(k2, bias_shape)
        return output_shape, (W, b)
Пример #8
0
def conv_info(in_shape,
              out_chan,
              filter_shape,
              strides=None,
              padding='VALID',
              kernel_init=None,
              bias_init=stax.randn(1e-6),
              transpose=False):
    """Returns parameters and output shape information given input shapes."""
    # Essentially the `stax` implementation
    if len(in_shape) != 3:
        raise ValueError('Need to `jax.vmap` in order to batch')
    in_shape = (1, ) + in_shape
    lhs_spec, rhs_spec, out_spec = DIMENSION_NUMBERS
    one = (1, ) * len(filter_shape)
    strides = strides or one
    kernel_init = kernel_init or stax.glorot(rhs_spec.index('O'),
                                             rhs_spec.index('I'))
    filter_shape_iter = iter(filter_shape)
    kernel_shape = tuple([
        out_chan if c == 'O' else
        in_shape[lhs_spec.index('C')] if c == 'I' else next(filter_shape_iter)
        for c in rhs_spec
    ])
    if transpose:
        out_shape = lax.conv_transpose_shape_tuple(in_shape, kernel_shape,
                                                   strides, padding,
                                                   DIMENSION_NUMBERS)
    else:
        out_shape = lax.conv_general_shape_tuple(in_shape, kernel_shape,
                                                 strides, padding,
                                                 DIMENSION_NUMBERS)
    bias_shape = [out_chan if c == 'C' else 1 for c in out_spec]
    bias_shape = tuple(itertools.dropwhile(lambda x: x == 1, bias_shape))
    out_shape = out_shape[1:]
    shapes = (out_shape, kernel_shape, bias_shape)
    inits = (kernel_init, bias_init)
    return shapes, inits, (strides, padding, one)
Пример #9
0
 def output_shape(self, input_shape):
     kernel_shape = self._kernel_shape(input_shape)
     return lax.conv_general_shape_tuple(input_shape, kernel_shape,
                                         self._strides, self._padding,
                                         self._dimension_numbers)