def test_setter_tree(self): witness = [] x = {"a": jnp.ones([]), "b": jnp.zeros([123])} y = jax.tree_map(lambda x: x + 1, x) def my_setter(next_setter, value, ctx): self.assertIs(value, x) self.assertEqual(ctx.original_shape, {"a": (), "b": (123, )}) self.assertEqual(ctx.original_dtype, { "a": jnp.float32, "b": jnp.float32 }) self.assertEqual(ctx.full_name, "~/x") self.assertEqual(ctx.name, "x") self.assertIsNone(ctx.module) witness.append(None) del next_setter return y with base.new_context(): with base.custom_setter(my_setter): base.set_state("x", x) x = base.get_state("x") self.assertIs(x, y) self.assertNotEmpty(witness)
def test_do_not_store(self): def my_creator(next_creator, shape, dtype, init, context): del next_creator, shape, dtype, init, context return base.DO_NOT_STORE def my_getter(next_getter, value, context): assert value is base.DO_NOT_STORE return next_getter( context.original_init(context.original_shape, context.original_dtype)) def my_setter(next_setter, value, context): del next_setter, value, context return base.DO_NOT_STORE with base.new_context() as ctx: with base.custom_creator(my_creator, state=True), \ base.custom_getter(my_getter, state=True), \ base.custom_setter(my_setter): self.assertEqual(base.get_parameter("w", [], init=jnp.ones), 1) self.assertEqual(base.get_state("s1", [], init=jnp.ones), 1) base.set_state("s2", jnp.ones([])) self.assertEmpty(ctx.collect_params()) self.assertEmpty(ctx.collect_state())
def __call__(self, value, update_stats=True, error_on_non_matrix=False): """Performs Spectral Normalization and returns the new value. Args: value: The array-like object for which you would like to perform an spectral normalization on. update_stats: A boolean defaulting to True. Regardless of this arg, this function will return the normalized input. When `update_stats` is True, the internal state of this object will also be updated to reflect the input value. When `update_stats` is False the internal stats will remain unchanged. error_on_non_matrix: Spectral normalization is only defined on matrices. By default, this module will return scalars unchanged and flatten higher-order tensors in their leading dimensions. Setting this flag to True will instead throw errors in those cases. Returns: The input value normalized by it's first singular value. Raises: ValueError: If `error_on_non_matrix` is True and `value` has ndims > 2. """ value = jnp.asarray(value) value_shape = value.shape # Handle scalars. if value.ndim <= 1: raise ValueError("Spectral normalization is not well defined for " "scalar or vector inputs.") # Handle higher-order tensors. elif value.ndim > 2: if error_on_non_matrix: raise ValueError( "Input is {}D but error_on_non_matrix is True".format( value.ndim)) else: value = jnp.reshape(value, [-1, value.shape[-1]]) u0 = base.get_state("u0", shape=[1, value.shape[-1]], dtype=value.dtype, init=initializers.RandomNormal()) # Power iteration for the weight's singular value. for _ in range(self._n_steps): v0 = _l2_normalize(jnp.matmul(u0, value.transpose([1, 0])), eps=self._eps) u0 = _l2_normalize(jnp.matmul(v0, value), eps=self._eps) u0 = jax.lax.stop_gradient(u0) v0 = jax.lax.stop_gradient(v0) sigma = jnp.matmul(jnp.matmul(v0, value), jnp.transpose(u0))[0, 0] value /= sigma value_bar = value.reshape(value_shape) if update_stats: base.set_state("u0", u0) base.set_state("sigma", sigma) return value_bar
def test_stateful(self): with base.new_context() as ctx: for _ in range(10): count = base.get_state("count", (), jnp.int32, jnp.zeros) base.set_state("count", count + 1) self.assertEqual(ctx.collect_initial_state(), {"~": {"count": 0}}) self.assertEqual(ctx.collect_state(), {"~": {"count": 10}})
def __call__(self, x): assert x.ndim == 0 p = base.get_parameter("p", [], jnp.int32, init=lambda *_: jnp.array(2)) y = x**p base.set_state("y", y) return y
def test_context_copies_input(self): before = {"~": {"w": jnp.array(1.)}} with base.new_context(params=before, state=before) as ctx: base.get_parameter("w", [], init=jnp.ones) base.set_state("w", jnp.array(2.)) self.assertEqual(ctx.collect_params(), {"~": {"w": jnp.array(1.)}}) self.assertIsNot(ctx.collect_initial_state(), before) self.assertEqual(ctx.collect_initial_state(), before) self.assertEqual(ctx.collect_state(), {"~": {"w": jnp.array(2.)}}) self.assertEqual(before, {"~": {"w": jnp.array(1.)}})
def test_difference_update_state(self): base.get_state("a", [], init=jnp.zeros) base.get_state("b", [], init=jnp.zeros) before = stateful.internal_state() base.set_state("b", jnp.ones([])) after = stateful.internal_state() diff = stateful.difference(before, after) self.assertEmpty(diff.params) self.assertEqual(diff.state, {"~": {"a": None, "b": base.StatePair(0., 1.)}}) self.assertIsNone(diff.rng)
def test_set_then_get(self): with base.new_context() as ctx: base.set_state("i", 1) base.get_state("i") self.assertEqual(ctx.collect_initial_state(), {"~": {"i": 1}}) for _ in range(10): with ctx: base.set_state("i", 1) y = base.get_state("i") self.assertEqual(y, 1) self.assertEqual(ctx.collect_initial_state(), {"~": {"i": 1}})
def test_without_state_raises_if_state_used_on_apply(self): f = lambda: base.set_state("~", 1) f = transform.without_state(transform.transform_with_state(f)) rng = jax.random.PRNGKey(42) with self.assertRaisesRegex(ValueError, "use.*transform_with_state"): params = f.init(rng) f.apply(params, rng)
def test_setter_array(self): witness = [] x = jnp.ones([]) y = x + 1 def my_setter(next_setter, value, context): self.assertIs(value, x) self.assertEqual(context.original_shape, value.shape) self.assertEqual(context.original_dtype, value.dtype) self.assertEqual(context.full_name, "~/x") self.assertEqual(context.name, "x") self.assertIsNone(context.module) witness.append(None) del next_setter return y with base.new_context(): with base.custom_setter(my_setter): base.set_state("x", x) x = base.get_state("x") self.assertIs(x, y) self.assertNotEmpty(witness)
def __call__(self, value, update_stats=True): """Updates the EMA and returns the new value. Args: value: The array-like object for which you would like to perform an exponential decay on. update_stats: A Boolean, whether to update the internal state of this object to reflect the input value. When `update_stats` is False the internal stats will remain unchanged. Returns: The exponentially weighted average of the input value. """ if not isinstance(value, jnp.ndarray): value = jnp.asarray(value) counter = base.get_state( "counter", (), jnp.int32, init=initializers.Constant(-self._warmup_length)) counter += 1 decay = jax.lax.convert_element_type(self._decay, value.dtype) if self._warmup_length > 0: decay = self._cond(counter <= 0, 0.0, decay, value.dtype) one = jnp.ones([], value.dtype) hidden = base.get_state("hidden", value.shape, value.dtype, init=jnp.zeros) hidden = hidden * decay + value * (one - decay) average = hidden if self._zero_debias: average /= (one - jnp.power(decay, counter)) if update_stats: base.set_state("counter", counter) base.set_state("hidden", hidden) base.set_state("average", average) return average
def __call__(self, value, update_stats=True): """Updates the EMA and returns the new value. Args: value: The array-like object for which you would like to perform an exponential decay on. update_stats: A Boolean, whether to update the internal state of this object to reflect the input value. When `update_stats` is False the internal stats will remain unchanged. Returns: The exponentially weighted average of the input value. """ value = jnp.asarray(value) # Ensure value has a dtype. prev_counter = base.get_state( "counter", shape=(), dtype=jnp.int32, init=initializers.Constant(-self._warmup_length)) prev_hidden = base.get_state("hidden", shape=value.shape, dtype=value.dtype, init=jnp.zeros) decay = jnp.asarray(self._decay).astype(value.dtype) counter = prev_counter + 1 decay = self._cond(jnp.less_equal(counter, 0), 0.0, decay, value.dtype) hidden = prev_hidden * decay + value * (1 - decay) if self._zero_debias: average = hidden / (1. - jnp.power(decay, counter)) else: average = hidden if update_stats: base.set_state("counter", counter) base.set_state("hidden", hidden) base.set_state("average", average) return average
def __call__(self, x): y = self.op(x) base.set_state("count", self.count + 1) return y
def net(): base.set_state("i", 1) return base.get_state("i")
def test_new_state_in_apply(self): with base.new_context(params={}, state={}) as ctx: base.set_state("count", 1) self.assertEqual(ctx.collect_initial_state(), {"~": {"count": 1}}) self.assertEqual(ctx.collect_state(), {"~": {"count": 1}})
state=True) identity_carry = lambda f: lambda carry, x: (carry, f(x)) ignore_index = lambda f: lambda i, x: f(x) def with_rng_example(): with base.with_rng(jax.random.PRNGKey(42)): pass # Methods in Haiku that mutate internal state. SIDE_EFFECTING_FUNCTIONS = ( ("get_parameter", lambda: base.get_parameter("w", [], init=jnp.zeros)), ("get_state", lambda: base.get_state("w", [], init=jnp.zeros)), ("set_state", lambda: base.set_state("w", 1)), ("next_rng_key", base.next_rng_key), ("next_rng_keys", lambda: base.next_rng_keys(2)), ("reserve_rng_keys", lambda: base.reserve_rng_keys(2)), ("with_rng", with_rng_example), ) # JAX transforms and control flow that need to be aware of Haiku internal # state to operate unsurprisingly. # pylint: disable=g-long-lambda JAX_PURE_EXPECTING_FNS = ( # Just-in-time compilation. ("jit", jax.jit), ("make_jaxpr", jax.make_jaxpr), ("eval_shape", lambda f: (lambda x: jax.eval_shape(f, x))),
def __call__(self, x): y = x ** 2 base.set_state("count", self.count + 1) return y
def apply(): s = base.get_state('s') base.set_state('s', s + 1)
def __call__(self): for _ in range(10): count = base.get_state("count", (), jnp.int32, jnp.zeros) base.set_state("count", count + 1) return count
def inner(): w = base.get_state("w", [], init=jnp.zeros) w += 1 base.set_state("w", w) return w
def __call__(self, inputs, is_training): """Connects the module to some inputs. Args: inputs: Tensor, final dimension must be equal to embedding_dim. All other leading dimensions will be flattened and treated as a large batch. is_training: boolean, whether this connection is to training data. When this is set to False, the internal moving average statistics will not be updated. Returns: dict containing the following keys and values: quantize: Tensor containing the quantized version of the input. loss: Tensor containing the loss to optimize. perplexity: Tensor containing the perplexity of the encodings. encodings: Tensor containing the discrete encodings, ie which element of the quantized space each input element was mapped to. encoding_indices: Tensor containing the discrete encoding indices, ie which element of the quantized space each input element was mapped to. """ flat_inputs = jnp.reshape(inputs, [-1, self.embedding_dim]) embeddings = self.embeddings distances = ( jnp.sum(flat_inputs**2, 1, keepdims=True) - 2 * jnp.matmul(flat_inputs, embeddings) + jnp.sum(embeddings**2, 0, keepdims=True)) encoding_indices = jnp.argmax(-distances, 1) encodings = jax.nn.one_hot(encoding_indices, self.num_embeddings, dtype=distances.dtype) # NB: if your code crashes with a reshape error on the line below about a # Tensor containing the wrong number of values, then the most likely cause # is that the input passed in does not have a final dimension equal to # self.embedding_dim. Ideally we would catch this with an Assert but that # creates various other problems related to device placement / TPUs. encoding_indices = jnp.reshape(encoding_indices, inputs.shape[:-1]) quantized = self.quantize(encoding_indices) e_latent_loss = jnp.mean((jax.lax.stop_gradient(quantized) - inputs)**2) if is_training: updated_ema_cluster_size = self.ema_cluster_size( jnp.sum(encodings, axis=0)) dw = jnp.matmul(flat_inputs.T, encodings) updated_ema_dw = self.ema_dw(dw) n = jnp.sum(updated_ema_cluster_size) updated_ema_cluster_size = ((updated_ema_cluster_size + self.epsilon) / (n + self.num_embeddings * self.epsilon) * n) normalised_updated_ema_w = ( updated_ema_dw / jnp.reshape(updated_ema_cluster_size, [1, -1])) base.set_state('embeddings', normalised_updated_ema_w) loss = self.commitment_cost * e_latent_loss else: loss = self.commitment_cost * e_latent_loss # Straight Through Estimator quantized = inputs + jax.lax.stop_gradient(quantized - inputs) avg_probs = jnp.mean(encodings, 0) perplexity = jnp.exp(-jnp.sum(avg_probs * jnp.log(avg_probs + 1e-10))) return { 'quantize': quantized, 'loss': loss, 'perplexity': perplexity, 'encodings': encodings, 'encoding_indices': encoding_indices, 'distances': distances, }
def init(): s = base.get_state('s', [], init=jnp.zeros) base.set_state('s', s + 1)