def get_net(x): def init(v): return dict(w_init=lambda *args: v * jnp.ones((1, 1)), b_init=lambda *args: v * 1.5 * jnp.ones((1, ))) h = basic.Linear(output_size=1, name="first_layer", **init(1.0))(x) h = basic.Linear(output_size=1, name="second_layer", **init(3.0))(h) return jnp.mean(h)
def test_sequential_params(self): seq = basic.Sequential([ basic.Sequential([basic.Linear(2), basic.Linear(2)]), basic.Sequential([lambda x: basic.Linear(2)(x * 1)])]) for _ in range(2): # Connect seq to ensure params are created. Connect twice to ensure that # we see the two instances of the lambda Linear. seq(jnp.zeros([1, 1])) params = seq.params_dict() self.assertCountEqual( list(params), ["linear/w", "linear/b", "linear_1/w", "linear_1/b", "sequential_1/linear/w", "sequential_1/linear/b"])
def inner_fn(x, extra): out = basic.Linear( x.shape[1], w_init=initializers.Constant(extra * jnp.eye(x.shape[1])), b_init=initializers.Constant(extra), )(x) return out, out
def test_connection_and_shapes(self): batch_size = 4 x = make_sequence([batch_size, 3]) # [B, F] core = recurrent.DeepRNN([ recurrent.VanillaRNN(hidden_size=3), basic.Linear(2), jax.nn.relu, recurrent.VanillaRNN(hidden_size=5), jax.nn.relu, ]) initial_state = core.initial_state(x.shape[0]) out, next_state = core(x, initial_state) self.assertEqual(out.shape, (batch_size, 5)) # Verifies that at least last layer of relu is applied. self.assertTrue(np.all(out >= np.zeros([batch_size, 5]))) self.assertLen(next_state, 2) self.assertEqual(initial_state[0].shape, (batch_size, 3)) self.assertEqual(initial_state[1].shape, (batch_size, 5)) self.assertLen(initial_state, 2) np.testing.assert_allclose(initial_state[0], jnp.zeros([batch_size, 3])) np.testing.assert_allclose(initial_state[1], jnp.zeros([batch_size, 5]))
def inner_fn(x): # Here we initialize the layer to an identity + 1, while later we multiply # each parameter by the index `n`. return basic.Linear( x.shape[1], w_init=initializers.Constant(jnp.eye(x.shape[1])), b_init=initializers.Constant(1.0), )(x)
def f_with_container_state(x): hk_layer = basic.Linear(width, w_init=initializers.Constant( jnp.eye(width))) layer_output = hk_layer(x) layer_state = { "raw_output": layer_output, "output_projection": jnp.sum(layer_output) } return layer_output + jnp.ones_like(layer_output), layer_state
def inner_fn(x, extra): # Compared to previous test we pass in the `extra` argument as an # additional input, in order to directly initialize the parameters to the # index `n` of the iteration. out = basic.Linear( x.shape[1], w_init=initializers.Constant(extra * jnp.eye(x.shape[1])), b_init=initializers.Constant(extra), )(x) return out, out
def __call__(self, inputs, state): if len(inputs.shape) > 2 or not inputs.shape: raise ValueError("LSTM input must be rank-1 or rank-2.") prev_h, prev_c = state x_and_h = jnp.concatenate([inputs, prev_h], axis=-1) gated = basic.Linear(4 * self.hidden_size)(x_and_h) # i = input, g = cell_gate, f = forget_gate, o = output_gate i, g, f, o = jnp.split(gated, indices_or_sections=4, axis=-1) f = jax.nn.sigmoid(f + 1) # Forget bias, as in sonnet. c = f * prev_c + jax.nn.sigmoid(i) * jnp.tanh(g) h = jax.nn.sigmoid(o) * jnp.tanh(c) return h, (h, c)
def net(x): # x is [B, F]. core = recurrent.DeepRNN([ recurrent.VanillaRNN(hidden_size=3), basic.Linear(2), jax.nn.relu, recurrent.VanillaRNN(hidden_size=5), jax.nn.relu, ]) initial_state = core.initial_state(x.shape[0]) out, next_state = core(x, initial_state) return dict(out=out, next_state=next_state, initial_state=initial_state)
def __call__(self, inputs, is_training): initial_conv = conv.Conv2D(32, (3, 3), stride=2, padding="VALID", with_bias=self._with_bias) net = initial_conv(inputs) if self._use_bn: net = batch_norm.BatchNorm(create_scale=True, create_offset=True)(net, is_training) net = jax.nn.relu(net) for i in range(len(self._strides)): net = MobileNetV1Block(self._channels[i], self._strides[i], self._use_bn)(net, is_training) net = jnp.mean(net, axis=(1, 2)) net = reshape.Flatten()(net) net = basic.Linear(self._num_classes, name="logits")(net) return net
def test_fast_eval_shape_already_transformed(self): f = transform.transform(lambda x: basic.Linear(20)(x)) # pylint: disable=unnecessary-lambda rng = jax.random.PRNGKey(0) x = jnp.ones([1, 12]) # init_fn y_slow = jax.eval_shape(f.init, rng, x) y_fast = eval_shape.fast_eval_shape(f.init, rng, x) self.assertEqual(y_slow, y_fast) self.assertEqual( y_slow, { 'linear': { 'w': jax.ShapeDtypeStruct((12, 20), jnp.float32), 'b': jax.ShapeDtypeStruct((20, ), jnp.float32) } }) # apply_fn y_slow = jax.eval_shape(f.apply, y_slow, rng, x) y_fast = eval_shape.fast_eval_shape(f.apply, y_fast, rng, x) self.assertEqual(y_slow, y_fast)
def __init__(self, output_sizes: Iterable[int], w_init: Optional[base.Initializer] = None, b_init: Optional[base.Initializer] = None, with_bias: bool = True, activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.relu, activate_final: bool = False, name: Optional[Text] = None): """Constructs an MLP. Args: output_sizes: Sequence of layer sizes. w_init: Initializer for Linear weights. b_init: Initializer for Linear bias. Must be `None` if `with_bias` is `False`. with_bias: Whether or not to apply a bias in each layer. activation: Activation function to apply between linear layers. Defaults to ReLU. activate_final: Whether or not to activate the final layer of the MLP. name: Optional name for this module. Raises: ValueError: If with_bias is False and b_init is not None. """ if not with_bias and b_init is not None: raise ValueError("When with_bias=False b_init must not be set.") super(MLP, self).__init__(name=name) self._with_bias = with_bias self._w_init = w_init self._b_init = b_init self._activation = activation self._activate_final = activate_final self._layers = [] for index, output_size in enumerate(output_sizes): self._layers.append( basic.Linear(output_size=output_size, w_init=w_init, b_init=b_init, with_bias=with_bias, name="linear_%d" % index))
def f(): return basic.Linear(output_size=2)(jnp.zeros([6]))
def f(): seq = basic.Sequential([basic.Linear(2), jax.nn.relu]) return seq(jnp.zeros([3, 2]))
def __call__(self, inputs, prev_state): # TODO(slebedev): Consider dropping one of the biases. in2h = basic.Linear(self.hidden_size)(inputs) h2h = basic.Linear(self.hidden_size)(prev_state) outputs = jax.nn.relu(in2h + h2h) return outputs, outputs
def f(): return basic.Linear(output_size=2, b_init=jnp.ones)(jnp.zeros([6]))
def f(): return basic.Linear(output_size=2, name=linear_name)(jnp.zeros([6]))
def inner_fn(x): x += basic.Linear(100, name="linear1")(x) x += basic.Linear(100, name="linear2")(x) x /= jnp.mean(x) return x
def f(x): return basic.Linear(1)(x, precision=precision)
def layer_fn(x): return basic.Linear(100)(x)
def __init__(self, blocks_per_group: Sequence[int], num_classes: int, bn_config: Optional[Mapping[str, float]] = None, resnet_v2: bool = False, channels_per_group: Sequence[int] = (256, 512, 1024, 2048), name: Optional[str] = None): """Constructs a ResNet model. Args: blocks_per_group: A sequence of length 4 that indicates the number of blocks created in each group. num_classes: The number of classes to classify the inputs into. bn_config: A dictionary of two elements, `decay_rate` and `eps` to be passed on to the `BatchNorm` layers. By default the `decay_rate` is `0.9` and `eps` is `1e-5`. resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults to False. channels_per_group: A sequence of length 4 that indicates the number of channels used for each block in each group. name: Name of the module. """ super().__init__(name=name) self._resnet_v2 = resnet_v2 bn_config = dict(bn_config or {}) bn_config.setdefault("decay_rate", 0.9) bn_config.setdefault("eps", 1e-5) bn_config.setdefault("create_scale", True) bn_config.setdefault("create_offset", True) # Number of blocks in each group for ResNet. check_length(4, blocks_per_group, "blocks_per_group") check_length(4, channels_per_group, "channels_per_group") self._initial_conv = conv.Conv2D(output_channels=64, kernel_shape=7, stride=2, with_bias=False, padding="SAME", name="initial_conv") if not self._resnet_v2: self._initial_batchnorm = batch_norm.BatchNorm( name="initial_batchnorm", **bn_config) self._block_groups = [] strides = (1, 2, 2, 2) for i in range(4): self._block_groups.append( BlockGroup(channels=channels_per_group[i], num_blocks=blocks_per_group[i], stride=strides[i], bn_config=bn_config, resnet_v2=resnet_v2, name="block_group_%d" % (i))) if self._resnet_v2: self._final_batchnorm = batch_norm.BatchNorm( name="final_batchnorm", **bn_config) self._logits = basic.Linear(num_classes, w_init=jnp.zeros, name="logits")
def f(x): witness.append(None) return basic.Linear(1)(x)
def f(x): m = basic.Linear(20) y_slow = stateful.eval_shape(m, x) y_fast = eval_shape.fast_eval_shape(m, x) self.assertEqual(y_slow, y_fast) return m(x)
def __init__(self, blocks_per_group_list: Sequence[int], num_classes: int, bn_config: Optional[Mapping[Text, float]] = None, resnet_v2: bool = False, channels_per_group_list: Sequence[int] = (256, 512, 1024, 2048), name: Optional[Text] = None): """Constructs a ResNet model. Args: blocks_per_group_list: A sequence of length 4 that indicates the number of blocks created in each group. num_classes: The number of classes to classify the inputs into. bn_config: A dictionary of two elements, `decay_rate` and `eps` to be passed on to the `BatchNorm` layers. By default the `decay_rate` is `0.9` and `eps` is `1e-5`. resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults to False. channels_per_group_list: A sequence of length 4 that indicates the number of channels used for each block in each group. name: Name of the module. """ super(ResNet, self).__init__(name=name) if bn_config is None: bn_config = {"decay_rate": 0.9, "eps": 1e-5} self._bn_config = bn_config self._resnet_v2 = resnet_v2 # Number of blocks in each group for ResNet. if len(blocks_per_group_list) != 4: raise ValueError( "`blocks_per_group_list` must be of length 4 not {}".format( len(blocks_per_group_list))) self._blocks_per_group_list = blocks_per_group_list # Number of channels in each group for ResNet. if len(channels_per_group_list) != 4: raise ValueError( "`channels_per_group_list` must be of length 4 not {}".format( len(channels_per_group_list))) self._channels_per_group_list = channels_per_group_list self._initial_conv = conv.Conv2D( output_channels=64, kernel_shape=7, stride=2, with_bias=False, padding="SAME", name="initial_conv") if not self._resnet_v2: self._initial_batchnorm = batch_norm.BatchNorm( create_scale=True, create_offset=True, name="initial_batchnorm", **bn_config) self._block_groups = [] strides = [1, 2, 2, 2] for i in range(4): self._block_groups.append( BlockGroup( channels=self._channels_per_group_list[i], num_blocks=self._blocks_per_group_list[i], stride=strides[i], bn_config=bn_config, resnet_v2=resnet_v2, name="block_group_%d" % (i))) if self._resnet_v2: self._final_batchnorm = batch_norm.BatchNorm( create_scale=True, create_offset=True, name="final_batchnorm", **bn_config) self._logits = basic.Linear( output_size=num_classes, w_init=jnp.zeros, name="logits")
def f(): return basic.Linear(output_size=2)(jnp.zeros((2, 5, 6)))
def __call__(self, inputs, state): in2h = basic.Linear(self.hidden_size)(inputs) h2h = basic.Linear(self.hidden_size)(state) output = jax.nn.relu(in2h + h2h) new_h = output return output, new_h
def f(): return basic.Linear(output_size=6, with_bias=False)(jnp.zeros( (5, 6)))
def inner_fn(x, y): x_out = x + basic.Linear(100, name="linear1")(y) y_out = y + basic.Linear(100, name="linear2")(x) return x_out, y_out
def test_sequential(self): seq = basic.Sequential([basic.Linear(2), jax.nn.relu]) out = seq(jnp.zeros([3, 2])) self.assertEqual(out.shape, (3, 2))
def f_with_multi_args(x, a, b): return basic.Linear(width, w_init=initializers.Constant( jnp.eye(width)))(x) * a + b, None