Ejemplo n.º 1
0
    def __call__(self, input, hidden):
        def update(state, x):
            nhid = self.nhid

            h, cell = state[0], state[1]
            h = h.squeeze()
            cell = cell.squeeze()

            x_components = self.i2h(x)
            h_components = self.h2h(h)

            preactivations = x_components + h_components

            gates_together = jax.nn.sigmoid(preactivations[:, 0:3 * nhid])
            forget_gate = gates_together[:, 0:nhid]
            input_gate = gates_together[:, nhid:2 * nhid]
            output_gate = gates_together[:, 2 * nhid:3 * nhid]
            new_cell = jnp.tanh(preactivations[:, 3 * nhid:4 * nhid])

            cell = forget_gate * cell + input_gate * new_cell
            h = output_gate * jnp.tanh(cell)

            new_state = jnp.stack([h, cell])
            return new_state, h

        new_state, hidden_list = hk.scan(update, hidden, input)
        hidden_stacked = jnp.stack(hidden_list)
        return hidden_stacked, new_state
Ejemplo n.º 2
0
 def u_f(xs):
   mod = module_fn()
   def s(carry, x):
     y = mod(x)
     return carry, y
   _, ys = hk.scan(s, (), xs)
   return ys
Ejemplo n.º 3
0
    def __call__(self, x_0: Array, time: float) -> Array:
        rng_beta = hk.next_rng_key()

        # Vector of zeros
        beta_mean_vector = jnp.zeros((self.beta_dims * 2, ))

        # Covariance matrix for the betas and gammas
        beta_covariance_top_left = self.delta_t**3 / 3 * jnp.eye(
            self.beta_dims)
        beta_covariance_top_right = self.delta_t**2 / 2 * jnp.eye(
            self.beta_dims)
        beta_covariance_bottom_right = self.delta_t * jnp.eye(self.beta_dims)
        beta_covariance_top = jnp.concatenate(
            (beta_covariance_top_left, beta_covariance_top_right), axis=1)
        beta_covariance_bottom = jnp.concatenate(
            (beta_covariance_top_right, beta_covariance_bottom_right), axis=1)
        beta_covariance = jnp.concatenate(
            (beta_covariance_top, beta_covariance_bottom), axis=0)

        delta_gamma_beta = numpyro.sample(
            "delta_gamma_beta",
            MultivariateNormal(loc=beta_mean_vector,
                               covariance_matrix=beta_covariance),
            rng_key=rng_beta)

        delta_gamma = delta_gamma_beta[:self.beta_dims]
        delta_beta = delta_gamma_beta[self.beta_dims:]

        drift_0 = self.drift(x_0, time)
        diff = self.diffusion(x_0, time)
        diff_plus = self.diffusion(x_0, time + self.delta_t)

        init_x_1 = x_0 + drift_0 * self.delta_t + jnp.matmul(diff, delta_beta)
        init_x_1 += 1. / self.delta_t * jnp.matmul(
            diff_plus - diff, delta_beta * self.delta_t - delta_gamma)

        def scan_fn(carry, s):
            x_1 = carry
            x_0_plus = \
                x_0 + drift_0 * self.delta_t / self.beta_dims + \
                diff[:, s] * jnp.sqrt(self.delta_t)
            x_0_minus = \
                x_0 + drift_0 * self.delta_t / self.beta_dims - \
                diff[:, s] * jnp.sqrt(self.delta_t)

            drift_0_plus = self.drift(x_0_plus, time + self.delta_t)
            drift_0_minus = self.drift(x_0_minus, time + self.delta_t)

            x_1 += 0.25 * self.delta_t * (drift_0_plus + drift_0_minus)
            x_1 -= 0.5 * drift_0 * self.delta_t
            x_1 += \
                1. / (2 * jnp.sqrt(self.delta_t)) * (drift_0_plus - drift_0_minus) * delta_gamma[s]
            return x_1, None

        final_x_1, _ = hk.scan(scan_fn, init_x_1, jnp.arange(self.beta_dims))

        return final_x_1
Ejemplo n.º 4
0
    def __call__(self,
                 x_0: Array,
                 y_values: Array,
                 t_mask: Array,
                 training: int,
                 t_0: float = 0.) -> Tuple[Array, Array, Array, Array]:
        def scan_fn(carry, it):
            x_t, y_t_path, y_t_generated, t = carry
            y_t_true, mask = it

            x_t_new = self.solver_x(x_t, t)
            d_x = x_t_new - x_t

            first_derivative = self.mapping.first_derivative(x_t, t)

            d_y = \
                self.mapping.time_derivative(x_t, t)(jnp.array([1.])) * self.delta_t + \
                first_derivative(d_x) + \
                0.5 * jnp.einsum("bc,c->b", self.mapping.hessian(x_t, t)(d_x), d_x)

            y_to_use = \
                (mask * y_t_true + jnp.abs(mask - 1) * y_t_path) * training + \
                jnp.abs(training - 1) * y_t_generated

            noise = self.brownian_noise()
            y_t_new = y_to_use + d_y + jnp.abs(training - 1) * noise
            y_t_generated_new = y_t_generated + d_y + noise

            t = t + self.delta_t
            return (x_t_new, y_t_new, y_t_generated_new, t), \
                   (x_t_new, y_t_new, y_t_generated_new, t)

        _, (final_paths_x, final_paths_y, final_paths_y_generated, final_t_seq) = \
            hk.scan(scan_fn, (x_0, y_values[0], y_values[0], t_0), (y_values[:-1], t_mask[:-1]))

        t_seq = jnp.tile(jnp.array([t_0]), [y_values.shape[0]])
        t_seq = t_seq.at[1:].set(final_t_seq)
        paths_x = jnp.tile(jnp.expand_dims(x_0, axis=0),
                           [y_values.shape[0], 1])
        paths_x = paths_x.at[1:].set(final_paths_x)
        paths_y = jnp.tile(jnp.expand_dims(y_values[0], axis=0),
                           [y_values.shape[0], 1])
        paths_y_generated = jnp.tile(jnp.expand_dims(y_values[0], axis=0),
                                     [y_values.shape[0], 1])
        paths_y = paths_y.at[1:].set(final_paths_y)
        paths_y_generated = paths_y_generated.at[1:].set(
            final_paths_y_generated)

        return t_seq, paths_x, paths_y, paths_y_generated
Ejemplo n.º 5
0
    def __call__(self, x, *args_ys):
        count = self._count
        if hk.running_init():
            # At initialization time, we run just one layer but add an extra first
            # dimension to every initialized tensor, making sure to use different
            # random keys for different slices.
            def creator(next_creator, shape, dtype, init, context):
                del context

                def multi_init(shape, dtype):
                    assert shape[0] == count
                    key = hk.maybe_next_rng_key()

                    def rng_context_init(slice_idx):
                        slice_key = maybe_fold_in(key, slice_idx)
                        with maybe_with_rng(slice_key):
                            return init(shape[1:], dtype)

                    return jax.vmap(rng_context_init)(jnp.arange(count))

                return next_creator((count, ) + tuple(shape), dtype,
                                    multi_init)

            def getter(next_getter, value, context):
                trailing_dims = len(context.original_shape) + 1
                sliced_value = jax.lax.index_in_dim(value,
                                                    index=0,
                                                    axis=value.ndim -
                                                    trailing_dims,
                                                    keepdims=False)
                return next_getter(sliced_value)

            with hk.experimental.custom_creator(
                    creator), hk.experimental.custom_getter(getter):
                if len(args_ys) == 1 and args_ys[0] is None:
                    args0 = (None, )
                else:
                    args0 = [
                        jax.lax.dynamic_index_in_dim(ys, 0, keepdims=False)
                        for ys in args_ys
                    ]
                x, z = self._call_wrapped(x, *args0)
                if z is None:
                    return x, z

                # Broadcast state to hold each layer state.
                def broadcast_state(layer_state):
                    return jnp.broadcast_to(layer_state, [
                        count,
                    ] + list(layer_state.shape))

                zs = jax.tree_util.tree_map(broadcast_state, z)
                return x, zs
        else:
            # Use scan during apply, threading through random seed so that it's
            # unique for each layer.
            def layer(carry: LayerStackCarry, scanned: LayerStackScanned):
                rng = carry.rng

                def getter(next_getter, value, context):
                    # Getter slices the full param at the current loop index.
                    trailing_dims = len(context.original_shape) + 1
                    assert value.shape[value.ndim - trailing_dims] == count, (
                        f'Attempting to use a parameter stack of size '
                        f'{value.shape[value.ndim - trailing_dims]} for a LayerStack of '
                        f'size {count}.')

                    sliced_value = jax.lax.dynamic_index_in_dim(
                        value,
                        scanned.i,
                        axis=value.ndim - trailing_dims,
                        keepdims=False)
                    return next_getter(sliced_value)

                with hk.experimental.custom_getter(getter):
                    if rng is None:
                        out_x, z = self._call_wrapped(carry.x,
                                                      *scanned.args_ys)
                    else:
                        rng, rng_ = jax.random.split(rng)
                        with hk.with_rng(rng_):
                            out_x, z = self._call_wrapped(
                                carry.x, *scanned.args_ys)
                return LayerStackCarry(x=out_x, rng=rng), z

            carry = LayerStackCarry(x=x, rng=hk.maybe_next_rng_key())
            scanned = LayerStackScanned(i=jnp.arange(count, dtype=jnp.int32),
                                        args_ys=args_ys)

            carry, zs = hk.scan(layer,
                                carry,
                                scanned,
                                length=count,
                                unroll=self._unroll)
            return carry.x, zs
Ejemplo n.º 6
0
  def mapped_fn(*args):
    # Expand in axes and Determine Loop range
    in_axes_ = _expand_axes(in_axes, args)

    in_sizes = jax.tree_multimap(_maybe_get_size, args, in_axes_)
    flat_sizes = jax.tree_flatten(in_sizes)[0]
    in_size = max(flat_sizes)
    assert all(i in {in_size, -1} for i in flat_sizes)

    num_extra_shards = (in_size - 1) // shard_size

    # Fix Up if necessary
    last_shard_size = in_size % shard_size
    last_shard_size = shard_size if last_shard_size == 0 else last_shard_size

    def apply_fun_to_slice(slice_start, slice_size):
      input_slice = jax.tree_multimap(
          lambda array, axis: _maybe_slice(array, slice_start, slice_size, axis
                                          ), args, in_axes_)
      return fun(*input_slice)

    remainder_shape_dtype = hk.eval_shape(
        partial(apply_fun_to_slice, 0, last_shard_size))
    out_dtypes = jax.tree_map(lambda x: x.dtype, remainder_shape_dtype)
    out_shapes = jax.tree_map(lambda x: x.shape, remainder_shape_dtype)
    out_axes_ = _expand_axes(out_axes, remainder_shape_dtype)

    if num_extra_shards > 0:
      regular_shard_shape_dtype = hk.eval_shape(
          partial(apply_fun_to_slice, 0, shard_size))
      shard_shapes = jax.tree_map(lambda x: x.shape, regular_shard_shape_dtype)

      def make_output_shape(axis, shard_shape, remainder_shape):
        return shard_shape[:axis] + (
            shard_shape[axis] * num_extra_shards +
            remainder_shape[axis],) + shard_shape[axis + 1:]

      out_shapes = jax.tree_multimap(make_output_shape, out_axes_, shard_shapes,
                                     out_shapes)

    # Calls dynamic Update slice with different argument order
    # This is here since tree_multimap only works with positional arguments
    def dynamic_update_slice_in_dim(full_array, update, axis, i):
      return jax.lax.dynamic_update_slice_in_dim(full_array, update, i, axis)

    def compute_shard(outputs, slice_start, slice_size):
      slice_out = apply_fun_to_slice(slice_start, slice_size)
      update_slice = partial(
          dynamic_update_slice_in_dim, i=slice_start)
      return jax.tree_multimap(update_slice, outputs, slice_out, out_axes_)

    def scan_iteration(outputs, i):
      new_outputs = compute_shard(outputs, i, shard_size)
      return new_outputs, ()

    slice_starts = jnp.arange(0, in_size - shard_size + 1, shard_size)

    def allocate_buffer(dtype, shape):
      return jnp.zeros(shape, dtype=dtype)

    outputs = jax.tree_multimap(allocate_buffer, out_dtypes, out_shapes)

    if slice_starts.shape[0] > 0:
      outputs, _ = hk.scan(scan_iteration, outputs, slice_starts)

    if last_shard_size != shard_size:
      remainder_start = in_size - last_shard_size
      outputs = compute_shard(outputs, remainder_start, last_shard_size)

    return outputs