def _update_memory(self, mem: jnp.ndarray, mask: jnp.ndarray,
                    input_length: int, cache_steps: int,
                    should_reset: jnp.ndarray) -> jnp.ndarray:
     """Logic for using and updating cached activations."""
     batch_size = mem.shape[0]
     if cache_steps > 0:
         # Tells us how much of the cache should be used.
         cache_progress_idx = hk.get_state('cache_progress_idx',
                                           [batch_size],
                                           dtype=jnp.int32,
                                           init=jnp.zeros)
         hk.set_state('cache_progress_idx',
                      cache_progress_idx + input_length)
         mem = self._update_cache('mem', mem, cache_steps=cache_steps)
         if mask is None:
             mask = jnp.ones((batch_size, 1, input_length, input_length))
         cache_mask = (jnp.arange(cache_steps - 1, -1, -1)[None, None,
                                                           None, :] <
                       cache_progress_idx[:, None, None, None])
         cache_mask = jnp.broadcast_to(
             cache_mask, (batch_size, 1, input_length, cache_steps))
         mask = jnp.concatenate([cache_mask, mask], axis=-1)
     if should_reset is not None:
         if cache_steps > 0:
             should_reset = self._update_cache('should_reset',
                                               should_reset,
                                               cache_steps=cache_steps)
         reset_mask = get_reset_attention_mask(should_reset)[:, None, :, :]
         mask *= reset_mask[:, :, cache_steps:, :]
     return mem, mask
Esempio n. 2
0
def get_pos_start(timesteps: int, batch_size: int) -> jnp.ndarray:
    """Find the right slice of positional embeddings for incremental sampling."""
    pos_start = hk.get_state('cache_progress_idx', [batch_size],
                             dtype=jnp.int32,
                             init=jnp.zeros)
    hk.set_state('cache_progress_idx', pos_start + timesteps)
    return pos_start
Esempio n. 3
0
    def __call__(self, x: jnp.ndarray):
        g = gram_matrix(x)

        style_loss = jnp.mean(jnp.square(g - self.target_g))
        hk.set_state("style_loss", style_loss)

        return x
Esempio n. 4
0
            def __call__(self, x):

                n = haiku.get_state(
                    "n", shape=[], dtype=jnp.int32, init=lambda *args: np.array(0)
                )
                w = haiku.get_parameter("w", [], init=lambda *args: np.array(2.0))

                haiku.set_state("n", n + 1)

                return x * w
Esempio n. 5
0
  def upsample(self, x, durations, L):
    ruler = jnp.arange(0, L)[None, :]  # B, L
    end_pos = jnp.cumsum(durations, axis=1)
    mid_pos = end_pos - durations/2  # B, T

    d2 = jnp.square((mid_pos[:, None, :] - ruler[:, :, None])) / 10.
    w = jax.nn.softmax(-d2, axis=-1)
    hk.set_state('attn', w)
    x = jnp.einsum('BLT,BTD->BLD', w, x)
    return x
Esempio n. 6
0
def noise(a, ep):
    a = jnp.asarray(a)
    shape = env.action_space.shape
    mu, sigma, theta = hparams.exploration_mu, hparams.exploration_sigma, hparams.exploration_theta
    scale = hk.get_state('scale', shape=(), dtype=a.dtype, init=jnp.ones)
    noise = hk.get_state('noise', shape=a.shape, dtype=a.dtype, init=jnp.zeros)
    scale = scale * hparams.noise_decay
    noise = theta * (mu - noise) + sigma * jax.random.normal(
        hk.next_rng_key(), shape)
    hk.set_state('scale', scale)
    hk.set_state('noise', noise)
    return a + noise * scale
Esempio n. 7
0
def conv_weight_with_spectral_norm(x: jnp.ndarray,
                                   kernel_shape: Sequence[int],
                                   out_channel: int,
                                   name_suffix: str = "",
                                   w_init: Callable = None,
                                   b_init: Callable = None,
                                   use_bias: bool = True,
                                   is_training: bool = True,
                                   update_params: bool = True,
                                   max_singular_value: float = 0.95,
                                   max_power_iters: int = 1,
                                   **conv_kwargs):
    batch_size, H, W, C = x.shape
    w_shape = kernel_shape + (C, out_channel)

    w = hk.get_parameter(f"w_{name_suffix}", w_shape, x.dtype, init=w_init)
    if use_bias:
        b = hk.get_parameter(f"b_{name_suffix}", (out_channel, ), init=b_init)

    u = hk.get_state(f"u_{name_suffix}", (H, W, out_channel),
                     init=hk.initializers.RandomNormal())
    v = hk.get_state(f"v_{name_suffix}", (H, W, C),
                     init=hk.initializers.RandomNormal())
    w, u, v = sn.spectral_norm_conv_apply(w, u, v, conv_kwargs["stride"],
                                          conv_kwargs["padding"],
                                          max_singular_value, max_power_iters,
                                          update_params)

    # Run for a lot of steps when we're first initializing
    running_init_fn = not hk_base.params_frozen()
    if running_init_fn:
        w, u, v = sn.spectral_norm_conv_apply(w, u, v, conv_kwargs["stride"],
                                              conv_kwargs["padding"],
                                              max_singular_value, None, True)

    if is_training == True or running_init_fn:
        hk.set_state(f"u_{name_suffix}", u)
        hk.set_state(f"v_{name_suffix}", v)

    if use_bias:
        b = hk.get_parameter(f"b_{name_suffix}", (out_channel, ),
                             x.dtype,
                             init=b_init)

    if use_bias:
        return w, b
    return w
Esempio n. 8
0
    def call(self,
             values: jnp.ndarray,
             sample_weight: tp.Optional[jnp.ndarray] = None) -> jnp.ndarray:
        """
        Accumulates statistics for computing the reduction metric. For example, if `values` is [1, 3, 5, 7] 
        and reduction=SUM_OVER_BATCH_SIZE, then the value of `result()` is 4. If the `sample_weight` 
        is specified as [1, 1, 0, 0] then value of `result()` would be 2.
        
        Arguments:
            values: Per-example value.
            sample_weight: Optional weighting of each example. Defaults to 1.
        
        Returns:
            Array with the cummulative reduce.
        """
        total = hk.get_state("total",
                             shape=[],
                             dtype=self._dtype,
                             init=hk.initializers.Constant(0))

        if self._reduction in (
                Reduction.SUM_OVER_BATCH_SIZE,
                Reduction.WEIGHTED_MEAN,
        ):
            count = hk.get_state("count",
                                 shape=[],
                                 dtype=jnp.int32,
                                 init=hk.initializers.Constant(0))
        else:
            count = None

        value, total, count = reduce(
            total=total,
            count=count,
            values=values,
            reduction=self._reduction,
            sample_weight=sample_weight,
            dtype=self._dtype,
        )

        hk.set_state("total", total)

        if count is not None:
            hk.set_state("count", count)

        return value
Esempio n. 9
0
def weight_with_spectral_norm(x: jnp.ndarray,
                              out_dim: int,
                              name_suffix: str = "",
                              w_init: Callable = None,
                              b_init: Callable = None,
                              is_training: bool = True,
                              update_params: bool = True,
                              use_bias: bool = True,
                              force_in_dim: Optional = None,
                              max_singular_value: float = 0.99,
                              max_power_iters: int = 1,
                              **kwargs):
    in_dim, dtype = x.shape[-1], x.dtype
    if force_in_dim:
        in_dim = force_in_dim

    w = hk.get_parameter(f"w_{name_suffix}", (out_dim, in_dim),
                         dtype,
                         init=w_init)
    if use_bias:
        b = hk.get_parameter(f"b_{name_suffix}", (out_dim, ),
                             dtype,
                             init=b_init)

    u = hk.get_state(f"u_{name_suffix}", (out_dim, ),
                     dtype,
                     init=hk.initializers.RandomNormal())
    v = hk.get_state(f"v_{name_suffix}", (in_dim, ),
                     dtype,
                     init=hk.initializers.RandomNormal())
    w, u, v = sn.spectral_norm_apply(w, u, v, max_singular_value,
                                     max_power_iters, update_params)

    running_init_fn = not hk_base.params_frozen()
    if running_init_fn:
        w, u, v = sn.spectral_norm_apply(w, u, v, max_singular_value, None,
                                         True)

    if is_training == True or running_init_fn:
        hk.set_state(f"u_{name_suffix}", u)
        hk.set_state(f"v_{name_suffix}", v)

    if use_bias:
        return w, b
    return w
Esempio n. 10
0
  def get_normalized_weights(self,
                             weights: jnp.ndarray,
                             renormalize: bool = False) -> jnp.ndarray:

    def _l2_normalize(x, axis=None, eps=1e-12):
      return x * jax.lax.rsqrt((x * x).sum(axis=axis, keepdims=True) + eps)

    output_size = self.output_size
    dtype = weights.dtype
    assert output_size == weights.shape[-1]
    sigma = hk.get_state('sigma', (), init=jnp.ones)
    if renormalize:
      # Power iterations to compute spectral norm V*W*U^T.
      u = hk.get_state(
          'u', (1, output_size), dtype, init=hk.initializers.RandomNormal())
      for _ in range(self.num_iterations):
        v = _l2_normalize(jnp.matmul(u, weights.transpose()), eps=self.eps)
        u = _l2_normalize(jnp.matmul(v, weights), eps=self.eps)
      u = jax.lax.stop_gradient(u)
      v = jax.lax.stop_gradient(v)
      sigma = jnp.matmul(jnp.matmul(v, weights), jnp.transpose(u))[0, 0]
      hk.set_state('u', u)
      hk.set_state('v', v)
      hk.set_state('sigma', sigma)
    factor = jnp.maximum(1, sigma / self.lipschitz_coeff)
    return weights / factor
    def _update_cache(self,
                      key: jnp.ndarray,
                      value: jnp.ndarray,
                      cache_steps: Optional[int] = None,
                      axis: int = 1) -> jnp.ndarray:
        """Update the cache stored in hk.state."""
        cache_shape = list(value.shape)
        value_steps = cache_shape[axis]
        if cache_steps is not None:
            cache_shape[axis] += cache_steps
        cache = hk.get_state(key,
                             shape=cache_shape,
                             dtype=value.dtype,
                             init=jnp.zeros)

        # Overwrite at index 0, then rotate timesteps left so what was just
        # inserted is first.
        value = jax.lax.dynamic_update_slice(
            cache, value, jnp.zeros(len(cache_shape), dtype=jnp.int32))
        value = jnp.roll(value, -value_steps, axis)
        hk.set_state(key, value)
        return value
Esempio n. 12
0
def conv_weight_with_spectral_norm(x: jnp.ndarray,
                                   kernel_shape: Sequence[int],
                                   out_channel: int,
                                   name_suffix: str = "",
                                   w_init: Callable = None,
                                   b_init: Callable = None,
                                   use_bias: bool = True,
                                   is_training: bool = True,
                                   **conv_kwargs):
    batch_size, H, W, C = x.shape
    w_shape = kernel_shape + (C, out_channel)

    def w_init_whiten(shape, dtype):
        w = w_init(shape, dtype)
        return w * 0.7

    w = hk.get_parameter(f"w_{name_suffix}",
                         w_shape,
                         x.dtype,
                         init=w_init_whiten)
    if use_bias:
        b = hk.get_parameter(f"b_{name_suffix}", (out_channel, ), init=b_init)

    u = hk.get_state(f"u_{name_suffix}",
                     kernel_shape + (out_channel, ),
                     init=hk.initializers.RandomNormal())
    w, u = sn.spectral_norm_conv_apply(w, u, conv_kwargs["stride"],
                                       conv_kwargs["padding"], 0.9, 1)
    if is_training == True:
        hk.set_state(f"u_{name_suffix}", u)

    if use_bias:
        b = hk.get_parameter(f"b_{name_suffix}", (out_channel, ),
                             x.dtype,
                             init=b_init)

    if use_bias:
        return w, b
    return w
Esempio n. 13
0
def weight_with_spectral_norm(x: jnp.ndarray,
                              out_dim: int,
                              name_suffix: str = "",
                              w_init: Callable = None,
                              b_init: Callable = None,
                              is_training: bool = True,
                              update_params: bool = True,
                              use_bias: bool = True,
                              **kwargs):
    in_dim, dtype = x.shape[-1], x.dtype

    def w_init_whiten(shape, dtype):
        w = w_init(shape, dtype)
        return util.whiten(w) * 0.9

    w = hk.get_parameter(f"w_{name_suffix}", (out_dim, in_dim),
                         dtype,
                         init=w_init_whiten)
    if use_bias:
        b = hk.get_parameter(f"b_{name_suffix}", (out_dim, ),
                             dtype,
                             init=b_init)

    u = hk.get_state(f"u_{name_suffix}", (out_dim, ),
                     dtype,
                     init=hk.initializers.RandomNormal())
    v = hk.get_state(f"v_{name_suffix}", (in_dim, ),
                     dtype,
                     init=hk.initializers.RandomNormal())
    w, u, v = sn.spectral_norm_apply(w, u, v, 0.99, 5, update_params)
    if is_training == True:
        hk.set_state(f"u_{name_suffix}", u)
        hk.set_state(f"v_{name_suffix}", v)

    if use_bias:
        return w, b
    return w
Esempio n. 14
0
  def __call__(self, sample: jnp.DeviceArray) -> Tuple[Array, Array]:
    if len(sample.shape) > 1:
      raise ValueError("sample must be a rank 0 or 1 DeviceArray.")

    count = hk.get_state("count", shape=(), dtype=jnp.int32, init=jnp.zeros)
    mean = hk.get_state(
        "mean", shape=sample.shape, dtype=jnp.float32, init=jnp.zeros)
    m2 = hk.get_state(
        "m2", shape=sample.shape, dtype=jnp.float32, init=jnp.zeros)

    count += 1
    delta = sample - mean
    mean += delta / count
    delta_2 = sample - mean
    m2 += delta * delta_2

    hk.set_state("count", count)
    hk.set_state("mean", mean)
    hk.set_state("m2", m2)

    stddev = jnp.sqrt(m2 / count)
    return mean, stddev
Esempio n. 15
0
  def __call__(
      self,
      value,
      update_stats: bool = True,
      error_on_non_matrix: bool = False,
  ) -> jnp.ndarray:
    """Performs Spectral Normalization and returns the new value.
    Args:
      value: The array-like object for which you would like to perform an
        spectral normalization on.
      update_stats: A boolean defaulting to True. Regardless of this arg, this
        function will return the normalized input. When
        `update_stats` is True, the internal state of this object will also be
        updated to reflect the input value. When `update_stats` is False the
        internal stats will remain unchanged.
      error_on_non_matrix: Spectral normalization is only defined on matrices.
        By default, this module will return scalars unchanged and flatten
        higher-order tensors in their leading dimensions. Setting this flag to
        True will instead throw errors in those cases.
    Returns:
      The input value normalized by it's first singular value.
    Raises:
      ValueError: If `error_on_non_matrix` is True and `value` has ndims > 2.
    """
    value = jnp.asarray(value)
    value_shape = value.shape

    # Handle scalars.
    if value.ndim <= 1:
      raise ValueError("Spectral normalization is not well defined for "
                       "scalar or vector inputs.")
    # Handle higher-order tensors.
    elif value.ndim > 2:
      if error_on_non_matrix:
        raise ValueError(
            f"Input is {value.ndim}D but error_on_non_matrix is True")
      else:
        value = jnp.reshape(value, [-1, value.shape[-1]])

    u0 = hk.get_state("u0", [1, value.shape[-1]], value.dtype,
                      init=hk.initializers.RandomNormal())

    # Power iteration for the weight's singular value.
    for _ in range(self.n_steps):
      v0 = _l2_normalize(jnp.matmul(u0, value.transpose([1, 0])), eps=self.eps)
      u0 = _l2_normalize(jnp.matmul(v0, value), eps=self.eps)

    u0 = jax.lax.stop_gradient(u0)
    v0 = jax.lax.stop_gradient(v0)

    sigma = jnp.matmul(jnp.matmul(v0, value), jnp.transpose(u0))[0, 0]

    value /= sigma
    value *= self.val
    value_bar = value.reshape(value_shape)

    if update_stats:
      hk.set_state("u0", u0)
      hk.set_state("sigma", sigma)

    return value_bar
Esempio n. 16
0
def apply_sn(*,
             mvp,
             mvpT,
             w_shape,
             b_shape,
             out_shape,
             dtype,
             w_init,
             b_init,
             name_suffix,
             is_training,
             use_bias,
             max_singular_value,
             max_power_iters,
             use_proximal_gradient=False,
             monitor_progress=False,
             monitor_iters=20,
             return_sigma=False,
             **kwargs):

    w_exists = util.check_if_parameter_exists(f"w_{name_suffix}")

    w = hk.get_parameter(f"w_{name_suffix}", w_shape, dtype, init=w_init)
    u = hk.get_state(f"u_{name_suffix}",
                     out_shape,
                     dtype,
                     init=hk.initializers.RandomNormal())
    if use_proximal_gradient == False:
        zeta = hk.get_state(f"zeta_{name_suffix}",
                            out_shape,
                            dtype,
                            init=hk.initializers.RandomNormal())
        state = (u, zeta)
    else:
        state = (u, )

    if use_proximal_gradient == False:
        estimate_max_singular_value = jax.jit(sn.max_singular_value,
                                              static_argnums=(0, 1))
    else:
        estimate_max_singular_value = jax.jit(sn.max_singular_value_no_grad,
                                              static_argnums=(0, 1))

    if w_exists == False:
        max_power_iters = 1000

    if monitor_progress:
        estimates = []

    for i in range(max_power_iters):
        sigma, *state = estimate_max_singular_value(mvp, mvpT, w, *state)
        if monitor_progress:
            estimates.append(sigma)

    if monitor_progress:
        sigma_for_test = sigma
        state_for_test = state
        for i in range(monitor_iters - max_power_iters):
            sigma_for_test, *state_for_test = estimate_max_singular_value(
                mvp, mvpT, w, *state_for_test)
            estimates.append(sigma_for_test)

        estimates = jnp.array(estimates)

        sigma_for_test = jax.lax.stop_gradient(sigma_for_test)
        state_for_test = jax.lax.stop_gradient(state_for_test)

    state = jax.lax.stop_gradient(state)

    if is_training == True or w_exists == False:
        u = state[0]
        hk.set_state(f"u_{name_suffix}", u)
        if use_proximal_gradient == False:
            zeta = state[1]
            hk.set_state(f"zeta_{name_suffix}", zeta)

    if return_sigma == False:
        factor = jnp.where(max_singular_value < sigma,
                           max_singular_value / sigma, 1.0)
        w = w * factor
        w_ret = w
    else:
        w_ret = (w, sigma)

    if use_bias:
        b = hk.get_parameter(f"b_{name_suffix}", b_shape, dtype, init=b_init)
        ret = (w_ret, b)
    else:
        ret = w_ret

    if monitor_progress:
        ret = (ret, estimates)

    return ret
Esempio n. 17
0
    def __call__(self, x: jnp.ndarray):
        content_loss = jnp.mean(jnp.square(x - self.target))
        hk.set_state("content_loss", content_loss)

        return x
Esempio n. 18
0
def set_state_tree(name, val):
  return hk.set_state(name, Box(val))