Exemplo n.º 1
0
    def test_setter_tree(self):
        witness = []
        x = {"a": jnp.ones([]), "b": jnp.zeros([123])}
        y = jax.tree_map(lambda x: x + 1, x)

        def my_setter(next_setter, value, ctx):
            self.assertIs(value, x)
            self.assertEqual(ctx.original_shape, {"a": (), "b": (123, )})
            self.assertEqual(ctx.original_dtype, {
                "a": jnp.float32,
                "b": jnp.float32
            })
            self.assertEqual(ctx.full_name, "~/x")
            self.assertEqual(ctx.name, "x")
            self.assertIsNone(ctx.module)
            witness.append(None)
            del next_setter
            return y

        with base.new_context():
            with base.custom_setter(my_setter):
                base.set_state("x", x)
                x = base.get_state("x")
                self.assertIs(x, y)

        self.assertNotEmpty(witness)
Exemplo n.º 2
0
    def test_do_not_store(self):
        def my_creator(next_creator, shape, dtype, init, context):
            del next_creator, shape, dtype, init, context
            return base.DO_NOT_STORE

        def my_getter(next_getter, value, context):
            assert value is base.DO_NOT_STORE
            return next_getter(
                context.original_init(context.original_shape,
                                      context.original_dtype))

        def my_setter(next_setter, value, context):
            del next_setter, value, context
            return base.DO_NOT_STORE

        with base.new_context() as ctx:
            with base.custom_creator(my_creator, state=True), \
                 base.custom_getter(my_getter, state=True), \
                 base.custom_setter(my_setter):
                self.assertEqual(base.get_parameter("w", [], init=jnp.ones), 1)
                self.assertEqual(base.get_state("s1", [], init=jnp.ones), 1)
                base.set_state("s2", jnp.ones([]))

        self.assertEmpty(ctx.collect_params())
        self.assertEmpty(ctx.collect_state())
Exemplo n.º 3
0
    def __call__(self, value, update_stats=True, error_on_non_matrix=False):
        """Performs Spectral Normalization and returns the new value.

    Args:
      value: The array-like object for which you would like to perform an
        spectral normalization on.
      update_stats: A boolean defaulting to True. Regardless of this arg, this
        function will return the normalized input. When
        `update_stats` is True, the internal state of this object will also be
        updated to reflect the input value. When `update_stats` is False the
        internal stats will remain unchanged.
      error_on_non_matrix: Spectral normalization is only defined on matrices.
        By default, this module will return scalars unchanged and flatten
        higher-order tensors in their leading dimensions. Setting this flag to
        True will instead throw errors in those cases.
    Returns:
      The input value normalized by it's first singular value.
    Raises:
      ValueError: If `error_on_non_matrix` is True and `value` has ndims > 2.
    """
        value = jnp.asarray(value)
        value_shape = value.shape

        # Handle scalars.
        if value.ndim <= 1:
            raise ValueError("Spectral normalization is not well defined for "
                             "scalar or vector inputs.")
        # Handle higher-order tensors.
        elif value.ndim > 2:
            if error_on_non_matrix:
                raise ValueError(
                    "Input is {}D but error_on_non_matrix is True".format(
                        value.ndim))
            else:
                value = jnp.reshape(value, [-1, value.shape[-1]])

        u0 = base.get_state("u0",
                            shape=[1, value.shape[-1]],
                            dtype=value.dtype,
                            init=initializers.RandomNormal())

        # Power iteration for the weight's singular value.
        for _ in range(self._n_steps):
            v0 = _l2_normalize(jnp.matmul(u0, value.transpose([1, 0])),
                               eps=self._eps)
            u0 = _l2_normalize(jnp.matmul(v0, value), eps=self._eps)

        u0 = jax.lax.stop_gradient(u0)
        v0 = jax.lax.stop_gradient(v0)

        sigma = jnp.matmul(jnp.matmul(v0, value), jnp.transpose(u0))[0, 0]

        value /= sigma
        value_bar = value.reshape(value_shape)

        if update_stats:
            base.set_state("u0", u0)
            base.set_state("sigma", sigma)
        return value_bar
Exemplo n.º 4
0
    def test_stateful(self):
        with base.new_context() as ctx:
            for _ in range(10):
                count = base.get_state("count", (), jnp.int32, jnp.zeros)
                base.set_state("count", count + 1)

        self.assertEqual(ctx.collect_initial_state(), {"~": {"count": 0}})
        self.assertEqual(ctx.collect_state(), {"~": {"count": 10}})
Exemplo n.º 5
0
 def __call__(self, x):
     assert x.ndim == 0
     p = base.get_parameter("p", [],
                            jnp.int32,
                            init=lambda *_: jnp.array(2))
     y = x**p
     base.set_state("y", y)
     return y
Exemplo n.º 6
0
 def test_context_copies_input(self):
     before = {"~": {"w": jnp.array(1.)}}
     with base.new_context(params=before, state=before) as ctx:
         base.get_parameter("w", [], init=jnp.ones)
         base.set_state("w", jnp.array(2.))
     self.assertEqual(ctx.collect_params(), {"~": {"w": jnp.array(1.)}})
     self.assertIsNot(ctx.collect_initial_state(), before)
     self.assertEqual(ctx.collect_initial_state(), before)
     self.assertEqual(ctx.collect_state(), {"~": {"w": jnp.array(2.)}})
     self.assertEqual(before, {"~": {"w": jnp.array(1.)}})
Exemplo n.º 7
0
 def test_difference_update_state(self):
   base.get_state("a", [], init=jnp.zeros)
   base.get_state("b", [], init=jnp.zeros)
   before = stateful.internal_state()
   base.set_state("b", jnp.ones([]))
   after = stateful.internal_state()
   diff = stateful.difference(before, after)
   self.assertEmpty(diff.params)
   self.assertEqual(diff.state, {"~": {"a": None,
                                       "b": base.StatePair(0., 1.)}})
   self.assertIsNone(diff.rng)
Exemplo n.º 8
0
    def test_set_then_get(self):
        with base.new_context() as ctx:
            base.set_state("i", 1)
            base.get_state("i")

        self.assertEqual(ctx.collect_initial_state(), {"~": {"i": 1}})

        for _ in range(10):
            with ctx:
                base.set_state("i", 1)
                y = base.get_state("i")
                self.assertEqual(y, 1)
            self.assertEqual(ctx.collect_initial_state(), {"~": {"i": 1}})
Exemplo n.º 9
0
 def test_without_state_raises_if_state_used_on_apply(self):
     f = lambda: base.set_state("~", 1)
     f = transform.without_state(transform.transform_with_state(f))
     rng = jax.random.PRNGKey(42)
     with self.assertRaisesRegex(ValueError, "use.*transform_with_state"):
         params = f.init(rng)
         f.apply(params, rng)
Exemplo n.º 10
0
    def test_setter_array(self):
        witness = []
        x = jnp.ones([])
        y = x + 1

        def my_setter(next_setter, value, context):
            self.assertIs(value, x)
            self.assertEqual(context.original_shape, value.shape)
            self.assertEqual(context.original_dtype, value.dtype)
            self.assertEqual(context.full_name, "~/x")
            self.assertEqual(context.name, "x")
            self.assertIsNone(context.module)
            witness.append(None)
            del next_setter
            return y

        with base.new_context():
            with base.custom_setter(my_setter):
                base.set_state("x", x)
                x = base.get_state("x")
                self.assertIs(x, y)

        self.assertNotEmpty(witness)
Exemplo n.º 11
0
    def __call__(self, value, update_stats=True):
        """Updates the EMA and returns the new value.

    Args:
      value: The array-like object for which you would like to perform an
        exponential decay on.
      update_stats: A Boolean, whether to update the internal state
        of this object to reflect the input value. When `update_stats` is False
        the internal stats will remain unchanged.

    Returns:
      The exponentially weighted average of the input value.
    """
        if not isinstance(value, jnp.ndarray):
            value = jnp.asarray(value)

        counter = base.get_state(
            "counter", (),
            jnp.int32,
            init=initializers.Constant(-self._warmup_length))
        counter += 1

        decay = jax.lax.convert_element_type(self._decay, value.dtype)
        if self._warmup_length > 0:
            decay = self._cond(counter <= 0, 0.0, decay, value.dtype)

        one = jnp.ones([], value.dtype)
        hidden = base.get_state("hidden",
                                value.shape,
                                value.dtype,
                                init=jnp.zeros)
        hidden = hidden * decay + value * (one - decay)

        average = hidden
        if self._zero_debias:
            average /= (one - jnp.power(decay, counter))

        if update_stats:
            base.set_state("counter", counter)
            base.set_state("hidden", hidden)
            base.set_state("average", average)

        return average
Exemplo n.º 12
0
    def __call__(self, value, update_stats=True):
        """Updates the EMA and returns the new value.

    Args:
      value: The array-like object for which you would like to perform an
        exponential decay on.
      update_stats: A Boolean, whether to update the internal state
        of this object to reflect the input value. When `update_stats` is False
        the internal stats will remain unchanged.

    Returns:
      The exponentially weighted average of the input value.

    """
        value = jnp.asarray(value)  # Ensure value has a dtype.
        prev_counter = base.get_state(
            "counter",
            shape=(),
            dtype=jnp.int32,
            init=initializers.Constant(-self._warmup_length))
        prev_hidden = base.get_state("hidden",
                                     shape=value.shape,
                                     dtype=value.dtype,
                                     init=jnp.zeros)

        decay = jnp.asarray(self._decay).astype(value.dtype)
        counter = prev_counter + 1
        decay = self._cond(jnp.less_equal(counter, 0), 0.0, decay, value.dtype)
        hidden = prev_hidden * decay + value * (1 - decay)

        if self._zero_debias:
            average = hidden / (1. - jnp.power(decay, counter))
        else:
            average = hidden

        if update_stats:
            base.set_state("counter", counter)
            base.set_state("hidden", hidden)
            base.set_state("average", average)
        return average
Exemplo n.º 13
0
 def __call__(self, x):
   y = self.op(x)
   base.set_state("count", self.count + 1)
   return y
Exemplo n.º 14
0
 def net():
     base.set_state("i", 1)
     return base.get_state("i")
Exemplo n.º 15
0
    def test_new_state_in_apply(self):
        with base.new_context(params={}, state={}) as ctx:
            base.set_state("count", 1)

        self.assertEqual(ctx.collect_initial_state(), {"~": {"count": 1}})
        self.assertEqual(ctx.collect_state(), {"~": {"count": 1}})
Exemplo n.º 16
0
                                        state=True)

identity_carry = lambda f: lambda carry, x: (carry, f(x))
ignore_index = lambda f: lambda i, x: f(x)


def with_rng_example():
    with base.with_rng(jax.random.PRNGKey(42)):
        pass


# Methods in Haiku that mutate internal state.
SIDE_EFFECTING_FUNCTIONS = (
    ("get_parameter", lambda: base.get_parameter("w", [], init=jnp.zeros)),
    ("get_state", lambda: base.get_state("w", [], init=jnp.zeros)),
    ("set_state", lambda: base.set_state("w", 1)),
    ("next_rng_key", base.next_rng_key),
    ("next_rng_keys", lambda: base.next_rng_keys(2)),
    ("reserve_rng_keys", lambda: base.reserve_rng_keys(2)),
    ("with_rng", with_rng_example),
)

# JAX transforms and control flow that need to be aware of Haiku internal
# state to operate unsurprisingly.
# pylint: disable=g-long-lambda
JAX_PURE_EXPECTING_FNS = (
    # Just-in-time compilation.
    ("jit", jax.jit),
    ("make_jaxpr", jax.make_jaxpr),
    ("eval_shape", lambda f: (lambda x: jax.eval_shape(f, x))),
Exemplo n.º 17
0
 def __call__(self, x):
   y = x ** 2
   base.set_state("count", self.count + 1)
   return y
Exemplo n.º 18
0
 def apply():
     s = base.get_state('s')
     base.set_state('s', s + 1)
Exemplo n.º 19
0
 def __call__(self):
     for _ in range(10):
         count = base.get_state("count", (), jnp.int32, jnp.zeros)
         base.set_state("count", count + 1)
     return count
Exemplo n.º 20
0
 def inner():
     w = base.get_state("w", [], init=jnp.zeros)
     w += 1
     base.set_state("w", w)
     return w
Exemplo n.º 21
0
  def __call__(self, inputs, is_training):
    """Connects the module to some inputs.

    Args:
      inputs: Tensor, final dimension must be equal to embedding_dim. All other
        leading dimensions will be flattened and treated as a large batch.
      is_training: boolean, whether this connection is to training data. When
        this is set to False, the internal moving average statistics will not be
        updated.

    Returns:
      dict containing the following keys and values:
        quantize: Tensor containing the quantized version of the input.
        loss: Tensor containing the loss to optimize.
        perplexity: Tensor containing the perplexity of the encodings.
        encodings: Tensor containing the discrete encodings, ie which element
        of the quantized space each input element was mapped to.
        encoding_indices: Tensor containing the discrete encoding indices, ie
        which element of the quantized space each input element was mapped to.
    """
    flat_inputs = jnp.reshape(inputs, [-1, self.embedding_dim])
    embeddings = self.embeddings

    distances = (
        jnp.sum(flat_inputs**2, 1, keepdims=True) -
        2 * jnp.matmul(flat_inputs, embeddings) +
        jnp.sum(embeddings**2, 0, keepdims=True))

    encoding_indices = jnp.argmax(-distances, 1)
    encodings = jax.nn.one_hot(encoding_indices,
                               self.num_embeddings,
                               dtype=distances.dtype)

    # NB: if your code crashes with a reshape error on the line below about a
    # Tensor containing the wrong number of values, then the most likely cause
    # is that the input passed in does not have a final dimension equal to
    # self.embedding_dim. Ideally we would catch this with an Assert but that
    # creates various other problems related to device placement / TPUs.
    encoding_indices = jnp.reshape(encoding_indices, inputs.shape[:-1])
    quantized = self.quantize(encoding_indices)
    e_latent_loss = jnp.mean((jax.lax.stop_gradient(quantized) - inputs)**2)

    if is_training:
      updated_ema_cluster_size = self.ema_cluster_size(
          jnp.sum(encodings, axis=0))

      dw = jnp.matmul(flat_inputs.T, encodings)
      updated_ema_dw = self.ema_dw(dw)

      n = jnp.sum(updated_ema_cluster_size)
      updated_ema_cluster_size = ((updated_ema_cluster_size + self.epsilon) /
                                  (n + self.num_embeddings * self.epsilon) * n)

      normalised_updated_ema_w = (
          updated_ema_dw / jnp.reshape(updated_ema_cluster_size, [1, -1]))

      base.set_state('embeddings', normalised_updated_ema_w)
      loss = self.commitment_cost * e_latent_loss

    else:
      loss = self.commitment_cost * e_latent_loss

    # Straight Through Estimator
    quantized = inputs + jax.lax.stop_gradient(quantized - inputs)
    avg_probs = jnp.mean(encodings, 0)
    perplexity = jnp.exp(-jnp.sum(avg_probs * jnp.log(avg_probs + 1e-10)))

    return {
        'quantize': quantized,
        'loss': loss,
        'perplexity': perplexity,
        'encodings': encodings,
        'encoding_indices': encoding_indices,
        'distances': distances,
    }
Exemplo n.º 22
0
 def init():
     s = base.get_state('s', [], init=jnp.zeros)
     base.set_state('s', s + 1)