Ejemplo n.º 1
0
    def initialize(cls,
                   rng,
                   in_spec,
                   dim_out,
                   kernel_init=stax.glorot(),
                   bias_init=stax.zeros):
        """Initializes Dense Layer.

    Args:
      rng: Random key.
      in_spec: Input Spec.
      dim_out: Output dimensions.
      kernel_init: Kernel initialization function.
      bias_init: Bias initialization function.

    Returns:
      Tuple with the output shape and the LayerParams.
    """
        if rng is None:
            raise ValueError('Need valid RNG to instantiate Dense layer.')
        dim_in = in_spec.shape[-1]
        k1, k2 = random.split(rng)
        params = DenseParams(
            base.create_parameter(k1, (dim_in, dim_out), init=kernel_init),
            base.create_parameter(k2, (dim_out, ), init=bias_init))
        return base.LayerParams(params)
Ejemplo n.º 2
0
 def test_update_state(self):
     layer_params = base.LayerParams(params=1, state=2)
     layer_init = DummyLayer(layer_params)
     layer = layer_init.init(self._seed, state.Shape((1, 1)))
     self.assertEqual(layer.state, 2)
     new_layer = layer.update(3)
     self.assertEqual(new_layer.state, 102)
Ejemplo n.º 3
0
 def initialize(cls,
                key,
                in_spec,
                axis=(0, 1),
                momentum=0.99,
                epsilon=1e-5,
                center=True,
                scale=True,
                beta_init=stax.zeros,
                gamma_init=stax.ones):
     in_shape = in_spec.shape
     axis = (axis, ) if np.isscalar(axis) else axis
     decay = 1.0 - momentum
     shape = tuple(d for i, d in enumerate(in_shape) if i not in axis)
     moving_shape = tuple(1 if i in axis else d
                          for i, d in enumerate(in_shape))
     k1, k2, k3, k4 = random.split(key, 4)
     beta = base.create_parameter(k1, shape,
                                  init=beta_init) if center else ()
     gamma = base.create_parameter(k2, shape,
                                   init=gamma_init) if scale else ()
     moving_mean = base.create_parameter(k3, moving_shape, init=stax.zeros)
     moving_var = base.create_parameter(k4, moving_shape, init=stax.ones)
     params = BatchNormParams(beta, gamma)
     info = BatchNormInfo(axis, epsilon, center, scale, decay, in_shape)
     state = BatchNormState(moving_mean, moving_var)
     return base.LayerParams(params, info, state)
Ejemplo n.º 4
0
    def initialize(cls,
                   rng,
                   in_spec,
                   window_shape,
                   strides=None,
                   padding='VALID'):
        """Initializes Pooling layers.

    Args:
      rng: Random key.
      in_spec: Spec, specifying the input shape and dtype.
      window_shape: Int Tuple, specifying the Pooling window shape.
      strides: Optional tuple with pooling strides. If None, it will use
        stride 1 for each dimension in window_shape.
      padding: Either the string "SAME" or "VALID" indicating the type of
        padding algorithm to use. "SAME" would preserve the same input size,
        while "VALID" would reduce the input size.

    Returns:
      Tuple with the output shape and the LayerParams.
    """
        del in_spec
        strides = strides or (1, ) * len(window_shape)
        dims = (1, ) + window_shape + (1, )  # NHWC or NHC
        strides = (1, ) + strides + (1, )
        info = PoolingInfo(window_shape, dims, strides, padding)
        return base.LayerParams(info=info)
Ejemplo n.º 5
0
 def initialize(cls,
                key,
                in_spec,
                out_chan,
                filter_shape,
                strides=None,
                padding='VALID',
                kernel_init=None,
                bias_init=stax.randn(1e-6),
                use_bias=True):
     in_shape = in_spec.shape
     shapes, inits, (strides, padding,
                     one) = conv_info(in_shape,
                                      out_chan,
                                      filter_shape,
                                      strides=strides,
                                      padding=padding,
                                      kernel_init=kernel_init,
                                      bias_init=bias_init)
     info = ConvInfo(strides, padding, one, use_bias)
     _, kernel_shape, bias_shape = shapes
     kernel_init, bias_init = inits
     k1, k2 = random.split(key)
     if use_bias:
         params = ConvParams(
             base.create_parameter(k1, kernel_shape, init=kernel_init),
             base.create_parameter(k2, bias_shape, init=bias_init),
         )
     else:
         params = ConvParams(
             base.create_parameter(k1, kernel_shape, init=kernel_init),
             None)
     return base.LayerParams(params, info=info)
Ejemplo n.º 6
0
 def test_flatten_and_unflatten(self):
     layer_params = base.LayerParams((1, 2), 3)
     layer_init = DummyLayer(layer_params)
     layer = layer_init.init(self._seed, state.Shape((1, 1)))
     xs, data = layer.flatten()
     new_layer = DummyLayer.unflatten(data, xs)
     self.assertTupleEqual(layer.params, new_layer.params)
     self.assertEqual(layer.info, new_layer.info)
     self.assertEqual(layer.name, new_layer.name)
Ejemplo n.º 7
0
 def unflatten(cls, data, xs):
   """Reconstruct the Layer from the PyTree tuple."""
   children_cls, children_data, name = data[0], data[1], data[2]
   layers = tuple(c.unflatten(d, x) for c, x, d in
                  zip(children_cls, xs, children_data))
   layer = object.__new__(cls)
   layer_params = base.LayerParams(layers)
   layer.__init__(layer_params, name=name)
   return layer
Ejemplo n.º 8
0
    def initialize(cls, rng, in_spec):
        """Initializes Flatten Layer.

    Args:
      rng: Random key.
      in_spec: Input Spec.

    Returns:
      Tuple with the output shape and the LayerParams.
    """
        return base.LayerParams(info=in_spec.shape)
Ejemplo n.º 9
0
    def initialize(cls, rng, in_spec, dim_out):
        """Initializes Reshape Layer.

    Args:
      rng: Random key.
      in_spec: Input Spec.
      dim_out: Desired output dimensions.

    Returns:
      Tuple with the output shape and the LayerParams.
    """
        return base.LayerParams(info=(in_spec.shape, tuple(dim_out)))
Ejemplo n.º 10
0
  def initialize(cls, init_key, *args):
    """Initializes Serial Layer.

    Args:
      init_key: Random key.
      *args: Contains input specs and layer_inits.

    Returns:
      Tuple with the output spec and the LayerParams.
    """
    in_specs, layer_inits = args[:-1], args[-1]
    layers = state.init(list(layer_inits), name='layers')(init_key, *in_specs)
    return base.LayerParams(tuple(layers))  # pytype: disable=wrong-arg-types
Ejemplo n.º 11
0
 def test_call_with_needs_rng(self):
     layer_params = base.LayerParams(params=1, state=1, info=2)
     layer_init = DummyLayer(layer_params)
     layer = layer_init.init(self._seed, state.Shape((1, 1)))
     outputs = layer(3, rng=1)
     self.assertTupleEqual((1, 3, {'rng': 1}), outputs)
Ejemplo n.º 12
0
 def test_flatten(self):
     layer_params = base.LayerParams(params=(1, 2), state=3)
     layer_init = DummyLayer(layer_params)
     layer = layer_init.init(self._seed, state.Shape((1, 1)), name='foo')
     self.assertTupleEqual((((1, 2), 3), ((), 'foo')), layer.flatten())
Ejemplo n.º 13
0
 def test_to_string(self):
     layer_params = base.LayerParams(params=1, state=1, info=2)
     layer_init = DummyLayer(layer_params)
     layer = layer_init.init(self._seed, state.Shape((1, 1)))
     exp_str = ('DummyLayer(params=1, info=2)')
     self.assertEqual(exp_str, repr(layer))
Ejemplo n.º 14
0
 def initialize(cls, rng, in_shape):
   """Initializes Activation Layer."""
   del in_shape, rng
   return base.LayerParams()
Ejemplo n.º 15
0
 def test_defaults(self):
     layer_params = base.LayerParams()
     self.assertTupleEqual(((), (), ()), layer_params)
Ejemplo n.º 16
0
 def initialize(cls, rng, in_shape, rate):
   del in_shape
   layer_params = base.LayerParams(info=DropoutInfo(rate))
   return layer_params
Ejemplo n.º 17
0
 def test_init(self):
     layer_params = base.LayerParams((1, 2), 3)
     layer_init = DummyLayer(layer_params)
     layer = layer_init.init(self._seed, state.Shape((1, 1)))
     self.assertTupleEqual((1, 2), layer.params)
     self.assertEqual(3, layer.info)
Ejemplo n.º 18
0
 def test_init_adds_tuple_to_params(self):
     layer_params = base.LayerParams(1, 2)
     layer_init = DummyLayer(layer_params)
     layer = layer_init.init(self._seed, state.Shape((1, 1)))
     self.assertEqual(1, layer.params)
     self.assertEqual(2, layer.info)
Ejemplo n.º 19
0
 def test_call_with_has_training(self):
     layer_params = base.LayerParams(params=1, state=1, info=2)
     layer_init = DummyLayer(layer_params)
     layer = layer_init.init(self._seed, state.Shape((1, 1)))
     outputs = layer(3, training=True)
     self.assertTupleEqual((1, 3, {'training': True}), outputs)
Ejemplo n.º 20
0
 def test_call_pass_params(self):
     layer_params = base.LayerParams(params=1, state=1)
     layer_init = DummyLayer(layer_params)
     layer = layer_init.init(self._seed, state.Shape((1, 1)))
     outputs = layer(3)
     self.assertTupleEqual((1, 3, {}), outputs)