Пример #1
0
def get_net(x):
    def init(v):
        return dict(w_init=lambda *args: v * jnp.ones((1, 1)),
                    b_init=lambda *args: v * 1.5 * jnp.ones((1, )))

    h = basic.Linear(output_size=1, name="first_layer", **init(1.0))(x)
    h = basic.Linear(output_size=1, name="second_layer", **init(3.0))(h)
    return jnp.mean(h)
Пример #2
0
 def test_sequential_params(self):
   seq = basic.Sequential([
       basic.Sequential([basic.Linear(2), basic.Linear(2)]),
       basic.Sequential([lambda x: basic.Linear(2)(x * 1)])])
   for _ in range(2):
     # Connect seq to ensure params are created. Connect twice to ensure that
     # we see the two instances of the lambda Linear.
     seq(jnp.zeros([1, 1]))
   params = seq.params_dict()
   self.assertCountEqual(
       list(params),
       ["linear/w", "linear/b", "linear_1/w", "linear_1/b",
        "sequential_1/linear/w", "sequential_1/linear/b"])
Пример #3
0
 def inner_fn(x, extra):
     out = basic.Linear(
         x.shape[1],
         w_init=initializers.Constant(extra * jnp.eye(x.shape[1])),
         b_init=initializers.Constant(extra),
     )(x)
     return out, out
Пример #4
0
    def test_connection_and_shapes(self):
        batch_size = 4
        x = make_sequence([batch_size, 3])  # [B, F]
        core = recurrent.DeepRNN([
            recurrent.VanillaRNN(hidden_size=3),
            basic.Linear(2),
            jax.nn.relu,
            recurrent.VanillaRNN(hidden_size=5),
            jax.nn.relu,
        ])
        initial_state = core.initial_state(x.shape[0])
        out, next_state = core(x, initial_state)

        self.assertEqual(out.shape, (batch_size, 5))
        # Verifies that at least last layer of relu is applied.
        self.assertTrue(np.all(out >= np.zeros([batch_size, 5])))

        self.assertLen(next_state, 2)
        self.assertEqual(initial_state[0].shape, (batch_size, 3))
        self.assertEqual(initial_state[1].shape, (batch_size, 5))

        self.assertLen(initial_state, 2)
        np.testing.assert_allclose(initial_state[0], jnp.zeros([batch_size,
                                                                3]))
        np.testing.assert_allclose(initial_state[1], jnp.zeros([batch_size,
                                                                5]))
Пример #5
0
 def inner_fn(x):
     # Here we initialize the layer to an identity + 1, while later we multiply
     # each parameter by the index `n`.
     return basic.Linear(
         x.shape[1],
         w_init=initializers.Constant(jnp.eye(x.shape[1])),
         b_init=initializers.Constant(1.0),
     )(x)
Пример #6
0
 def f_with_container_state(x):
     hk_layer = basic.Linear(width,
                             w_init=initializers.Constant(
                                 jnp.eye(width)))
     layer_output = hk_layer(x)
     layer_state = {
         "raw_output": layer_output,
         "output_projection": jnp.sum(layer_output)
     }
     return layer_output + jnp.ones_like(layer_output), layer_state
Пример #7
0
 def inner_fn(x, extra):
     # Compared to previous test we pass in the `extra` argument as an
     # additional input, in order to directly initialize the parameters to the
     # index `n` of the iteration.
     out = basic.Linear(
         x.shape[1],
         w_init=initializers.Constant(extra * jnp.eye(x.shape[1])),
         b_init=initializers.Constant(extra),
     )(x)
     return out, out
Пример #8
0
 def __call__(self, inputs, state):
     if len(inputs.shape) > 2 or not inputs.shape:
         raise ValueError("LSTM input must be rank-1 or rank-2.")
     prev_h, prev_c = state
     x_and_h = jnp.concatenate([inputs, prev_h], axis=-1)
     gated = basic.Linear(4 * self.hidden_size)(x_and_h)
     # i = input, g = cell_gate, f = forget_gate, o = output_gate
     i, g, f, o = jnp.split(gated, indices_or_sections=4, axis=-1)
     f = jax.nn.sigmoid(f + 1)  # Forget bias, as in sonnet.
     c = f * prev_c + jax.nn.sigmoid(i) * jnp.tanh(g)
     h = jax.nn.sigmoid(o) * jnp.tanh(c)
     return h, (h, c)
Пример #9
0
        def net(x):
            # x is [B, F].
            core = recurrent.DeepRNN([
                recurrent.VanillaRNN(hidden_size=3),
                basic.Linear(2),
                jax.nn.relu,
                recurrent.VanillaRNN(hidden_size=5),
                jax.nn.relu,
            ])
            initial_state = core.initial_state(x.shape[0])
            out, next_state = core(x, initial_state)

            return dict(out=out,
                        next_state=next_state,
                        initial_state=initial_state)
Пример #10
0
 def __call__(self, inputs, is_training):
     initial_conv = conv.Conv2D(32, (3, 3),
                                stride=2,
                                padding="VALID",
                                with_bias=self._with_bias)
     net = initial_conv(inputs)
     if self._use_bn:
         net = batch_norm.BatchNorm(create_scale=True,
                                    create_offset=True)(net, is_training)
     net = jax.nn.relu(net)
     for i in range(len(self._strides)):
         net = MobileNetV1Block(self._channels[i], self._strides[i],
                                self._use_bn)(net, is_training)
     net = jnp.mean(net, axis=(1, 2))
     net = reshape.Flatten()(net)
     net = basic.Linear(self._num_classes, name="logits")(net)
     return net
Пример #11
0
 def test_fast_eval_shape_already_transformed(self):
     f = transform.transform(lambda x: basic.Linear(20)(x))  # pylint: disable=unnecessary-lambda
     rng = jax.random.PRNGKey(0)
     x = jnp.ones([1, 12])
     # init_fn
     y_slow = jax.eval_shape(f.init, rng, x)
     y_fast = eval_shape.fast_eval_shape(f.init, rng, x)
     self.assertEqual(y_slow, y_fast)
     self.assertEqual(
         y_slow, {
             'linear': {
                 'w': jax.ShapeDtypeStruct((12, 20), jnp.float32),
                 'b': jax.ShapeDtypeStruct((20, ), jnp.float32)
             }
         })
     # apply_fn
     y_slow = jax.eval_shape(f.apply, y_slow, rng, x)
     y_fast = eval_shape.fast_eval_shape(f.apply, y_fast, rng, x)
     self.assertEqual(y_slow, y_fast)
Пример #12
0
    def __init__(self,
                 output_sizes: Iterable[int],
                 w_init: Optional[base.Initializer] = None,
                 b_init: Optional[base.Initializer] = None,
                 with_bias: bool = True,
                 activation: Callable[[jnp.ndarray],
                                      jnp.ndarray] = jax.nn.relu,
                 activate_final: bool = False,
                 name: Optional[Text] = None):
        """Constructs an MLP.

    Args:
      output_sizes: Sequence of layer sizes.
      w_init: Initializer for Linear weights.
      b_init: Initializer for Linear bias. Must be `None` if `with_bias` is
        `False`.
      with_bias: Whether or not to apply a bias in each layer.
      activation: Activation function to apply between linear layers. Defaults
        to ReLU.
      activate_final: Whether or not to activate the final layer of the MLP.
      name: Optional name for this module.

    Raises:
      ValueError: If with_bias is False and b_init is not None.
    """
        if not with_bias and b_init is not None:
            raise ValueError("When with_bias=False b_init must not be set.")

        super(MLP, self).__init__(name=name)
        self._with_bias = with_bias
        self._w_init = w_init
        self._b_init = b_init
        self._activation = activation
        self._activate_final = activate_final
        self._layers = []
        for index, output_size in enumerate(output_sizes):
            self._layers.append(
                basic.Linear(output_size=output_size,
                             w_init=w_init,
                             b_init=b_init,
                             with_bias=with_bias,
                             name="linear_%d" % index))
Пример #13
0
 def f():
     return basic.Linear(output_size=2)(jnp.zeros([6]))
Пример #14
0
 def f():
   seq = basic.Sequential([basic.Linear(2), jax.nn.relu])
   return seq(jnp.zeros([3, 2]))
Пример #15
0
 def __call__(self, inputs, prev_state):
   # TODO(slebedev): Consider dropping one of the biases.
   in2h = basic.Linear(self.hidden_size)(inputs)
   h2h = basic.Linear(self.hidden_size)(prev_state)
   outputs = jax.nn.relu(in2h + h2h)
   return outputs, outputs
Пример #16
0
 def f():
     return basic.Linear(output_size=2, b_init=jnp.ones)(jnp.zeros([6]))
Пример #17
0
 def f():
     return basic.Linear(output_size=2,
                         name=linear_name)(jnp.zeros([6]))
Пример #18
0
 def inner_fn(x):
     x += basic.Linear(100, name="linear1")(x)
     x += basic.Linear(100, name="linear2")(x)
     x /= jnp.mean(x)
     return x
Пример #19
0
 def f(x):
     return basic.Linear(1)(x, precision=precision)
Пример #20
0
 def layer_fn(x):
     return basic.Linear(100)(x)
Пример #21
0
    def __init__(self,
                 blocks_per_group: Sequence[int],
                 num_classes: int,
                 bn_config: Optional[Mapping[str, float]] = None,
                 resnet_v2: bool = False,
                 channels_per_group: Sequence[int] = (256, 512, 1024, 2048),
                 name: Optional[str] = None):
        """Constructs a ResNet model.

    Args:
      blocks_per_group: A sequence of length 4 that indicates the number of
        blocks created in each group.
      num_classes: The number of classes to classify the inputs into.
      bn_config: A dictionary of two elements, `decay_rate` and `eps` to be
        passed on to the `BatchNorm` layers. By default the `decay_rate` is
        `0.9` and `eps` is `1e-5`.
      resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults to
        False.
      channels_per_group: A sequence of length 4 that indicates the number
        of channels used for each block in each group.
      name: Name of the module.
    """
        super().__init__(name=name)
        self._resnet_v2 = resnet_v2

        bn_config = dict(bn_config or {})
        bn_config.setdefault("decay_rate", 0.9)
        bn_config.setdefault("eps", 1e-5)
        bn_config.setdefault("create_scale", True)
        bn_config.setdefault("create_offset", True)

        # Number of blocks in each group for ResNet.
        check_length(4, blocks_per_group, "blocks_per_group")
        check_length(4, channels_per_group, "channels_per_group")

        self._initial_conv = conv.Conv2D(output_channels=64,
                                         kernel_shape=7,
                                         stride=2,
                                         with_bias=False,
                                         padding="SAME",
                                         name="initial_conv")

        if not self._resnet_v2:
            self._initial_batchnorm = batch_norm.BatchNorm(
                name="initial_batchnorm", **bn_config)

        self._block_groups = []
        strides = (1, 2, 2, 2)
        for i in range(4):
            self._block_groups.append(
                BlockGroup(channels=channels_per_group[i],
                           num_blocks=blocks_per_group[i],
                           stride=strides[i],
                           bn_config=bn_config,
                           resnet_v2=resnet_v2,
                           name="block_group_%d" % (i)))

        if self._resnet_v2:
            self._final_batchnorm = batch_norm.BatchNorm(
                name="final_batchnorm", **bn_config)

        self._logits = basic.Linear(num_classes,
                                    w_init=jnp.zeros,
                                    name="logits")
Пример #22
0
 def f(x):
     witness.append(None)
     return basic.Linear(1)(x)
Пример #23
0
 def f(x):
     m = basic.Linear(20)
     y_slow = stateful.eval_shape(m, x)
     y_fast = eval_shape.fast_eval_shape(m, x)
     self.assertEqual(y_slow, y_fast)
     return m(x)
Пример #24
0
  def __init__(self,
               blocks_per_group_list: Sequence[int],
               num_classes: int,
               bn_config: Optional[Mapping[Text, float]] = None,
               resnet_v2: bool = False,
               channels_per_group_list: Sequence[int] = (256, 512, 1024, 2048),
               name: Optional[Text] = None):
    """Constructs a ResNet model.

    Args:
      blocks_per_group_list: A sequence of length 4 that indicates the number of
        blocks created in each group.
      num_classes: The number of classes to classify the inputs into.
      bn_config: A dictionary of two elements, `decay_rate` and `eps` to be
        passed on to the `BatchNorm` layers. By default the `decay_rate` is
        `0.9` and `eps` is `1e-5`.
      resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults to
        False.
      channels_per_group_list: A sequence of length 4 that indicates the number
        of channels used for each block in each group.
      name: Name of the module.
    """
    super(ResNet, self).__init__(name=name)
    if bn_config is None:
      bn_config = {"decay_rate": 0.9, "eps": 1e-5}
    self._bn_config = bn_config
    self._resnet_v2 = resnet_v2

    # Number of blocks in each group for ResNet.
    if len(blocks_per_group_list) != 4:
      raise ValueError(
          "`blocks_per_group_list` must be of length 4 not {}".format(
              len(blocks_per_group_list)))
    self._blocks_per_group_list = blocks_per_group_list

    # Number of channels in each group for ResNet.
    if len(channels_per_group_list) != 4:
      raise ValueError(
          "`channels_per_group_list` must be of length 4 not {}".format(
              len(channels_per_group_list)))
    self._channels_per_group_list = channels_per_group_list

    self._initial_conv = conv.Conv2D(
        output_channels=64,
        kernel_shape=7,
        stride=2,
        with_bias=False,
        padding="SAME",
        name="initial_conv")
    if not self._resnet_v2:
      self._initial_batchnorm = batch_norm.BatchNorm(
          create_scale=True,
          create_offset=True,
          name="initial_batchnorm",
          **bn_config)

    self._block_groups = []
    strides = [1, 2, 2, 2]
    for i in range(4):
      self._block_groups.append(
          BlockGroup(
              channels=self._channels_per_group_list[i],
              num_blocks=self._blocks_per_group_list[i],
              stride=strides[i],
              bn_config=bn_config,
              resnet_v2=resnet_v2,
              name="block_group_%d" % (i)))

    if self._resnet_v2:
      self._final_batchnorm = batch_norm.BatchNorm(
          create_scale=True,
          create_offset=True,
          name="final_batchnorm",
          **bn_config)

    self._logits = basic.Linear(
        output_size=num_classes, w_init=jnp.zeros, name="logits")
Пример #25
0
 def f():
     return basic.Linear(output_size=2)(jnp.zeros((2, 5, 6)))
Пример #26
0
 def __call__(self, inputs, state):
     in2h = basic.Linear(self.hidden_size)(inputs)
     h2h = basic.Linear(self.hidden_size)(state)
     output = jax.nn.relu(in2h + h2h)
     new_h = output
     return output, new_h
Пример #27
0
 def f():
     return basic.Linear(output_size=6, with_bias=False)(jnp.zeros(
         (5, 6)))
Пример #28
0
 def inner_fn(x, y):
     x_out = x + basic.Linear(100, name="linear1")(y)
     y_out = y + basic.Linear(100, name="linear2")(x)
     return x_out, y_out
Пример #29
0
 def test_sequential(self):
     seq = basic.Sequential([basic.Linear(2), jax.nn.relu])
     out = seq(jnp.zeros([3, 2]))
     self.assertEqual(out.shape, (3, 2))
Пример #30
0
 def f_with_multi_args(x, a, b):
     return basic.Linear(width,
                         w_init=initializers.Constant(
                             jnp.eye(width)))(x) * a + b, None