예제 #1
0
  def get_normalized_weights(self,
                             weights: jnp.ndarray,
                             renormalize: bool = False) -> jnp.ndarray:

    def _l2_normalize(x, axis=None, eps=1e-12):
      return x * jax.lax.rsqrt((x * x).sum(axis=axis, keepdims=True) + eps)

    output_size = self.output_size
    dtype = weights.dtype
    assert output_size == weights.shape[-1]
    sigma = hk.get_state('sigma', (), init=jnp.ones)
    if renormalize:
      # Power iterations to compute spectral norm V*W*U^T.
      u = hk.get_state(
          'u', (1, output_size), dtype, init=hk.initializers.RandomNormal())
      for _ in range(self.num_iterations):
        v = _l2_normalize(jnp.matmul(u, weights.transpose()), eps=self.eps)
        u = _l2_normalize(jnp.matmul(v, weights), eps=self.eps)
      u = jax.lax.stop_gradient(u)
      v = jax.lax.stop_gradient(v)
      sigma = jnp.matmul(jnp.matmul(v, weights), jnp.transpose(u))[0, 0]
      hk.set_state('u', u)
      hk.set_state('v', v)
      hk.set_state('sigma', sigma)
    factor = jnp.maximum(1, sigma / self.lipschitz_coeff)
    return weights / factor
예제 #2
0
def noise(a, ep):
    a = jnp.asarray(a)
    shape = env.action_space.shape
    mu, sigma, theta = hparams.exploration_mu, hparams.exploration_sigma, hparams.exploration_theta
    scale = hk.get_state('scale', shape=(), dtype=a.dtype, init=jnp.ones)
    noise = hk.get_state('noise', shape=a.shape, dtype=a.dtype, init=jnp.zeros)
    scale = scale * hparams.noise_decay
    noise = theta * (mu - noise) + sigma * jax.random.normal(
        hk.next_rng_key(), shape)
    hk.set_state('scale', scale)
    hk.set_state('noise', noise)
    return a + noise * scale
예제 #3
0
 def _update_memory(self, mem: jnp.ndarray, mask: jnp.ndarray,
                    input_length: int, cache_steps: int,
                    should_reset: jnp.ndarray) -> jnp.ndarray:
     """Logic for using and updating cached activations."""
     batch_size = mem.shape[0]
     if cache_steps > 0:
         # Tells us how much of the cache should be used.
         cache_progress_idx = hk.get_state('cache_progress_idx',
                                           [batch_size],
                                           dtype=jnp.int32,
                                           init=jnp.zeros)
         hk.set_state('cache_progress_idx',
                      cache_progress_idx + input_length)
         mem = self._update_cache('mem', mem, cache_steps=cache_steps)
         if mask is None:
             mask = jnp.ones((batch_size, 1, input_length, input_length))
         cache_mask = (jnp.arange(cache_steps - 1, -1, -1)[None, None,
                                                           None, :] <
                       cache_progress_idx[:, None, None, None])
         cache_mask = jnp.broadcast_to(
             cache_mask, (batch_size, 1, input_length, cache_steps))
         mask = jnp.concatenate([cache_mask, mask], axis=-1)
     if should_reset is not None:
         if cache_steps > 0:
             should_reset = self._update_cache('should_reset',
                                               should_reset,
                                               cache_steps=cache_steps)
         reset_mask = get_reset_attention_mask(should_reset)[:, None, :, :]
         mask *= reset_mask[:, :, cache_steps:, :]
     return mem, mask
예제 #4
0
def get_pos_start(timesteps: int, batch_size: int) -> jnp.ndarray:
    """Find the right slice of positional embeddings for incremental sampling."""
    pos_start = hk.get_state('cache_progress_idx', [batch_size],
                             dtype=jnp.int32,
                             init=jnp.zeros)
    hk.set_state('cache_progress_idx', pos_start + timesteps)
    return pos_start
예제 #5
0
def conv_weight_with_spectral_norm(x: jnp.ndarray,
                                   kernel_shape: Sequence[int],
                                   out_channel: int,
                                   name_suffix: str = "",
                                   w_init: Callable = None,
                                   b_init: Callable = None,
                                   use_bias: bool = True,
                                   is_training: bool = True,
                                   update_params: bool = True,
                                   max_singular_value: float = 0.95,
                                   max_power_iters: int = 1,
                                   **conv_kwargs):
    batch_size, H, W, C = x.shape
    w_shape = kernel_shape + (C, out_channel)

    w = hk.get_parameter(f"w_{name_suffix}", w_shape, x.dtype, init=w_init)
    if use_bias:
        b = hk.get_parameter(f"b_{name_suffix}", (out_channel, ), init=b_init)

    u = hk.get_state(f"u_{name_suffix}", (H, W, out_channel),
                     init=hk.initializers.RandomNormal())
    v = hk.get_state(f"v_{name_suffix}", (H, W, C),
                     init=hk.initializers.RandomNormal())
    w, u, v = sn.spectral_norm_conv_apply(w, u, v, conv_kwargs["stride"],
                                          conv_kwargs["padding"],
                                          max_singular_value, max_power_iters,
                                          update_params)

    # Run for a lot of steps when we're first initializing
    running_init_fn = not hk_base.params_frozen()
    if running_init_fn:
        w, u, v = sn.spectral_norm_conv_apply(w, u, v, conv_kwargs["stride"],
                                              conv_kwargs["padding"],
                                              max_singular_value, None, True)

    if is_training == True or running_init_fn:
        hk.set_state(f"u_{name_suffix}", u)
        hk.set_state(f"v_{name_suffix}", v)

    if use_bias:
        b = hk.get_parameter(f"b_{name_suffix}", (out_channel, ),
                             x.dtype,
                             init=b_init)

    if use_bias:
        return w, b
    return w
예제 #6
0
            def __call__(self, x):
                c1 = haiku.get_parameter("c1", [5], jnp.int32, init=jnp.ones)
                c2 = haiku.get_state("c2", [6], jnp.int32, init=jnp.ones)

                x = jax.nn.relu(x)
                elegy.haiku_summary("relu", jax.nn.relu, x)

                return x
예제 #7
0
파일: reduce.py 프로젝트: stjordanis/elegy
    def call(self,
             values: jnp.ndarray,
             sample_weight: tp.Optional[jnp.ndarray] = None) -> jnp.ndarray:
        """
        Accumulates statistics for computing the reduction metric. For example, if `values` is [1, 3, 5, 7] 
        and reduction=SUM_OVER_BATCH_SIZE, then the value of `result()` is 4. If the `sample_weight` 
        is specified as [1, 1, 0, 0] then value of `result()` would be 2.
        
        Arguments:
            values: Per-example value.
            sample_weight: Optional weighting of each example. Defaults to 1.
        
        Returns:
            Array with the cummulative reduce.
        """
        total = hk.get_state("total",
                             shape=[],
                             dtype=self._dtype,
                             init=hk.initializers.Constant(0))

        if self._reduction in (
                Reduction.SUM_OVER_BATCH_SIZE,
                Reduction.WEIGHTED_MEAN,
        ):
            count = hk.get_state("count",
                                 shape=[],
                                 dtype=jnp.int32,
                                 init=hk.initializers.Constant(0))
        else:
            count = None

        value, total, count = reduce(
            total=total,
            count=count,
            values=values,
            reduction=self._reduction,
            sample_weight=sample_weight,
            dtype=self._dtype,
        )

        hk.set_state("total", total)

        if count is not None:
            hk.set_state("count", count)

        return value
예제 #8
0
def weight_with_spectral_norm(x: jnp.ndarray,
                              out_dim: int,
                              name_suffix: str = "",
                              w_init: Callable = None,
                              b_init: Callable = None,
                              is_training: bool = True,
                              update_params: bool = True,
                              use_bias: bool = True,
                              force_in_dim: Optional = None,
                              max_singular_value: float = 0.99,
                              max_power_iters: int = 1,
                              **kwargs):
    in_dim, dtype = x.shape[-1], x.dtype
    if force_in_dim:
        in_dim = force_in_dim

    w = hk.get_parameter(f"w_{name_suffix}", (out_dim, in_dim),
                         dtype,
                         init=w_init)
    if use_bias:
        b = hk.get_parameter(f"b_{name_suffix}", (out_dim, ),
                             dtype,
                             init=b_init)

    u = hk.get_state(f"u_{name_suffix}", (out_dim, ),
                     dtype,
                     init=hk.initializers.RandomNormal())
    v = hk.get_state(f"v_{name_suffix}", (in_dim, ),
                     dtype,
                     init=hk.initializers.RandomNormal())
    w, u, v = sn.spectral_norm_apply(w, u, v, max_singular_value,
                                     max_power_iters, update_params)

    running_init_fn = not hk_base.params_frozen()
    if running_init_fn:
        w, u, v = sn.spectral_norm_apply(w, u, v, max_singular_value, None,
                                         True)

    if is_training == True or running_init_fn:
        hk.set_state(f"u_{name_suffix}", u)
        hk.set_state(f"v_{name_suffix}", v)

    if use_bias:
        return w, b
    return w
예제 #9
0
    def _get_hyperplanes(self, side_info_size):
        """Get (or initialize) hyperplane weights and bias."""

        hyp_w_init = self._hyp_w_init or NormalizedRandomNormal(
            stddev=1., normalize_axis=1)
        hyperplanes = hk.get_state("hyperplanes",
                                   shape=(self._output_size, self._context_dim,
                                          side_info_size),
                                   init=hyp_w_init)

        hyp_b_init = self._hyp_b_init or hk.initializers.RandomNormal(
            stddev=0.05)
        hyperplane_bias = hk.get_state("hyperplane_bias",
                                       shape=(self._output_size,
                                              self._context_dim),
                                       init=hyp_b_init)

        return hyperplanes, hyperplane_bias
예제 #10
0
            def __call__(self, x):
                b1 = haiku.get_parameter("b1", [3], jnp.int32, init=jnp.ones)
                b2 = haiku.get_state("b2", [4], jnp.int32, init=jnp.ones)

                x = ModuleC()(x)

                x = jax.nn.relu(x)
                elegy.haiku_summary("relu", jax.nn.relu, x)

                return x
예제 #11
0
            def __call__(self, x):
                a1 = haiku.get_parameter("a1", [1], jnp.int32, init=jnp.ones)
                a2 = haiku.get_state("a2", [2], jnp.int32, init=jnp.ones)

                x = ModuleB()(x)

                x = jax.nn.relu(x)
                elegy.haiku_summary("relu", jax.nn.relu, x)

                return x
예제 #12
0
            def __call__(self, x):

                n = haiku.get_state(
                    "n", shape=[], dtype=jnp.int32, init=lambda *args: np.array(0)
                )
                w = haiku.get_parameter("w", [], init=lambda *args: np.array(2.0))

                haiku.set_state("n", n + 1)

                return x * w
예제 #13
0
  def __call__(self, sample: jnp.DeviceArray) -> Tuple[Array, Array]:
    if len(sample.shape) > 1:
      raise ValueError("sample must be a rank 0 or 1 DeviceArray.")

    count = hk.get_state("count", shape=(), dtype=jnp.int32, init=jnp.zeros)
    mean = hk.get_state(
        "mean", shape=sample.shape, dtype=jnp.float32, init=jnp.zeros)
    m2 = hk.get_state(
        "m2", shape=sample.shape, dtype=jnp.float32, init=jnp.zeros)

    count += 1
    delta = sample - mean
    mean += delta / count
    delta_2 = sample - mean
    m2 += delta * delta_2

    hk.set_state("count", count)
    hk.set_state("mean", mean)
    hk.set_state("m2", m2)

    stddev = jnp.sqrt(m2 / count)
    return mean, stddev
예제 #14
0
def weight_with_spectral_norm(x: jnp.ndarray,
                              out_dim: int,
                              name_suffix: str = "",
                              w_init: Callable = None,
                              b_init: Callable = None,
                              is_training: bool = True,
                              update_params: bool = True,
                              use_bias: bool = True,
                              **kwargs):
    in_dim, dtype = x.shape[-1], x.dtype

    def w_init_whiten(shape, dtype):
        w = w_init(shape, dtype)
        return util.whiten(w) * 0.9

    w = hk.get_parameter(f"w_{name_suffix}", (out_dim, in_dim),
                         dtype,
                         init=w_init_whiten)
    if use_bias:
        b = hk.get_parameter(f"b_{name_suffix}", (out_dim, ),
                             dtype,
                             init=b_init)

    u = hk.get_state(f"u_{name_suffix}", (out_dim, ),
                     dtype,
                     init=hk.initializers.RandomNormal())
    v = hk.get_state(f"v_{name_suffix}", (in_dim, ),
                     dtype,
                     init=hk.initializers.RandomNormal())
    w, u, v = sn.spectral_norm_apply(w, u, v, 0.99, 5, update_params)
    if is_training == True:
        hk.set_state(f"u_{name_suffix}", u)
        hk.set_state(f"v_{name_suffix}", v)

    if use_bias:
        return w, b
    return w
예제 #15
0
def conv_weight_with_spectral_norm(x: jnp.ndarray,
                                   kernel_shape: Sequence[int],
                                   out_channel: int,
                                   name_suffix: str = "",
                                   w_init: Callable = None,
                                   b_init: Callable = None,
                                   use_bias: bool = True,
                                   is_training: bool = True,
                                   **conv_kwargs):
    batch_size, H, W, C = x.shape
    w_shape = kernel_shape + (C, out_channel)

    def w_init_whiten(shape, dtype):
        w = w_init(shape, dtype)
        return w * 0.7

    w = hk.get_parameter(f"w_{name_suffix}",
                         w_shape,
                         x.dtype,
                         init=w_init_whiten)
    if use_bias:
        b = hk.get_parameter(f"b_{name_suffix}", (out_channel, ), init=b_init)

    u = hk.get_state(f"u_{name_suffix}",
                     kernel_shape + (out_channel, ),
                     init=hk.initializers.RandomNormal())
    w, u = sn.spectral_norm_conv_apply(w, u, conv_kwargs["stride"],
                                       conv_kwargs["padding"], 0.9, 1)
    if is_training == True:
        hk.set_state(f"u_{name_suffix}", u)

    if use_bias:
        b = hk.get_parameter(f"b_{name_suffix}", (out_channel, ),
                             x.dtype,
                             init=b_init)

    if use_bias:
        return w, b
    return w
예제 #16
0
    def _update_cache(self,
                      key: jnp.ndarray,
                      value: jnp.ndarray,
                      cache_steps: Optional[int] = None,
                      axis: int = 1) -> jnp.ndarray:
        """Update the cache stored in hk.state."""
        cache_shape = list(value.shape)
        value_steps = cache_shape[axis]
        if cache_steps is not None:
            cache_shape[axis] += cache_steps
        cache = hk.get_state(key,
                             shape=cache_shape,
                             dtype=value.dtype,
                             init=jnp.zeros)

        # Overwrite at index 0, then rotate timesteps left so what was just
        # inserted is first.
        value = jax.lax.dynamic_update_slice(
            cache, value, jnp.zeros(len(cache_shape), dtype=jnp.int32))
        value = jnp.roll(value, -value_steps, axis)
        hk.set_state(key, value)
        return value
    def masked_call(self,
                    inputs: Mapping[str, jnp.ndarray],
                    rng: jnp.ndarray = None,
                    sample: Optional[bool] = False,
                    **kwargs) -> Mapping[str, jnp.ndarray]:
        """ Perform coupling by masking the input
    """
        k1, k2, k3 = random.split(rng, 3)

        # Generate the mask
        def mask_init(shape, dtype):
            if len(shape) == 3:
                H, W, C = shape
                X, Y, Z = jnp.meshgrid(jnp.arange(H), jnp.arange(W),
                                       jnp.arange(C))
                if self.split_kind == "checkerboard":
                    mask = (X + Y + Z) % 2
                elif self.split_kind == "channel":
                    mask = (X, Y, Z)[self.axis] > shape[self.axis] // 2
            else:
                dim, = shape
                if self.split_kind == "checkerboard":
                    mask = jnp.arange(dim) % 2
                elif self.split_kind == "channel":
                    mask = jnp.arange(dim) > dim // 2
            return mask.astype(dtype)

        x_shape = self.unbatched_input_shapes["x"]
        mask = hk.get_state("mask", shape=x_shape, dtype=bool, init=mask_init)
        nmask = ~mask

        x = inputs["x"]
        if self.use_condition:
            assert "condition" in inputs
            condition = inputs["condition"]
        else:
            condition = None

        # Mask the input
        x_mask = x * mask
        x_nmask = x * nmask

        # Initialize the network
        out_shape = self.get_out_shape(x_mask)
        self.network = self.get_network(out_shape)

        if sample == False:
            # zb = f(xb; theta)
            if self.apply_to_both_halves:
                z_nmask, log_det_b = self._transform(x_nmask,
                                                     sample=False,
                                                     mask=nmask,
                                                     rng=k1)
            else:
                z_nmask, log_det_b = x_nmask, 0.0

            # za = f(xa; NN(xb))
            network_out = self.apply_conditioner_network(
                k2, x_nmask, condition, **kwargs)
            z_mask, log_det_a = self._transform(x,
                                                params=network_out,
                                                sample=False,
                                                mask=mask,
                                                rng=k3)
        else:
            # xb = f^{-1}(zb; theta).  (x and z are swapped so that the code is a bit cleaner)
            if self.apply_to_both_halves:
                z_nmask, log_det_b = self._transform(x_nmask,
                                                     sample=True,
                                                     mask=nmask,
                                                     rng=k1)
            else:
                z_nmask, log_det_b = x_nmask, 0.0

            # xa = f^{-1}(za; NN(xb)).
            network_out = self.apply_conditioner_network(
                k2, z_nmask, condition, **kwargs)
            x_in = z_nmask + x_mask
            z_mask, log_det_a = self._transform(x_in,
                                                params=network_out,
                                                sample=True,
                                                mask=mask,
                                                rng=k3)

        # Apply the other half of the mask to the output
        z = z_nmask + z_mask
        log_det = log_det_a + log_det_b

        outputs = {"x": z, "log_det": log_det}
        return outputs
예제 #18
0
파일: priors.py 프로젝트: jxzhangjhu/NuX
    def call(self,
             inputs: Mapping[str, jnp.ndarray],
             rng: PRNGKey,
             sample: Optional[bool] = False,
             reconstruction: Optional[bool] = False,
             **kwargs) -> Mapping[str, jnp.ndarray]:
        x = inputs["x"]
        outputs = {}
        x_shape = self.get_unbatched_shapes(sample)["x"]
        sum_axes = tuple(-jnp.arange(1, 1 + len(x_shape)))
        x_flat = x.reshape(self.batch_shape + (-1, ))
        y = inputs.get("y", jnp.ones(self.batch_shape, dtype=jnp.int32) * -1)

        # Keep these fixed.  Learning doesn't make much difference apparently.
        means = hk.get_state("means",
                             shape=(self.n_classes, x_flat.shape[-1]),
                             dtype=x.dtype,
                             init=hk.initializers.RandomNormal())
        log_diag_covs = hk.get_state("log_diag_covs",
                                     shape=(self.n_classes, x_flat.shape[-1]),
                                     dtype=x.dtype,
                                     init=jnp.zeros)

        @partial(jax.vmap, in_axes=(0, 0, None))
        def diag_gaussian(mean, log_diag_cov, x_flat):
            dx = x_flat - mean
            log_pdf = jnp.dot(dx * jnp.exp(-log_diag_cov), dx)
            log_pdf += log_diag_cov.sum()
            log_pdf += x_flat.size * jnp.log(2 * jnp.pi)
            return -0.5 * log_pdf

        log_pdfs = self.auto_batch(partial(diag_gaussian, means,
                                           log_diag_covs))(x_flat)

        # # Compute the log pdfs of each mixture component
        # normal = dists.Normal(means, jnp.exp(log_diag_covs))
        # log_pdfs = self.auto_batch(normal.log_prob)(x_flat)
        # log_pdfs = log_pdfs.sum(axis=-1)

        if sample == False:
            # Compute p(x,y) = p(x|y)p(y) if we have a label, p(x) otherwise
            def log_prob(y, log_pdfs):
                return jax.lax.cond(
                    y >= 0, lambda a: log_pdfs[y] + jnp.log(self.n_classes),
                    lambda a: logsumexp(log_pdfs) - jnp.log(self.n_classes),
                    None)

            outputs["log_pz"] = self.auto_batch(log_prob)(y, log_pdfs)
            outputs["x"] = x

        else:
            if reconstruction:
                outputs = {"x": x, "log_pz": jnp.array(0.0)}
            else:
                # Sample from all of the clusters
                # xs = normal.sample(rng)
                xs = random.normal(rng, x_flat.shape)

                def sample(log_pdfs, y, rng):
                    def no_label(y):
                        y = random.randint(rng,
                                           minval=0,
                                           maxval=self.n_classes,
                                           shape=(1, ))[0]
                        # y = dists.CategoricalLogits(jnp.zeros(self.n_classes)).sample(rng, (1,))[0]
                        return y, logsumexp(log_pdfs) - jnp.log(self.n_classes)

                    def with_label(y):
                        return y, log_pdfs[y] - jnp.log(self.n_classes)

                    # Either sample or use a specified cluster
                    return jax.lax.cond(y < 0, no_label, with_label, y)

                n_keys = util.list_prod(self.batch_shape)
                rngs = random.split(rng,
                                    n_keys).reshape(self.batch_shape + (-1, ))
                y, log_pz = self.auto_batch(sample)(log_pdfs, y, rngs)

                # Take a specific cluster
                outputs = {"x": xs[y].reshape(x.shape), "log_pz": log_pz}

        outputs["prediction"] = jnp.argmax(log_pdfs)

        return outputs
예제 #19
0
def apply_sn(*,
             mvp,
             mvpT,
             w_shape,
             b_shape,
             out_shape,
             dtype,
             w_init,
             b_init,
             name_suffix,
             is_training,
             use_bias,
             max_singular_value,
             max_power_iters,
             use_proximal_gradient=False,
             monitor_progress=False,
             monitor_iters=20,
             return_sigma=False,
             **kwargs):

    w_exists = util.check_if_parameter_exists(f"w_{name_suffix}")

    w = hk.get_parameter(f"w_{name_suffix}", w_shape, dtype, init=w_init)
    u = hk.get_state(f"u_{name_suffix}",
                     out_shape,
                     dtype,
                     init=hk.initializers.RandomNormal())
    if use_proximal_gradient == False:
        zeta = hk.get_state(f"zeta_{name_suffix}",
                            out_shape,
                            dtype,
                            init=hk.initializers.RandomNormal())
        state = (u, zeta)
    else:
        state = (u, )

    if use_proximal_gradient == False:
        estimate_max_singular_value = jax.jit(sn.max_singular_value,
                                              static_argnums=(0, 1))
    else:
        estimate_max_singular_value = jax.jit(sn.max_singular_value_no_grad,
                                              static_argnums=(0, 1))

    if w_exists == False:
        max_power_iters = 1000

    if monitor_progress:
        estimates = []

    for i in range(max_power_iters):
        sigma, *state = estimate_max_singular_value(mvp, mvpT, w, *state)
        if monitor_progress:
            estimates.append(sigma)

    if monitor_progress:
        sigma_for_test = sigma
        state_for_test = state
        for i in range(monitor_iters - max_power_iters):
            sigma_for_test, *state_for_test = estimate_max_singular_value(
                mvp, mvpT, w, *state_for_test)
            estimates.append(sigma_for_test)

        estimates = jnp.array(estimates)

        sigma_for_test = jax.lax.stop_gradient(sigma_for_test)
        state_for_test = jax.lax.stop_gradient(state_for_test)

    state = jax.lax.stop_gradient(state)

    if is_training == True or w_exists == False:
        u = state[0]
        hk.set_state(f"u_{name_suffix}", u)
        if use_proximal_gradient == False:
            zeta = state[1]
            hk.set_state(f"zeta_{name_suffix}", zeta)

    if return_sigma == False:
        factor = jnp.where(max_singular_value < sigma,
                           max_singular_value / sigma, 1.0)
        w = w * factor
        w_ret = w
    else:
        w_ret = (w, sigma)

    if use_bias:
        b = hk.get_parameter(f"b_{name_suffix}", b_shape, dtype, init=b_init)
        ret = (w_ret, b)
    else:
        ret = w_ret

    if monitor_progress:
        ret = (ret, estimates)

    return ret
예제 #20
0
파일: maf.py 프로젝트: jxzhangjhu/NuX
    def call(self,
             inputs: Mapping[str, jnp.ndarray],
             rng: jnp.ndarray = None,
             sample: Optional[bool] = False,
             **kwargs) -> Mapping[str, jnp.ndarray]:
        def initialize_input_sel(shape, dtype):
            dim = shape[-1]
            if self.method == "random":
                rng = hk.next_rng_key()
                input_sel = random.randint(rng,
                                           shape=(dim, ),
                                           minval=1,
                                           maxval=dim + 1)
            else:
                input_sel = jnp.arange(1, dim + 1)
            return input_sel

        # Initialize the input selection
        dim = inputs["x"].shape[-1]
        input_sel = hk.get_state("input_sel", (dim, ),
                                 jnp.int32,
                                 init=initialize_input_sel)

        # Create the MADE network that will generate the parameters for the MAF.
        made = net.MADE(input_sel,
                        dim,
                        self.hidden_layer_sizes,
                        self.method,
                        nonlinearity=self.nonlinearity,
                        triangular_jacobian=False)

        if sample == False:
            x = inputs["x"]

            mu, alpha = made(x, rng)
            z = (x - mu) * jnp.exp(-alpha)
            log_det = -alpha.sum(axis=-1) * jnp.ones(self.batch_shape)
            outputs = {"x": z, "log_det": log_det}
        else:
            z = inputs["x"]

            def inverse(z):
                x = jnp.zeros_like(z)

                # We need to build output a dimension at a time
                def carry_body(carry, inputs):
                    x, idx = carry, inputs
                    mu, alpha = made(x, rng)
                    w = mu + z * jnp.exp(alpha)
                    x = jax.ops.index_update(x, idx, w[idx])
                    return x, alpha[idx]

                indices = jnp.nonzero(
                    input_sel == (1 + jnp.arange(x.shape[0])[:, None]))[1]
                x, alpha_diag = jax.lax.scan(carry_body, x, indices)
                log_det = -alpha_diag.sum(axis=-1)
                return x, log_det

            x, log_det = self.auto_batch(inverse)(z)
            outputs = {"x": x, "log_det": log_det}

        return outputs
예제 #21
0
def get_state_tree(name, init):
  return hk.get_state(name, (), jnp.float32, init=lambda *_: Box(init())).value
예제 #22
0
 def u0(self):
   return hk.get_state("u0")
예제 #23
0
 def sigma(self):
   return hk.get_state("sigma", shape=(), init=jnp.ones)
예제 #24
0
  def __call__(
      self,
      value,
      update_stats: bool = True,
      error_on_non_matrix: bool = False,
  ) -> jnp.ndarray:
    """Performs Spectral Normalization and returns the new value.
    Args:
      value: The array-like object for which you would like to perform an
        spectral normalization on.
      update_stats: A boolean defaulting to True. Regardless of this arg, this
        function will return the normalized input. When
        `update_stats` is True, the internal state of this object will also be
        updated to reflect the input value. When `update_stats` is False the
        internal stats will remain unchanged.
      error_on_non_matrix: Spectral normalization is only defined on matrices.
        By default, this module will return scalars unchanged and flatten
        higher-order tensors in their leading dimensions. Setting this flag to
        True will instead throw errors in those cases.
    Returns:
      The input value normalized by it's first singular value.
    Raises:
      ValueError: If `error_on_non_matrix` is True and `value` has ndims > 2.
    """
    value = jnp.asarray(value)
    value_shape = value.shape

    # Handle scalars.
    if value.ndim <= 1:
      raise ValueError("Spectral normalization is not well defined for "
                       "scalar or vector inputs.")
    # Handle higher-order tensors.
    elif value.ndim > 2:
      if error_on_non_matrix:
        raise ValueError(
            f"Input is {value.ndim}D but error_on_non_matrix is True")
      else:
        value = jnp.reshape(value, [-1, value.shape[-1]])

    u0 = hk.get_state("u0", [1, value.shape[-1]], value.dtype,
                      init=hk.initializers.RandomNormal())

    # Power iteration for the weight's singular value.
    for _ in range(self.n_steps):
      v0 = _l2_normalize(jnp.matmul(u0, value.transpose([1, 0])), eps=self.eps)
      u0 = _l2_normalize(jnp.matmul(v0, value), eps=self.eps)

    u0 = jax.lax.stop_gradient(u0)
    v0 = jax.lax.stop_gradient(v0)

    sigma = jnp.matmul(jnp.matmul(v0, value), jnp.transpose(u0))[0, 0]

    value /= sigma
    value *= self.val
    value_bar = value.reshape(value_shape)

    if update_stats:
      hk.set_state("u0", u0)
      hk.set_state("sigma", sigma)

    return value_bar