def _update_memory(self, mem: jnp.ndarray, mask: jnp.ndarray, input_length: int, cache_steps: int, should_reset: jnp.ndarray) -> jnp.ndarray: """Logic for using and updating cached activations.""" batch_size = mem.shape[0] if cache_steps > 0: # Tells us how much of the cache should be used. cache_progress_idx = hk.get_state('cache_progress_idx', [batch_size], dtype=jnp.int32, init=jnp.zeros) hk.set_state('cache_progress_idx', cache_progress_idx + input_length) mem = self._update_cache('mem', mem, cache_steps=cache_steps) if mask is None: mask = jnp.ones((batch_size, 1, input_length, input_length)) cache_mask = (jnp.arange(cache_steps - 1, -1, -1)[None, None, None, :] < cache_progress_idx[:, None, None, None]) cache_mask = jnp.broadcast_to( cache_mask, (batch_size, 1, input_length, cache_steps)) mask = jnp.concatenate([cache_mask, mask], axis=-1) if should_reset is not None: if cache_steps > 0: should_reset = self._update_cache('should_reset', should_reset, cache_steps=cache_steps) reset_mask = get_reset_attention_mask(should_reset)[:, None, :, :] mask *= reset_mask[:, :, cache_steps:, :] return mem, mask
def get_pos_start(timesteps: int, batch_size: int) -> jnp.ndarray: """Find the right slice of positional embeddings for incremental sampling.""" pos_start = hk.get_state('cache_progress_idx', [batch_size], dtype=jnp.int32, init=jnp.zeros) hk.set_state('cache_progress_idx', pos_start + timesteps) return pos_start
def __call__(self, x: jnp.ndarray): g = gram_matrix(x) style_loss = jnp.mean(jnp.square(g - self.target_g)) hk.set_state("style_loss", style_loss) return x
def __call__(self, x): n = haiku.get_state( "n", shape=[], dtype=jnp.int32, init=lambda *args: np.array(0) ) w = haiku.get_parameter("w", [], init=lambda *args: np.array(2.0)) haiku.set_state("n", n + 1) return x * w
def upsample(self, x, durations, L): ruler = jnp.arange(0, L)[None, :] # B, L end_pos = jnp.cumsum(durations, axis=1) mid_pos = end_pos - durations/2 # B, T d2 = jnp.square((mid_pos[:, None, :] - ruler[:, :, None])) / 10. w = jax.nn.softmax(-d2, axis=-1) hk.set_state('attn', w) x = jnp.einsum('BLT,BTD->BLD', w, x) return x
def noise(a, ep): a = jnp.asarray(a) shape = env.action_space.shape mu, sigma, theta = hparams.exploration_mu, hparams.exploration_sigma, hparams.exploration_theta scale = hk.get_state('scale', shape=(), dtype=a.dtype, init=jnp.ones) noise = hk.get_state('noise', shape=a.shape, dtype=a.dtype, init=jnp.zeros) scale = scale * hparams.noise_decay noise = theta * (mu - noise) + sigma * jax.random.normal( hk.next_rng_key(), shape) hk.set_state('scale', scale) hk.set_state('noise', noise) return a + noise * scale
def conv_weight_with_spectral_norm(x: jnp.ndarray, kernel_shape: Sequence[int], out_channel: int, name_suffix: str = "", w_init: Callable = None, b_init: Callable = None, use_bias: bool = True, is_training: bool = True, update_params: bool = True, max_singular_value: float = 0.95, max_power_iters: int = 1, **conv_kwargs): batch_size, H, W, C = x.shape w_shape = kernel_shape + (C, out_channel) w = hk.get_parameter(f"w_{name_suffix}", w_shape, x.dtype, init=w_init) if use_bias: b = hk.get_parameter(f"b_{name_suffix}", (out_channel, ), init=b_init) u = hk.get_state(f"u_{name_suffix}", (H, W, out_channel), init=hk.initializers.RandomNormal()) v = hk.get_state(f"v_{name_suffix}", (H, W, C), init=hk.initializers.RandomNormal()) w, u, v = sn.spectral_norm_conv_apply(w, u, v, conv_kwargs["stride"], conv_kwargs["padding"], max_singular_value, max_power_iters, update_params) # Run for a lot of steps when we're first initializing running_init_fn = not hk_base.params_frozen() if running_init_fn: w, u, v = sn.spectral_norm_conv_apply(w, u, v, conv_kwargs["stride"], conv_kwargs["padding"], max_singular_value, None, True) if is_training == True or running_init_fn: hk.set_state(f"u_{name_suffix}", u) hk.set_state(f"v_{name_suffix}", v) if use_bias: b = hk.get_parameter(f"b_{name_suffix}", (out_channel, ), x.dtype, init=b_init) if use_bias: return w, b return w
def call(self, values: jnp.ndarray, sample_weight: tp.Optional[jnp.ndarray] = None) -> jnp.ndarray: """ Accumulates statistics for computing the reduction metric. For example, if `values` is [1, 3, 5, 7] and reduction=SUM_OVER_BATCH_SIZE, then the value of `result()` is 4. If the `sample_weight` is specified as [1, 1, 0, 0] then value of `result()` would be 2. Arguments: values: Per-example value. sample_weight: Optional weighting of each example. Defaults to 1. Returns: Array with the cummulative reduce. """ total = hk.get_state("total", shape=[], dtype=self._dtype, init=hk.initializers.Constant(0)) if self._reduction in ( Reduction.SUM_OVER_BATCH_SIZE, Reduction.WEIGHTED_MEAN, ): count = hk.get_state("count", shape=[], dtype=jnp.int32, init=hk.initializers.Constant(0)) else: count = None value, total, count = reduce( total=total, count=count, values=values, reduction=self._reduction, sample_weight=sample_weight, dtype=self._dtype, ) hk.set_state("total", total) if count is not None: hk.set_state("count", count) return value
def weight_with_spectral_norm(x: jnp.ndarray, out_dim: int, name_suffix: str = "", w_init: Callable = None, b_init: Callable = None, is_training: bool = True, update_params: bool = True, use_bias: bool = True, force_in_dim: Optional = None, max_singular_value: float = 0.99, max_power_iters: int = 1, **kwargs): in_dim, dtype = x.shape[-1], x.dtype if force_in_dim: in_dim = force_in_dim w = hk.get_parameter(f"w_{name_suffix}", (out_dim, in_dim), dtype, init=w_init) if use_bias: b = hk.get_parameter(f"b_{name_suffix}", (out_dim, ), dtype, init=b_init) u = hk.get_state(f"u_{name_suffix}", (out_dim, ), dtype, init=hk.initializers.RandomNormal()) v = hk.get_state(f"v_{name_suffix}", (in_dim, ), dtype, init=hk.initializers.RandomNormal()) w, u, v = sn.spectral_norm_apply(w, u, v, max_singular_value, max_power_iters, update_params) running_init_fn = not hk_base.params_frozen() if running_init_fn: w, u, v = sn.spectral_norm_apply(w, u, v, max_singular_value, None, True) if is_training == True or running_init_fn: hk.set_state(f"u_{name_suffix}", u) hk.set_state(f"v_{name_suffix}", v) if use_bias: return w, b return w
def get_normalized_weights(self, weights: jnp.ndarray, renormalize: bool = False) -> jnp.ndarray: def _l2_normalize(x, axis=None, eps=1e-12): return x * jax.lax.rsqrt((x * x).sum(axis=axis, keepdims=True) + eps) output_size = self.output_size dtype = weights.dtype assert output_size == weights.shape[-1] sigma = hk.get_state('sigma', (), init=jnp.ones) if renormalize: # Power iterations to compute spectral norm V*W*U^T. u = hk.get_state( 'u', (1, output_size), dtype, init=hk.initializers.RandomNormal()) for _ in range(self.num_iterations): v = _l2_normalize(jnp.matmul(u, weights.transpose()), eps=self.eps) u = _l2_normalize(jnp.matmul(v, weights), eps=self.eps) u = jax.lax.stop_gradient(u) v = jax.lax.stop_gradient(v) sigma = jnp.matmul(jnp.matmul(v, weights), jnp.transpose(u))[0, 0] hk.set_state('u', u) hk.set_state('v', v) hk.set_state('sigma', sigma) factor = jnp.maximum(1, sigma / self.lipschitz_coeff) return weights / factor
def _update_cache(self, key: jnp.ndarray, value: jnp.ndarray, cache_steps: Optional[int] = None, axis: int = 1) -> jnp.ndarray: """Update the cache stored in hk.state.""" cache_shape = list(value.shape) value_steps = cache_shape[axis] if cache_steps is not None: cache_shape[axis] += cache_steps cache = hk.get_state(key, shape=cache_shape, dtype=value.dtype, init=jnp.zeros) # Overwrite at index 0, then rotate timesteps left so what was just # inserted is first. value = jax.lax.dynamic_update_slice( cache, value, jnp.zeros(len(cache_shape), dtype=jnp.int32)) value = jnp.roll(value, -value_steps, axis) hk.set_state(key, value) return value
def conv_weight_with_spectral_norm(x: jnp.ndarray, kernel_shape: Sequence[int], out_channel: int, name_suffix: str = "", w_init: Callable = None, b_init: Callable = None, use_bias: bool = True, is_training: bool = True, **conv_kwargs): batch_size, H, W, C = x.shape w_shape = kernel_shape + (C, out_channel) def w_init_whiten(shape, dtype): w = w_init(shape, dtype) return w * 0.7 w = hk.get_parameter(f"w_{name_suffix}", w_shape, x.dtype, init=w_init_whiten) if use_bias: b = hk.get_parameter(f"b_{name_suffix}", (out_channel, ), init=b_init) u = hk.get_state(f"u_{name_suffix}", kernel_shape + (out_channel, ), init=hk.initializers.RandomNormal()) w, u = sn.spectral_norm_conv_apply(w, u, conv_kwargs["stride"], conv_kwargs["padding"], 0.9, 1) if is_training == True: hk.set_state(f"u_{name_suffix}", u) if use_bias: b = hk.get_parameter(f"b_{name_suffix}", (out_channel, ), x.dtype, init=b_init) if use_bias: return w, b return w
def weight_with_spectral_norm(x: jnp.ndarray, out_dim: int, name_suffix: str = "", w_init: Callable = None, b_init: Callable = None, is_training: bool = True, update_params: bool = True, use_bias: bool = True, **kwargs): in_dim, dtype = x.shape[-1], x.dtype def w_init_whiten(shape, dtype): w = w_init(shape, dtype) return util.whiten(w) * 0.9 w = hk.get_parameter(f"w_{name_suffix}", (out_dim, in_dim), dtype, init=w_init_whiten) if use_bias: b = hk.get_parameter(f"b_{name_suffix}", (out_dim, ), dtype, init=b_init) u = hk.get_state(f"u_{name_suffix}", (out_dim, ), dtype, init=hk.initializers.RandomNormal()) v = hk.get_state(f"v_{name_suffix}", (in_dim, ), dtype, init=hk.initializers.RandomNormal()) w, u, v = sn.spectral_norm_apply(w, u, v, 0.99, 5, update_params) if is_training == True: hk.set_state(f"u_{name_suffix}", u) hk.set_state(f"v_{name_suffix}", v) if use_bias: return w, b return w
def __call__(self, sample: jnp.DeviceArray) -> Tuple[Array, Array]: if len(sample.shape) > 1: raise ValueError("sample must be a rank 0 or 1 DeviceArray.") count = hk.get_state("count", shape=(), dtype=jnp.int32, init=jnp.zeros) mean = hk.get_state( "mean", shape=sample.shape, dtype=jnp.float32, init=jnp.zeros) m2 = hk.get_state( "m2", shape=sample.shape, dtype=jnp.float32, init=jnp.zeros) count += 1 delta = sample - mean mean += delta / count delta_2 = sample - mean m2 += delta * delta_2 hk.set_state("count", count) hk.set_state("mean", mean) hk.set_state("m2", m2) stddev = jnp.sqrt(m2 / count) return mean, stddev
def __call__( self, value, update_stats: bool = True, error_on_non_matrix: bool = False, ) -> jnp.ndarray: """Performs Spectral Normalization and returns the new value. Args: value: The array-like object for which you would like to perform an spectral normalization on. update_stats: A boolean defaulting to True. Regardless of this arg, this function will return the normalized input. When `update_stats` is True, the internal state of this object will also be updated to reflect the input value. When `update_stats` is False the internal stats will remain unchanged. error_on_non_matrix: Spectral normalization is only defined on matrices. By default, this module will return scalars unchanged and flatten higher-order tensors in their leading dimensions. Setting this flag to True will instead throw errors in those cases. Returns: The input value normalized by it's first singular value. Raises: ValueError: If `error_on_non_matrix` is True and `value` has ndims > 2. """ value = jnp.asarray(value) value_shape = value.shape # Handle scalars. if value.ndim <= 1: raise ValueError("Spectral normalization is not well defined for " "scalar or vector inputs.") # Handle higher-order tensors. elif value.ndim > 2: if error_on_non_matrix: raise ValueError( f"Input is {value.ndim}D but error_on_non_matrix is True") else: value = jnp.reshape(value, [-1, value.shape[-1]]) u0 = hk.get_state("u0", [1, value.shape[-1]], value.dtype, init=hk.initializers.RandomNormal()) # Power iteration for the weight's singular value. for _ in range(self.n_steps): v0 = _l2_normalize(jnp.matmul(u0, value.transpose([1, 0])), eps=self.eps) u0 = _l2_normalize(jnp.matmul(v0, value), eps=self.eps) u0 = jax.lax.stop_gradient(u0) v0 = jax.lax.stop_gradient(v0) sigma = jnp.matmul(jnp.matmul(v0, value), jnp.transpose(u0))[0, 0] value /= sigma value *= self.val value_bar = value.reshape(value_shape) if update_stats: hk.set_state("u0", u0) hk.set_state("sigma", sigma) return value_bar
def apply_sn(*, mvp, mvpT, w_shape, b_shape, out_shape, dtype, w_init, b_init, name_suffix, is_training, use_bias, max_singular_value, max_power_iters, use_proximal_gradient=False, monitor_progress=False, monitor_iters=20, return_sigma=False, **kwargs): w_exists = util.check_if_parameter_exists(f"w_{name_suffix}") w = hk.get_parameter(f"w_{name_suffix}", w_shape, dtype, init=w_init) u = hk.get_state(f"u_{name_suffix}", out_shape, dtype, init=hk.initializers.RandomNormal()) if use_proximal_gradient == False: zeta = hk.get_state(f"zeta_{name_suffix}", out_shape, dtype, init=hk.initializers.RandomNormal()) state = (u, zeta) else: state = (u, ) if use_proximal_gradient == False: estimate_max_singular_value = jax.jit(sn.max_singular_value, static_argnums=(0, 1)) else: estimate_max_singular_value = jax.jit(sn.max_singular_value_no_grad, static_argnums=(0, 1)) if w_exists == False: max_power_iters = 1000 if monitor_progress: estimates = [] for i in range(max_power_iters): sigma, *state = estimate_max_singular_value(mvp, mvpT, w, *state) if monitor_progress: estimates.append(sigma) if monitor_progress: sigma_for_test = sigma state_for_test = state for i in range(monitor_iters - max_power_iters): sigma_for_test, *state_for_test = estimate_max_singular_value( mvp, mvpT, w, *state_for_test) estimates.append(sigma_for_test) estimates = jnp.array(estimates) sigma_for_test = jax.lax.stop_gradient(sigma_for_test) state_for_test = jax.lax.stop_gradient(state_for_test) state = jax.lax.stop_gradient(state) if is_training == True or w_exists == False: u = state[0] hk.set_state(f"u_{name_suffix}", u) if use_proximal_gradient == False: zeta = state[1] hk.set_state(f"zeta_{name_suffix}", zeta) if return_sigma == False: factor = jnp.where(max_singular_value < sigma, max_singular_value / sigma, 1.0) w = w * factor w_ret = w else: w_ret = (w, sigma) if use_bias: b = hk.get_parameter(f"b_{name_suffix}", b_shape, dtype, init=b_init) ret = (w_ret, b) else: ret = w_ret if monitor_progress: ret = (ret, estimates) return ret
def __call__(self, x: jnp.ndarray): content_loss = jnp.mean(jnp.square(x - self.target)) hk.set_state("content_loss", content_loss) return x
def set_state_tree(name, val): return hk.set_state(name, Box(val))