def initialize(cls, rng, in_spec, dim_out, kernel_init=stax.glorot(), bias_init=stax.zeros): """Initializes Dense Layer. Args: rng: Random key. in_spec: Input Spec. dim_out: Output dimensions. kernel_init: Kernel initialization function. bias_init: Bias initialization function. Returns: Tuple with the output shape and the LayerParams. """ if rng is None: raise ValueError('Need valid RNG to instantiate Dense layer.') dim_in = in_spec.shape[-1] k1, k2 = random.split(rng) params = DenseParams( base.create_parameter(k1, (dim_in, dim_out), init=kernel_init), base.create_parameter(k2, (dim_out, ), init=bias_init)) return base.LayerParams(params)
def test_update_state(self): layer_params = base.LayerParams(params=1, state=2) layer_init = DummyLayer(layer_params) layer = layer_init.init(self._seed, state.Shape((1, 1))) self.assertEqual(layer.state, 2) new_layer = layer.update(3) self.assertEqual(new_layer.state, 102)
def initialize(cls, key, in_spec, axis=(0, 1), momentum=0.99, epsilon=1e-5, center=True, scale=True, beta_init=stax.zeros, gamma_init=stax.ones): in_shape = in_spec.shape axis = (axis, ) if np.isscalar(axis) else axis decay = 1.0 - momentum shape = tuple(d for i, d in enumerate(in_shape) if i not in axis) moving_shape = tuple(1 if i in axis else d for i, d in enumerate(in_shape)) k1, k2, k3, k4 = random.split(key, 4) beta = base.create_parameter(k1, shape, init=beta_init) if center else () gamma = base.create_parameter(k2, shape, init=gamma_init) if scale else () moving_mean = base.create_parameter(k3, moving_shape, init=stax.zeros) moving_var = base.create_parameter(k4, moving_shape, init=stax.ones) params = BatchNormParams(beta, gamma) info = BatchNormInfo(axis, epsilon, center, scale, decay, in_shape) state = BatchNormState(moving_mean, moving_var) return base.LayerParams(params, info, state)
def initialize(cls, rng, in_spec, window_shape, strides=None, padding='VALID'): """Initializes Pooling layers. Args: rng: Random key. in_spec: Spec, specifying the input shape and dtype. window_shape: Int Tuple, specifying the Pooling window shape. strides: Optional tuple with pooling strides. If None, it will use stride 1 for each dimension in window_shape. padding: Either the string "SAME" or "VALID" indicating the type of padding algorithm to use. "SAME" would preserve the same input size, while "VALID" would reduce the input size. Returns: Tuple with the output shape and the LayerParams. """ del in_spec strides = strides or (1, ) * len(window_shape) dims = (1, ) + window_shape + (1, ) # NHWC or NHC strides = (1, ) + strides + (1, ) info = PoolingInfo(window_shape, dims, strides, padding) return base.LayerParams(info=info)
def initialize(cls, key, in_spec, out_chan, filter_shape, strides=None, padding='VALID', kernel_init=None, bias_init=stax.randn(1e-6), use_bias=True): in_shape = in_spec.shape shapes, inits, (strides, padding, one) = conv_info(in_shape, out_chan, filter_shape, strides=strides, padding=padding, kernel_init=kernel_init, bias_init=bias_init) info = ConvInfo(strides, padding, one, use_bias) _, kernel_shape, bias_shape = shapes kernel_init, bias_init = inits k1, k2 = random.split(key) if use_bias: params = ConvParams( base.create_parameter(k1, kernel_shape, init=kernel_init), base.create_parameter(k2, bias_shape, init=bias_init), ) else: params = ConvParams( base.create_parameter(k1, kernel_shape, init=kernel_init), None) return base.LayerParams(params, info=info)
def test_flatten_and_unflatten(self): layer_params = base.LayerParams((1, 2), 3) layer_init = DummyLayer(layer_params) layer = layer_init.init(self._seed, state.Shape((1, 1))) xs, data = layer.flatten() new_layer = DummyLayer.unflatten(data, xs) self.assertTupleEqual(layer.params, new_layer.params) self.assertEqual(layer.info, new_layer.info) self.assertEqual(layer.name, new_layer.name)
def unflatten(cls, data, xs): """Reconstruct the Layer from the PyTree tuple.""" children_cls, children_data, name = data[0], data[1], data[2] layers = tuple(c.unflatten(d, x) for c, x, d in zip(children_cls, xs, children_data)) layer = object.__new__(cls) layer_params = base.LayerParams(layers) layer.__init__(layer_params, name=name) return layer
def initialize(cls, rng, in_spec): """Initializes Flatten Layer. Args: rng: Random key. in_spec: Input Spec. Returns: Tuple with the output shape and the LayerParams. """ return base.LayerParams(info=in_spec.shape)
def initialize(cls, rng, in_spec, dim_out): """Initializes Reshape Layer. Args: rng: Random key. in_spec: Input Spec. dim_out: Desired output dimensions. Returns: Tuple with the output shape and the LayerParams. """ return base.LayerParams(info=(in_spec.shape, tuple(dim_out)))
def initialize(cls, init_key, *args): """Initializes Serial Layer. Args: init_key: Random key. *args: Contains input specs and layer_inits. Returns: Tuple with the output spec and the LayerParams. """ in_specs, layer_inits = args[:-1], args[-1] layers = state.init(list(layer_inits), name='layers')(init_key, *in_specs) return base.LayerParams(tuple(layers)) # pytype: disable=wrong-arg-types
def test_call_with_needs_rng(self): layer_params = base.LayerParams(params=1, state=1, info=2) layer_init = DummyLayer(layer_params) layer = layer_init.init(self._seed, state.Shape((1, 1))) outputs = layer(3, rng=1) self.assertTupleEqual((1, 3, {'rng': 1}), outputs)
def test_flatten(self): layer_params = base.LayerParams(params=(1, 2), state=3) layer_init = DummyLayer(layer_params) layer = layer_init.init(self._seed, state.Shape((1, 1)), name='foo') self.assertTupleEqual((((1, 2), 3), ((), 'foo')), layer.flatten())
def test_to_string(self): layer_params = base.LayerParams(params=1, state=1, info=2) layer_init = DummyLayer(layer_params) layer = layer_init.init(self._seed, state.Shape((1, 1))) exp_str = ('DummyLayer(params=1, info=2)') self.assertEqual(exp_str, repr(layer))
def initialize(cls, rng, in_shape): """Initializes Activation Layer.""" del in_shape, rng return base.LayerParams()
def test_defaults(self): layer_params = base.LayerParams() self.assertTupleEqual(((), (), ()), layer_params)
def initialize(cls, rng, in_shape, rate): del in_shape layer_params = base.LayerParams(info=DropoutInfo(rate)) return layer_params
def test_init(self): layer_params = base.LayerParams((1, 2), 3) layer_init = DummyLayer(layer_params) layer = layer_init.init(self._seed, state.Shape((1, 1))) self.assertTupleEqual((1, 2), layer.params) self.assertEqual(3, layer.info)
def test_init_adds_tuple_to_params(self): layer_params = base.LayerParams(1, 2) layer_init = DummyLayer(layer_params) layer = layer_init.init(self._seed, state.Shape((1, 1))) self.assertEqual(1, layer.params) self.assertEqual(2, layer.info)
def test_call_with_has_training(self): layer_params = base.LayerParams(params=1, state=1, info=2) layer_init = DummyLayer(layer_params) layer = layer_init.init(self._seed, state.Shape((1, 1))) outputs = layer(3, training=True) self.assertTupleEqual((1, 3, {'training': True}), outputs)
def test_call_pass_params(self): layer_params = base.LayerParams(params=1, state=1) layer_init = DummyLayer(layer_params) layer = layer_init.init(self._seed, state.Shape((1, 1))) outputs = layer(3) self.assertTupleEqual((1, 3, {}), outputs)