Esempio n. 1
0
    def __call__(self, inputs):

        channel_index = utils.get_channel_index(self._data_format)
        weight_shape = self._kernel_shape + (1, self._channel_multiplier *
                                             inputs.shape[channel_index])
        fan_in_shape = np.prod(weight_shape[:-1])
        stddev = 1. / np.sqrt(fan_in_shape)
        w_init = self._w_init or initializers.TruncatedNormal(stddev=stddev)
        w = base.get_parameter("w", weight_shape, inputs.dtype, init=w_init)
        if self._channel_index == -1:
            dn = DIMENSION_NUMBERS[self._num_spatial_dims]
        else:
            dn = DIMENSION_NUMBERS_NCSPATIAL[self._num_spatial_dims]
        result = lax.conv_general_dilated(
            inputs,
            w,
            self._stride,
            self._padding,
            self._lhs_dilation,
            self._rhs_dilation,
            dn,
            feature_group_count=inputs.shape[channel_index])
        if self._with_bias:
            if channel_index == -1:
                bias_shape = (self._channel_multiplier *
                              inputs.shape[channel_index], )
            else:
                bias_shape = (self._channel_multiplier *
                              inputs.shape[channel_index], 1, 1)
            b = base.get_parameter("b", bias_shape, init=self._b_init)
            result = result + b
        return result
Esempio n. 2
0
  def __call__(self, inputs, state):
    if len(inputs.shape) > 2 or not inputs.shape:
      raise ValueError("GRU input must be rank-1 or rank-2.")

    input_size = inputs.shape[-1]
    hidden_size = self.hidden_size
    w_i = base.get_parameter(
        name="w_i", shape=[input_size, 3 * hidden_size], init=self._w_i_init)
    w_h = base.get_parameter(
        name="w_h", shape=[hidden_size, 3 * hidden_size], init=self._w_h_init)
    b = base.get_parameter(
        name="b",
        shape=[3 * hidden_size],
        dtype=inputs.dtype,
        init=self._b_init)
    w_h_z, w_h_a = jnp.split(w_h, indices_or_sections=[2 * hidden_size], axis=1)
    b_z, b_a = jnp.split(b, indices_or_sections=[2 * hidden_size], axis=0)

    gates_x = jnp.matmul(inputs, w_i)
    zr_x, a_x = jnp.split(
        gates_x, indices_or_sections=[2 * hidden_size], axis=-1)
    zr_h = jnp.matmul(state, w_h_z)
    zr = zr_x + zr_h + jnp.broadcast_to(b_z, zr_h.shape)
    z, r = jnp.split(jax.nn.sigmoid(zr), indices_or_sections=2, axis=-1)

    a_h = jnp.matmul(r * state, w_h_a)
    a = jnp.tanh(a_x + a_h + jnp.broadcast_to(b_a, a_h.shape))

    next_state = (1 - z) * state + z * a
    return next_state, next_state
Esempio n. 3
0
    def test_unable_to_mutate_name(self):
        def mutates_name(next_creator, name, shape, dtype, init):
            next_creator(name + "_foo", shape, dtype, init)

        with base.new_context(), base.custom_creator(mutates_name):
            with self.assertRaisesRegex(ValueError,
                                        "Modifying .*name.* not supported"):
                base.get_parameter("w", [], init=jnp.ones)
Esempio n. 4
0
 def test_context_copies_input(self):
     before = {"~": {"w": jnp.array(1.)}}
     with base.new_context(params=before, state=before) as ctx:
         base.get_parameter("w", [], init=jnp.ones)
         base.set_state("w", jnp.array(2.))
     self.assertEqual(ctx.collect_params(), {"~": {"w": jnp.array(1.)}})
     self.assertIsNot(ctx.collect_initial_state(), before)
     self.assertEqual(ctx.collect_initial_state(), before)
     self.assertEqual(ctx.collect_state(), {"~": {"w": jnp.array(2.)}})
     self.assertEqual(before, {"~": {"w": jnp.array(1.)}})
Esempio n. 5
0
    def __call__(self, inputs, scale=None, offset=None):
        """Connects the layer norm.

    Args:
      inputs: An array, where the data format is [N, ..., C].
      scale: An array up to n-D. The shape of this tensor must be broadcastable
        to the shape of `inputs`. This is the scale applied to the normalized
        inputs. This cannot be passed in if the module was constructed with
        `create_scale=True`.
      offset: An array up to n-D. The shape of this tensor must be broadcastable
        to the shape of `inputs`. This is the offset applied to the normalized
        inputs. This cannot be passed in if the module was constructed with
        `create_offset=True`.

    Returns:
      The array, normalized.
    """
        if isinstance(self._axis, slice):
            axes = tuple(range(len(inputs.shape)))
            axis = axes[self._axis]
        else:
            axis = self._axis

        m = jnp.mean(inputs, axis=axis, keepdims=True)
        variance = jnp.var(inputs, axis=axis, keepdims=True)
        param_shape = inputs.shape[-1:]
        if self._create_scale:
            if scale is not None:
                raise ValueError(
                    "Cannot pass `scale` at call time if `create_scale=True`.")
            scale = base.get_parameter("scale",
                                       param_shape,
                                       init=self._scale_init)
        elif scale is None:
            scale = 1.

        if self._create_offset:
            if offset is not None:
                raise ValueError(
                    "Cannot pass `offset` at call time if `create_offset=True`."
                )
            offset = base.get_parameter("offset",
                                        param_shape,
                                        init=self._offset_init)
        elif offset is None:
            offset = 0.

        scale = jnp.broadcast_to(scale, inputs.shape)
        offset = jnp.broadcast_to(offset, inputs.shape)
        m = jnp.broadcast_to(m, inputs.shape)

        inv = scale * jax.lax.rsqrt(variance + self._eps)
        return inv * (inputs - m) + offset
Esempio n. 6
0
    def test_init_custom_creator(self):
        def zeros_creator(next_creator, name, shape, dtype, init):
            self.assertEqual(name, "~/w")
            self.assertEqual(shape, [])
            self.assertEqual(dtype, jnp.float32)
            self.assertEqual(init, jnp.ones)
            return next_creator(name, shape, dtype, jnp.zeros)

        with base.new_context() as ctx:
            with base.custom_creator(zeros_creator):
                base.get_parameter("w", [], init=jnp.ones)

        self.assertEqual(ctx.collect_params(), {"~": {"w": jnp.zeros([])}})
Esempio n. 7
0
  def __call__(self, inputs):
    if not inputs.shape:
      raise ValueError("Input must not be scalar.")

    self.input_size = inputs.shape[-1]
    default_stddev = 1. / jnp.sqrt(self.input_size)
    w_init = self.w_init or initializers.TruncatedNormal(stddev=default_stddev)

    w = base.get_parameter("w", [self.input_size, self.output_size],
                           inputs.dtype, init=w_init)
    out = jnp.dot(inputs, w)
    if self.with_bias:
      out += base.get_parameter("b", [self.output_size], inputs.dtype,
                                init=self.b_init)
    return out
Esempio n. 8
0
    def test_parameter_in_apply(self, params):
        _, apply_fn = transform.transform(
            lambda: base.get_parameter("w", [], init=jnp.zeros))

        with self.assertRaisesRegex(
                ValueError, "parameters must be created as part of `init`"):
            apply_fn(params)
Esempio n. 9
0
    def __call__(self, inputs: jnp.ndarray, multiplier: FloatLike = None):
        """Adds bias to `inputs` and optionally multiplies by `multiplier`.

    Args:
      inputs: A Tensor of size `[batch_size, input_size1, ...]`.
      multiplier: A scalar or Tensor which the bias term is multiplied by before
        adding it to `inputs`. Anything which works in the expression `bias *
        multiplier` is acceptable here. This may be useful if you want to add a
        bias in one place and subtract the same bias in another place via
        `multiplier=-1`.

    Returns:
      A Tensor of size `[batch_size, input_size1, ...]`.
    """
        utils.assert_minimum_rank(inputs, 2)

        input_shape = inputs.shape
        self.bias_shape = calculate_bias_shape(input_shape, self.bias_dims)

        input_size = input_shape[1:]
        if self.output_size is not None and self.output_size != input_size:
            raise ValueError("Input shape must be {} not {}".format(
                (-1, ) + self.output_size, input_shape))

        self.input_size = input_size
        b = base.get_parameter("b",
                               self.bias_shape,
                               inputs.dtype,
                               init=self.b_init)
        b = jnp.broadcast_to(b, inputs.shape)

        if multiplier is not None:
            return inputs + (b * multiplier)
        else:
            return inputs + b
Esempio n. 10
0
    def test_do_not_store(self):
        def my_creator(next_creator, shape, dtype, init, context):
            del next_creator, shape, dtype, init, context
            return base.DO_NOT_STORE

        def my_getter(next_getter, value, context):
            assert value is base.DO_NOT_STORE
            return next_getter(
                context.original_init(context.original_shape,
                                      context.original_dtype))

        def my_setter(next_setter, value, context):
            del next_setter, value, context
            return base.DO_NOT_STORE

        with base.new_context() as ctx:
            with base.custom_creator(my_creator, state=True), \
                 base.custom_getter(my_getter, state=True), \
                 base.custom_setter(my_setter):
                self.assertEqual(base.get_parameter("w", [], init=jnp.ones), 1)
                self.assertEqual(base.get_state("s1", [], init=jnp.ones), 1)
                base.set_state("s2", jnp.ones([]))

        self.assertEmpty(ctx.collect_params())
        self.assertEmpty(ctx.collect_state())
Esempio n. 11
0
 def __call__(self, *args, **kwargs):
   frame = base.current_frame()
   bundle_name = self.module_name
   if _SENTINEL_NAME in frame.params[bundle_name]:
     prefix = bundle_name + "/"
     lifted_params = unpack_from_dict(frame.params, prefix)
     lifted_state = unpack_from_dict(frame.state, prefix)
     return lifted_params, lifted_state
   else:
     # Ensure sentinel is set for apply.
     base.get_parameter(_SENTINEL_NAME, (), init=jnp.zeros)
     # Lift parameters into this transform's params_dict.
     params, state = self._init_fn(*args, **kwargs)
     pack_into_dict(params, frame.params, bundle_name)
     pack_into_dict(state, frame.state, bundle_name)
     return params, state
Esempio n. 12
0
  def __init__(self,
               embedding_dim: int,
               num_embeddings: int,
               commitment_cost: float,
               dtype: DType = jnp.float32,
               name: str = None):
    """Initializes a VQ-VAE module.

    Args:
      embedding_dim: dimensionality of the tensors in the quantized space.
        Inputs to the modules must be in this format as well.
      num_embeddings: number of vectors in the quantized space.
      commitment_cost: scalar which controls the weighting of the loss terms
        (see equation 4 in the paper - this variable is Beta).
      dtype: dtype for the embeddings variable, defaults to tf.float32.
      name: name of the module.
    """
    super(VectorQuantizer, self).__init__(name=name)
    self.embedding_dim = embedding_dim
    self.num_embeddings = num_embeddings
    self.commitment_cost = commitment_cost

    embedding_shape = [embedding_dim, num_embeddings]
    initializer = initializers.VarianceScaling(distribution='uniform')
    self.embeddings = base.get_parameter('embeddings', embedding_shape, dtype,
                                         init=initializer)
Esempio n. 13
0
  def test_getter_types(self, params, state):
    log = []
    def logging_getter(next_getter, value, context):
      log.append(context.full_name)
      return next_getter(value)

    with base.new_context():
      with base.custom_getter(logging_getter, params=params, state=state):
        base.get_parameter("params", [], init=jnp.zeros)
        base.get_state("state", [], init=jnp.zeros)

    self.assertLen(log, int(params) + int(state))
    if params:
      self.assertIn("~/params", log)
    if state:
      self.assertIn("~/state", log)
Esempio n. 14
0
File: conv.py Progetto: ibab/haiku
    def __call__(self, inputs):
        """Connects `ConvND` layer.

    Args:
      inputs: A rank-N+2 array with shape [N, spatial_dims, C].

    Returns:
      A rank-N+2 array with shape [N, spatial_dims, output_channels].
    """
        if len(inputs.shape) != self._num_spatial_dims + 2:
            raise ValueError(
                "Input to ConvND needs to have rank {}, but input "
                "has shape {}.".format(self._num_spatial_dims + 2,
                                       inputs.shape))
        weight_shape = self._kernel_shape + (inputs.shape[self._channel_index],
                                             self._output_channels)

        fan_in_shape = np.prod(weight_shape[:-1])
        stddev = 1. / np.sqrt(fan_in_shape)
        w_init = self._w_init or initializers.TruncatedNormal(stddev=stddev)
        w = base.get_parameter("w", weight_shape, inputs.dtype, init=w_init)

        if self._mask is not None:
            if self._mask.shape != w.shape:
                raise ValueError(
                    "Mask needs to have the same shape as weights. "
                    "Shapes are: {}, {}".format(self._mask.shape, w.shape))
            w *= self._mask
        result = lax.conv_general_dilated(inputs,
                                          w,
                                          self._stride,
                                          self._padding,
                                          lhs_dilation=self._lhs_dilation,
                                          rhs_dilation=self._kernal_dilation,
                                          dimension_numbers=self._dn)
        if self._with_bias:
            if self._channel_index == -1:
                bias_shape = (self._output_channels, )
            else:
                bias_shape = (
                    self._output_channels, ) + (1, ) * self._num_spatial_dims
            b = base.get_parameter("b",
                                   bias_shape,
                                   inputs.dtype,
                                   init=self._b_init)
            result = result + b
        return result
Esempio n. 15
0
    def test_custom_getter_bf16(self):
        def bf16_getter(next_getter, value, context):
            del context
            if value.dtype == jnp.float32:
                value = value.astype(jnp.bfloat16)
            return next_getter(value)

        with base.new_context() as ctx:
            with base.custom_getter(bf16_getter):
                f = base.get_parameter("f", [], jnp.float32, init=jnp.ones)
                i = base.get_parameter("i", [], jnp.int32, init=jnp.ones)

        params = ctx.collect_params()
        self.assertEqual(params["~"]["f"].dtype, jnp.float32)
        self.assertEqual(f.dtype, jnp.bfloat16)
        self.assertEqual(params["~"]["i"].dtype, jnp.int32)
        self.assertEqual(i.dtype, jnp.int32)
Esempio n. 16
0
 def __call__(self, x):
     assert x.ndim == 0
     p = base.get_parameter("p", [],
                            jnp.int32,
                            init=lambda *_: jnp.array(2))
     y = x**p
     base.set_state("y", y)
     return y
Esempio n. 17
0
    def test_nested_creators(self):
        log = []

        def logging_creator(log_msg):
            def _logging_creator(next_creator, name, shape, dtype, init):
                log.append(log_msg)
                return next_creator(name, shape, dtype, init)

            return _logging_creator

        with base.new_context():
            with base.custom_creator(logging_creator("a")), \
                 base.custom_creator(logging_creator("b")), \
                 base.custom_creator(logging_creator("c")):
                base.get_parameter("w", [], init=jnp.ones)

        self.assertEqual(log, ["a", "b", "c"])
Esempio n. 18
0
            def __call__(self, carry, x):
                x += base.get_parameter("w", shape=[], init=jnp.zeros)

                inner = transform.transform(inner_fn)
                keys = base.next_rng_key() if transform.running_init(
                ) else None
                params = lift.lift(inner.init,
                                   allow_reuse=self._allow_reuse)(keys, x)
                return carry, inner.apply(params, None, x)
Esempio n. 19
0
  def test_unable_to_mutate_name(self):
    def mutates_name(next_creator, name, shape, dtype, init):
      next_creator(name + "_foo", shape, dtype, init)

    init_fn, _ = base.transform(
        lambda: base.get_parameter("w", [], init=jnp.ones))

    with self.assertRaisesRegex(ValueError, "Modifying .*name.* not supported"):
      with base.custom_creator(mutates_name):
        init_fn(None)
Esempio n. 20
0
            def __call__(self, x):
                x += base.get_parameter("a", shape=[10, 10], init=jnp.zeros)

                def inner_fn(x):
                    return InnerModule(name="inner")(x)

                inner_transformed = transform.transform(inner_fn)
                inner_params = lift.transparent_lift(inner_transformed.init)(
                    base.next_rng_key(), x)
                x = inner_transformed.apply(inner_params, base.next_rng_key(),
                                            x)
                return x
Esempio n. 21
0
    def test_init_custom_creator(self):
        def zeros_creator(next_creator, name, shape, dtype, init):
            self.assertEqual(name, "~/w")
            self.assertEqual(shape, [])
            self.assertEqual(dtype, jnp.float32)
            self.assertEqual(init, jnp.ones)
            return next_creator(name, shape, dtype, jnp.zeros)

        init_fn, _ = base.transform(
            lambda: base.get_parameter("w", [], init=jnp.ones))

        with base.custom_creator(zeros_creator):
            params = init_fn(None)

        self.assertEqual(params, {"~": {"w": jnp.zeros([])}})
Esempio n. 22
0
        def f():
            w = base.get_parameter('w', [], init=jnp.zeros)
            s = base.get_state('s', [], init=jnp.zeros)
            init = lambda: None

            def add():
                s_add = base.get_state('s', [], init=jnp.zeros)
                w_add = base.get_parameter('w', [], init=jnp.zeros)
                return w, w_add, s, s_add

            def sub():
                s_sub = base.get_state('s', [], init=jnp.zeros)
                w_sub = base.get_parameter('w', [], init=jnp.zeros)
                return w, w_sub, s, s_sub

            return init, (add, sub)
Esempio n. 23
0
  def test_nested_creators(self):
    log = []

    def logging_creator(log_msg):
      def _logging_creator(next_creator, name, shape, dtype, init):
        log.append(log_msg)
        return next_creator(name, shape, dtype, init)
      return _logging_creator

    init_fn, _ = base.transform(
        lambda: base.get_parameter("w", [], init=jnp.ones))

    a, b, c = map(logging_creator, ["a", "b", "c"])
    with base.custom_creator(a), base.custom_creator(b), base.custom_creator(c):
      init_fn(None)

    self.assertEqual(log, ["a", "b", "c"])
Esempio n. 24
0
    def test_original_shape(self):
        def new_shape_creator(next_creator, shape, dtype, init, context):
            del shape
            del context
            new_shape = (1, 2, 3)
            return next_creator(new_shape, dtype, init)

        def original_shape_restorer(next_creator, shape, dtype, init, context):
            assert shape == (1, 2, 3)
            return next_creator(context.original_shape, dtype, init)

        with base.new_context():
            with base.custom_creator(new_shape_creator):
                with base.custom_creator(original_shape_restorer):
                    param = base.get_parameter("w", [5], jnp.bfloat16,
                                               jnp.ones)
                    assert param.shape == (5, )
Esempio n. 25
0
    def test_assert_no_new_parameters(self):
        with base.new_context():
            base.get_parameter("w", [], init=jnp.zeros)
            with base.assert_no_new_parameters():
                # Should not raise, "w" already exists.
                base.get_parameter("w", [], init=jnp.zeros)

            with self.assertRaisesRegex(AssertionError,
                                        "New parameters were created: .*x"):
                with base.assert_no_new_parameters():
                    # Should raise, "x" does not exist.
                    base.get_parameter("x", [], init=jnp.zeros)
Esempio n. 26
0
    def test_original_dtype(self):
        def dtype_cast_creator(next_creator, shape, dtype, init, context):
            if context.original_dtype == jnp.bfloat16:
                dtype = jnp.float32
            return next_creator(shape, dtype, init)

        def dtype_recast_getter(next_getter, value, context):
            if context.original_dtype == jnp.bfloat16:
                assert value.dtype == jnp.float32
                value = value.astype(jnp.bfloat16)
            return next_getter(value)

        with base.new_context() as ctx:
            with base.custom_creator(dtype_cast_creator), \
                 base.custom_getter(dtype_recast_getter):
                param = base.get_parameter("w", [], jnp.bfloat16, jnp.ones)
                orig_param = jax.tree_leaves(ctx.collect_params())[0]

                assert param.dtype == jnp.bfloat16
                assert orig_param.dtype == jnp.float32
Esempio n. 27
0
    def test_nested_getters(self):
        log = []

        def logging_getter(log_msg, dtype_in, dtype_out):
            def _logging_getter(next_getter, value, context):
                del context
                log.append(log_msg)
                self.assertEqual(value.dtype, dtype_in)
                value = value.astype(dtype_out)
                return next_getter(value)

            return _logging_getter

        with base.new_context():
            with base.custom_getter(logging_getter("a", jnp.float32, jnp.bfloat16)), \
                 base.custom_getter(logging_getter("b", jnp.bfloat16, jnp.int32)), \
                 base.custom_getter(logging_getter("c", jnp.int32, jnp.int8)):
                w = base.get_parameter("w", [], init=jnp.ones)

        self.assertEqual(w.dtype, jnp.int8)
        self.assertEqual(log, ["a", "b", "c"])
Esempio n. 28
0
 def bias_fn(x):
     b = base.get_parameter("b", [], init=jnp.ones)
     return x + b
Esempio n. 29
0
    def __init__(self,
                 vocab_size=None,
                 embed_dim=None,
                 embedding_matrix=None,
                 w_init=None,
                 lookup_style=EmbedLookupStyle.ARRAY_INDEX.name,
                 name=None):
        """Constructs an Embed module.

    Args:
      vocab_size: int or None: the number of unique tokens to embed. If not
        provided, an existing vocabulary matrix from which vocab_size can be
        inferred must be provided as `existing_vocab`.
      embed_dim: int or None. Number of dimensions to assign to each embedding.
        If an existing vocabulary matrix initializes the module, this should not
        be provided as it will be inferred.
      embedding_matrix: A matrix-like object equivalent in size to
        [vocab_size, embed_dim]. If given, it is used as the initial value for
        the embedding matrix and neither vocab_size or embed_dim need be given.
        If they are given, their values are checked to be consistent with the
        dimensions of embedding_matrix.
      w_init: An initializer for the embeddings matrix. As a default,
        embeddings are initialized via a truncated normal distribution.
      lookup_style: One of the enum values of EmbedLookupStyle determining how
        to access the value of the embbeddings given an ID. Regardless the input
        should be a dense array of integer values representing ids. This setting
        changes how internally this module maps those ides to embeddings. The
        result is the same, but the speed and memory tradeoffs are different.
        It default to using numpy-style array indexing. This value is only the
        default for the module, and at any given invocation can be overriden
        in the __call__ method.
      name: string. Name for this module.

    Raise:
      ValueError: If none of embed_dim, embedding_matrix and vocab_size are
        supplied, or if embedding_matrix is supplied and embed_dim or vocab_size
        is not consistent with the supplied matrix.
    """
        super(Embed, self).__init__(name=name)
        if not embedding_matrix and not (vocab_size and embed_dim):
            raise ValueError(
                "hk.Embed must be supplied either with an initial `embedding_matrix` "
                "or with `embed_dim` and `vocab_size`.")
        if embedding_matrix:
            embedding_matrix = jnp.asarray(embedding_matrix)
            if vocab_size and embedding_matrix.shape[0] != vocab_size:
                raise ValueError(
                    "An `embedding_matrix` was supplied but the `vocab_size` of {vs} "
                    "was not consistent with its shape {emb_shape}.".format(
                        vs=vocab_size, emb_shape=embedding_matrix.shape))
            if embed_dim and embedding_matrix.shape[1] != embed_dim:
                raise ValueError(
                    "An `embedding_matrix` was supplied but the `embed_dim` of {ed} "
                    "was not consistent with its shape {emb_shape}.".format(
                        ed=embed_dim, emb_shape=embedding_matrix.shape))
            self._embedding = base.get_parameter(
                "embeddings",
                shape=embedding_matrix.shape,
                init=lambda _, __: embedding_matrix)
        else:
            w_init = w_init or hk_init.TruncatedNormal()
            self._embedding = base.get_parameter("embeddings",
                                                 shape=[vocab_size, embed_dim],
                                                 init=w_init)

        self._vocab_size = vocab_size or embedding_matrix.shape[0]
        self._embed_dim = embed_dim or embedding_matrix.shape[1]
        self._lookup_style = lookup_style
Esempio n. 30
0
 def __call__(self):
     w = base.get_parameter("w", [], init=jnp.zeros)
     with module.name_scope("foo"):
         w_foo = base.get_parameter("w", [], init=jnp.zeros)
     return w, w_foo