def get_normalized_weights(self, weights: jnp.ndarray, renormalize: bool = False) -> jnp.ndarray: def _l2_normalize(x, axis=None, eps=1e-12): return x * jax.lax.rsqrt((x * x).sum(axis=axis, keepdims=True) + eps) output_size = self.output_size dtype = weights.dtype assert output_size == weights.shape[-1] sigma = hk.get_state('sigma', (), init=jnp.ones) if renormalize: # Power iterations to compute spectral norm V*W*U^T. u = hk.get_state( 'u', (1, output_size), dtype, init=hk.initializers.RandomNormal()) for _ in range(self.num_iterations): v = _l2_normalize(jnp.matmul(u, weights.transpose()), eps=self.eps) u = _l2_normalize(jnp.matmul(v, weights), eps=self.eps) u = jax.lax.stop_gradient(u) v = jax.lax.stop_gradient(v) sigma = jnp.matmul(jnp.matmul(v, weights), jnp.transpose(u))[0, 0] hk.set_state('u', u) hk.set_state('v', v) hk.set_state('sigma', sigma) factor = jnp.maximum(1, sigma / self.lipschitz_coeff) return weights / factor
def noise(a, ep): a = jnp.asarray(a) shape = env.action_space.shape mu, sigma, theta = hparams.exploration_mu, hparams.exploration_sigma, hparams.exploration_theta scale = hk.get_state('scale', shape=(), dtype=a.dtype, init=jnp.ones) noise = hk.get_state('noise', shape=a.shape, dtype=a.dtype, init=jnp.zeros) scale = scale * hparams.noise_decay noise = theta * (mu - noise) + sigma * jax.random.normal( hk.next_rng_key(), shape) hk.set_state('scale', scale) hk.set_state('noise', noise) return a + noise * scale
def _update_memory(self, mem: jnp.ndarray, mask: jnp.ndarray, input_length: int, cache_steps: int, should_reset: jnp.ndarray) -> jnp.ndarray: """Logic for using and updating cached activations.""" batch_size = mem.shape[0] if cache_steps > 0: # Tells us how much of the cache should be used. cache_progress_idx = hk.get_state('cache_progress_idx', [batch_size], dtype=jnp.int32, init=jnp.zeros) hk.set_state('cache_progress_idx', cache_progress_idx + input_length) mem = self._update_cache('mem', mem, cache_steps=cache_steps) if mask is None: mask = jnp.ones((batch_size, 1, input_length, input_length)) cache_mask = (jnp.arange(cache_steps - 1, -1, -1)[None, None, None, :] < cache_progress_idx[:, None, None, None]) cache_mask = jnp.broadcast_to( cache_mask, (batch_size, 1, input_length, cache_steps)) mask = jnp.concatenate([cache_mask, mask], axis=-1) if should_reset is not None: if cache_steps > 0: should_reset = self._update_cache('should_reset', should_reset, cache_steps=cache_steps) reset_mask = get_reset_attention_mask(should_reset)[:, None, :, :] mask *= reset_mask[:, :, cache_steps:, :] return mem, mask
def get_pos_start(timesteps: int, batch_size: int) -> jnp.ndarray: """Find the right slice of positional embeddings for incremental sampling.""" pos_start = hk.get_state('cache_progress_idx', [batch_size], dtype=jnp.int32, init=jnp.zeros) hk.set_state('cache_progress_idx', pos_start + timesteps) return pos_start
def conv_weight_with_spectral_norm(x: jnp.ndarray, kernel_shape: Sequence[int], out_channel: int, name_suffix: str = "", w_init: Callable = None, b_init: Callable = None, use_bias: bool = True, is_training: bool = True, update_params: bool = True, max_singular_value: float = 0.95, max_power_iters: int = 1, **conv_kwargs): batch_size, H, W, C = x.shape w_shape = kernel_shape + (C, out_channel) w = hk.get_parameter(f"w_{name_suffix}", w_shape, x.dtype, init=w_init) if use_bias: b = hk.get_parameter(f"b_{name_suffix}", (out_channel, ), init=b_init) u = hk.get_state(f"u_{name_suffix}", (H, W, out_channel), init=hk.initializers.RandomNormal()) v = hk.get_state(f"v_{name_suffix}", (H, W, C), init=hk.initializers.RandomNormal()) w, u, v = sn.spectral_norm_conv_apply(w, u, v, conv_kwargs["stride"], conv_kwargs["padding"], max_singular_value, max_power_iters, update_params) # Run for a lot of steps when we're first initializing running_init_fn = not hk_base.params_frozen() if running_init_fn: w, u, v = sn.spectral_norm_conv_apply(w, u, v, conv_kwargs["stride"], conv_kwargs["padding"], max_singular_value, None, True) if is_training == True or running_init_fn: hk.set_state(f"u_{name_suffix}", u) hk.set_state(f"v_{name_suffix}", v) if use_bias: b = hk.get_parameter(f"b_{name_suffix}", (out_channel, ), x.dtype, init=b_init) if use_bias: return w, b return w
def __call__(self, x): c1 = haiku.get_parameter("c1", [5], jnp.int32, init=jnp.ones) c2 = haiku.get_state("c2", [6], jnp.int32, init=jnp.ones) x = jax.nn.relu(x) elegy.haiku_summary("relu", jax.nn.relu, x) return x
def call(self, values: jnp.ndarray, sample_weight: tp.Optional[jnp.ndarray] = None) -> jnp.ndarray: """ Accumulates statistics for computing the reduction metric. For example, if `values` is [1, 3, 5, 7] and reduction=SUM_OVER_BATCH_SIZE, then the value of `result()` is 4. If the `sample_weight` is specified as [1, 1, 0, 0] then value of `result()` would be 2. Arguments: values: Per-example value. sample_weight: Optional weighting of each example. Defaults to 1. Returns: Array with the cummulative reduce. """ total = hk.get_state("total", shape=[], dtype=self._dtype, init=hk.initializers.Constant(0)) if self._reduction in ( Reduction.SUM_OVER_BATCH_SIZE, Reduction.WEIGHTED_MEAN, ): count = hk.get_state("count", shape=[], dtype=jnp.int32, init=hk.initializers.Constant(0)) else: count = None value, total, count = reduce( total=total, count=count, values=values, reduction=self._reduction, sample_weight=sample_weight, dtype=self._dtype, ) hk.set_state("total", total) if count is not None: hk.set_state("count", count) return value
def weight_with_spectral_norm(x: jnp.ndarray, out_dim: int, name_suffix: str = "", w_init: Callable = None, b_init: Callable = None, is_training: bool = True, update_params: bool = True, use_bias: bool = True, force_in_dim: Optional = None, max_singular_value: float = 0.99, max_power_iters: int = 1, **kwargs): in_dim, dtype = x.shape[-1], x.dtype if force_in_dim: in_dim = force_in_dim w = hk.get_parameter(f"w_{name_suffix}", (out_dim, in_dim), dtype, init=w_init) if use_bias: b = hk.get_parameter(f"b_{name_suffix}", (out_dim, ), dtype, init=b_init) u = hk.get_state(f"u_{name_suffix}", (out_dim, ), dtype, init=hk.initializers.RandomNormal()) v = hk.get_state(f"v_{name_suffix}", (in_dim, ), dtype, init=hk.initializers.RandomNormal()) w, u, v = sn.spectral_norm_apply(w, u, v, max_singular_value, max_power_iters, update_params) running_init_fn = not hk_base.params_frozen() if running_init_fn: w, u, v = sn.spectral_norm_apply(w, u, v, max_singular_value, None, True) if is_training == True or running_init_fn: hk.set_state(f"u_{name_suffix}", u) hk.set_state(f"v_{name_suffix}", v) if use_bias: return w, b return w
def _get_hyperplanes(self, side_info_size): """Get (or initialize) hyperplane weights and bias.""" hyp_w_init = self._hyp_w_init or NormalizedRandomNormal( stddev=1., normalize_axis=1) hyperplanes = hk.get_state("hyperplanes", shape=(self._output_size, self._context_dim, side_info_size), init=hyp_w_init) hyp_b_init = self._hyp_b_init or hk.initializers.RandomNormal( stddev=0.05) hyperplane_bias = hk.get_state("hyperplane_bias", shape=(self._output_size, self._context_dim), init=hyp_b_init) return hyperplanes, hyperplane_bias
def __call__(self, x): b1 = haiku.get_parameter("b1", [3], jnp.int32, init=jnp.ones) b2 = haiku.get_state("b2", [4], jnp.int32, init=jnp.ones) x = ModuleC()(x) x = jax.nn.relu(x) elegy.haiku_summary("relu", jax.nn.relu, x) return x
def __call__(self, x): a1 = haiku.get_parameter("a1", [1], jnp.int32, init=jnp.ones) a2 = haiku.get_state("a2", [2], jnp.int32, init=jnp.ones) x = ModuleB()(x) x = jax.nn.relu(x) elegy.haiku_summary("relu", jax.nn.relu, x) return x
def __call__(self, x): n = haiku.get_state( "n", shape=[], dtype=jnp.int32, init=lambda *args: np.array(0) ) w = haiku.get_parameter("w", [], init=lambda *args: np.array(2.0)) haiku.set_state("n", n + 1) return x * w
def __call__(self, sample: jnp.DeviceArray) -> Tuple[Array, Array]: if len(sample.shape) > 1: raise ValueError("sample must be a rank 0 or 1 DeviceArray.") count = hk.get_state("count", shape=(), dtype=jnp.int32, init=jnp.zeros) mean = hk.get_state( "mean", shape=sample.shape, dtype=jnp.float32, init=jnp.zeros) m2 = hk.get_state( "m2", shape=sample.shape, dtype=jnp.float32, init=jnp.zeros) count += 1 delta = sample - mean mean += delta / count delta_2 = sample - mean m2 += delta * delta_2 hk.set_state("count", count) hk.set_state("mean", mean) hk.set_state("m2", m2) stddev = jnp.sqrt(m2 / count) return mean, stddev
def weight_with_spectral_norm(x: jnp.ndarray, out_dim: int, name_suffix: str = "", w_init: Callable = None, b_init: Callable = None, is_training: bool = True, update_params: bool = True, use_bias: bool = True, **kwargs): in_dim, dtype = x.shape[-1], x.dtype def w_init_whiten(shape, dtype): w = w_init(shape, dtype) return util.whiten(w) * 0.9 w = hk.get_parameter(f"w_{name_suffix}", (out_dim, in_dim), dtype, init=w_init_whiten) if use_bias: b = hk.get_parameter(f"b_{name_suffix}", (out_dim, ), dtype, init=b_init) u = hk.get_state(f"u_{name_suffix}", (out_dim, ), dtype, init=hk.initializers.RandomNormal()) v = hk.get_state(f"v_{name_suffix}", (in_dim, ), dtype, init=hk.initializers.RandomNormal()) w, u, v = sn.spectral_norm_apply(w, u, v, 0.99, 5, update_params) if is_training == True: hk.set_state(f"u_{name_suffix}", u) hk.set_state(f"v_{name_suffix}", v) if use_bias: return w, b return w
def conv_weight_with_spectral_norm(x: jnp.ndarray, kernel_shape: Sequence[int], out_channel: int, name_suffix: str = "", w_init: Callable = None, b_init: Callable = None, use_bias: bool = True, is_training: bool = True, **conv_kwargs): batch_size, H, W, C = x.shape w_shape = kernel_shape + (C, out_channel) def w_init_whiten(shape, dtype): w = w_init(shape, dtype) return w * 0.7 w = hk.get_parameter(f"w_{name_suffix}", w_shape, x.dtype, init=w_init_whiten) if use_bias: b = hk.get_parameter(f"b_{name_suffix}", (out_channel, ), init=b_init) u = hk.get_state(f"u_{name_suffix}", kernel_shape + (out_channel, ), init=hk.initializers.RandomNormal()) w, u = sn.spectral_norm_conv_apply(w, u, conv_kwargs["stride"], conv_kwargs["padding"], 0.9, 1) if is_training == True: hk.set_state(f"u_{name_suffix}", u) if use_bias: b = hk.get_parameter(f"b_{name_suffix}", (out_channel, ), x.dtype, init=b_init) if use_bias: return w, b return w
def _update_cache(self, key: jnp.ndarray, value: jnp.ndarray, cache_steps: Optional[int] = None, axis: int = 1) -> jnp.ndarray: """Update the cache stored in hk.state.""" cache_shape = list(value.shape) value_steps = cache_shape[axis] if cache_steps is not None: cache_shape[axis] += cache_steps cache = hk.get_state(key, shape=cache_shape, dtype=value.dtype, init=jnp.zeros) # Overwrite at index 0, then rotate timesteps left so what was just # inserted is first. value = jax.lax.dynamic_update_slice( cache, value, jnp.zeros(len(cache_shape), dtype=jnp.int32)) value = jnp.roll(value, -value_steps, axis) hk.set_state(key, value) return value
def masked_call(self, inputs: Mapping[str, jnp.ndarray], rng: jnp.ndarray = None, sample: Optional[bool] = False, **kwargs) -> Mapping[str, jnp.ndarray]: """ Perform coupling by masking the input """ k1, k2, k3 = random.split(rng, 3) # Generate the mask def mask_init(shape, dtype): if len(shape) == 3: H, W, C = shape X, Y, Z = jnp.meshgrid(jnp.arange(H), jnp.arange(W), jnp.arange(C)) if self.split_kind == "checkerboard": mask = (X + Y + Z) % 2 elif self.split_kind == "channel": mask = (X, Y, Z)[self.axis] > shape[self.axis] // 2 else: dim, = shape if self.split_kind == "checkerboard": mask = jnp.arange(dim) % 2 elif self.split_kind == "channel": mask = jnp.arange(dim) > dim // 2 return mask.astype(dtype) x_shape = self.unbatched_input_shapes["x"] mask = hk.get_state("mask", shape=x_shape, dtype=bool, init=mask_init) nmask = ~mask x = inputs["x"] if self.use_condition: assert "condition" in inputs condition = inputs["condition"] else: condition = None # Mask the input x_mask = x * mask x_nmask = x * nmask # Initialize the network out_shape = self.get_out_shape(x_mask) self.network = self.get_network(out_shape) if sample == False: # zb = f(xb; theta) if self.apply_to_both_halves: z_nmask, log_det_b = self._transform(x_nmask, sample=False, mask=nmask, rng=k1) else: z_nmask, log_det_b = x_nmask, 0.0 # za = f(xa; NN(xb)) network_out = self.apply_conditioner_network( k2, x_nmask, condition, **kwargs) z_mask, log_det_a = self._transform(x, params=network_out, sample=False, mask=mask, rng=k3) else: # xb = f^{-1}(zb; theta). (x and z are swapped so that the code is a bit cleaner) if self.apply_to_both_halves: z_nmask, log_det_b = self._transform(x_nmask, sample=True, mask=nmask, rng=k1) else: z_nmask, log_det_b = x_nmask, 0.0 # xa = f^{-1}(za; NN(xb)). network_out = self.apply_conditioner_network( k2, z_nmask, condition, **kwargs) x_in = z_nmask + x_mask z_mask, log_det_a = self._transform(x_in, params=network_out, sample=True, mask=mask, rng=k3) # Apply the other half of the mask to the output z = z_nmask + z_mask log_det = log_det_a + log_det_b outputs = {"x": z, "log_det": log_det} return outputs
def call(self, inputs: Mapping[str, jnp.ndarray], rng: PRNGKey, sample: Optional[bool] = False, reconstruction: Optional[bool] = False, **kwargs) -> Mapping[str, jnp.ndarray]: x = inputs["x"] outputs = {} x_shape = self.get_unbatched_shapes(sample)["x"] sum_axes = tuple(-jnp.arange(1, 1 + len(x_shape))) x_flat = x.reshape(self.batch_shape + (-1, )) y = inputs.get("y", jnp.ones(self.batch_shape, dtype=jnp.int32) * -1) # Keep these fixed. Learning doesn't make much difference apparently. means = hk.get_state("means", shape=(self.n_classes, x_flat.shape[-1]), dtype=x.dtype, init=hk.initializers.RandomNormal()) log_diag_covs = hk.get_state("log_diag_covs", shape=(self.n_classes, x_flat.shape[-1]), dtype=x.dtype, init=jnp.zeros) @partial(jax.vmap, in_axes=(0, 0, None)) def diag_gaussian(mean, log_diag_cov, x_flat): dx = x_flat - mean log_pdf = jnp.dot(dx * jnp.exp(-log_diag_cov), dx) log_pdf += log_diag_cov.sum() log_pdf += x_flat.size * jnp.log(2 * jnp.pi) return -0.5 * log_pdf log_pdfs = self.auto_batch(partial(diag_gaussian, means, log_diag_covs))(x_flat) # # Compute the log pdfs of each mixture component # normal = dists.Normal(means, jnp.exp(log_diag_covs)) # log_pdfs = self.auto_batch(normal.log_prob)(x_flat) # log_pdfs = log_pdfs.sum(axis=-1) if sample == False: # Compute p(x,y) = p(x|y)p(y) if we have a label, p(x) otherwise def log_prob(y, log_pdfs): return jax.lax.cond( y >= 0, lambda a: log_pdfs[y] + jnp.log(self.n_classes), lambda a: logsumexp(log_pdfs) - jnp.log(self.n_classes), None) outputs["log_pz"] = self.auto_batch(log_prob)(y, log_pdfs) outputs["x"] = x else: if reconstruction: outputs = {"x": x, "log_pz": jnp.array(0.0)} else: # Sample from all of the clusters # xs = normal.sample(rng) xs = random.normal(rng, x_flat.shape) def sample(log_pdfs, y, rng): def no_label(y): y = random.randint(rng, minval=0, maxval=self.n_classes, shape=(1, ))[0] # y = dists.CategoricalLogits(jnp.zeros(self.n_classes)).sample(rng, (1,))[0] return y, logsumexp(log_pdfs) - jnp.log(self.n_classes) def with_label(y): return y, log_pdfs[y] - jnp.log(self.n_classes) # Either sample or use a specified cluster return jax.lax.cond(y < 0, no_label, with_label, y) n_keys = util.list_prod(self.batch_shape) rngs = random.split(rng, n_keys).reshape(self.batch_shape + (-1, )) y, log_pz = self.auto_batch(sample)(log_pdfs, y, rngs) # Take a specific cluster outputs = {"x": xs[y].reshape(x.shape), "log_pz": log_pz} outputs["prediction"] = jnp.argmax(log_pdfs) return outputs
def apply_sn(*, mvp, mvpT, w_shape, b_shape, out_shape, dtype, w_init, b_init, name_suffix, is_training, use_bias, max_singular_value, max_power_iters, use_proximal_gradient=False, monitor_progress=False, monitor_iters=20, return_sigma=False, **kwargs): w_exists = util.check_if_parameter_exists(f"w_{name_suffix}") w = hk.get_parameter(f"w_{name_suffix}", w_shape, dtype, init=w_init) u = hk.get_state(f"u_{name_suffix}", out_shape, dtype, init=hk.initializers.RandomNormal()) if use_proximal_gradient == False: zeta = hk.get_state(f"zeta_{name_suffix}", out_shape, dtype, init=hk.initializers.RandomNormal()) state = (u, zeta) else: state = (u, ) if use_proximal_gradient == False: estimate_max_singular_value = jax.jit(sn.max_singular_value, static_argnums=(0, 1)) else: estimate_max_singular_value = jax.jit(sn.max_singular_value_no_grad, static_argnums=(0, 1)) if w_exists == False: max_power_iters = 1000 if monitor_progress: estimates = [] for i in range(max_power_iters): sigma, *state = estimate_max_singular_value(mvp, mvpT, w, *state) if monitor_progress: estimates.append(sigma) if monitor_progress: sigma_for_test = sigma state_for_test = state for i in range(monitor_iters - max_power_iters): sigma_for_test, *state_for_test = estimate_max_singular_value( mvp, mvpT, w, *state_for_test) estimates.append(sigma_for_test) estimates = jnp.array(estimates) sigma_for_test = jax.lax.stop_gradient(sigma_for_test) state_for_test = jax.lax.stop_gradient(state_for_test) state = jax.lax.stop_gradient(state) if is_training == True or w_exists == False: u = state[0] hk.set_state(f"u_{name_suffix}", u) if use_proximal_gradient == False: zeta = state[1] hk.set_state(f"zeta_{name_suffix}", zeta) if return_sigma == False: factor = jnp.where(max_singular_value < sigma, max_singular_value / sigma, 1.0) w = w * factor w_ret = w else: w_ret = (w, sigma) if use_bias: b = hk.get_parameter(f"b_{name_suffix}", b_shape, dtype, init=b_init) ret = (w_ret, b) else: ret = w_ret if monitor_progress: ret = (ret, estimates) return ret
def call(self, inputs: Mapping[str, jnp.ndarray], rng: jnp.ndarray = None, sample: Optional[bool] = False, **kwargs) -> Mapping[str, jnp.ndarray]: def initialize_input_sel(shape, dtype): dim = shape[-1] if self.method == "random": rng = hk.next_rng_key() input_sel = random.randint(rng, shape=(dim, ), minval=1, maxval=dim + 1) else: input_sel = jnp.arange(1, dim + 1) return input_sel # Initialize the input selection dim = inputs["x"].shape[-1] input_sel = hk.get_state("input_sel", (dim, ), jnp.int32, init=initialize_input_sel) # Create the MADE network that will generate the parameters for the MAF. made = net.MADE(input_sel, dim, self.hidden_layer_sizes, self.method, nonlinearity=self.nonlinearity, triangular_jacobian=False) if sample == False: x = inputs["x"] mu, alpha = made(x, rng) z = (x - mu) * jnp.exp(-alpha) log_det = -alpha.sum(axis=-1) * jnp.ones(self.batch_shape) outputs = {"x": z, "log_det": log_det} else: z = inputs["x"] def inverse(z): x = jnp.zeros_like(z) # We need to build output a dimension at a time def carry_body(carry, inputs): x, idx = carry, inputs mu, alpha = made(x, rng) w = mu + z * jnp.exp(alpha) x = jax.ops.index_update(x, idx, w[idx]) return x, alpha[idx] indices = jnp.nonzero( input_sel == (1 + jnp.arange(x.shape[0])[:, None]))[1] x, alpha_diag = jax.lax.scan(carry_body, x, indices) log_det = -alpha_diag.sum(axis=-1) return x, log_det x, log_det = self.auto_batch(inverse)(z) outputs = {"x": x, "log_det": log_det} return outputs
def get_state_tree(name, init): return hk.get_state(name, (), jnp.float32, init=lambda *_: Box(init())).value
def u0(self): return hk.get_state("u0")
def sigma(self): return hk.get_state("sigma", shape=(), init=jnp.ones)
def __call__( self, value, update_stats: bool = True, error_on_non_matrix: bool = False, ) -> jnp.ndarray: """Performs Spectral Normalization and returns the new value. Args: value: The array-like object for which you would like to perform an spectral normalization on. update_stats: A boolean defaulting to True. Regardless of this arg, this function will return the normalized input. When `update_stats` is True, the internal state of this object will also be updated to reflect the input value. When `update_stats` is False the internal stats will remain unchanged. error_on_non_matrix: Spectral normalization is only defined on matrices. By default, this module will return scalars unchanged and flatten higher-order tensors in their leading dimensions. Setting this flag to True will instead throw errors in those cases. Returns: The input value normalized by it's first singular value. Raises: ValueError: If `error_on_non_matrix` is True and `value` has ndims > 2. """ value = jnp.asarray(value) value_shape = value.shape # Handle scalars. if value.ndim <= 1: raise ValueError("Spectral normalization is not well defined for " "scalar or vector inputs.") # Handle higher-order tensors. elif value.ndim > 2: if error_on_non_matrix: raise ValueError( f"Input is {value.ndim}D but error_on_non_matrix is True") else: value = jnp.reshape(value, [-1, value.shape[-1]]) u0 = hk.get_state("u0", [1, value.shape[-1]], value.dtype, init=hk.initializers.RandomNormal()) # Power iteration for the weight's singular value. for _ in range(self.n_steps): v0 = _l2_normalize(jnp.matmul(u0, value.transpose([1, 0])), eps=self.eps) u0 = _l2_normalize(jnp.matmul(v0, value), eps=self.eps) u0 = jax.lax.stop_gradient(u0) v0 = jax.lax.stop_gradient(v0) sigma = jnp.matmul(jnp.matmul(v0, value), jnp.transpose(u0))[0, 0] value /= sigma value *= self.val value_bar = value.reshape(value_shape) if update_stats: hk.set_state("u0", u0) hk.set_state("sigma", sigma) return value_bar