def __call__(self, input, hidden): def update(state, x): nhid = self.nhid h, cell = state[0], state[1] h = h.squeeze() cell = cell.squeeze() x_components = self.i2h(x) h_components = self.h2h(h) preactivations = x_components + h_components gates_together = jax.nn.sigmoid(preactivations[:, 0:3 * nhid]) forget_gate = gates_together[:, 0:nhid] input_gate = gates_together[:, nhid:2 * nhid] output_gate = gates_together[:, 2 * nhid:3 * nhid] new_cell = jnp.tanh(preactivations[:, 3 * nhid:4 * nhid]) cell = forget_gate * cell + input_gate * new_cell h = output_gate * jnp.tanh(cell) new_state = jnp.stack([h, cell]) return new_state, h new_state, hidden_list = hk.scan(update, hidden, input) hidden_stacked = jnp.stack(hidden_list) return hidden_stacked, new_state
def u_f(xs): mod = module_fn() def s(carry, x): y = mod(x) return carry, y _, ys = hk.scan(s, (), xs) return ys
def __call__(self, x_0: Array, time: float) -> Array: rng_beta = hk.next_rng_key() # Vector of zeros beta_mean_vector = jnp.zeros((self.beta_dims * 2, )) # Covariance matrix for the betas and gammas beta_covariance_top_left = self.delta_t**3 / 3 * jnp.eye( self.beta_dims) beta_covariance_top_right = self.delta_t**2 / 2 * jnp.eye( self.beta_dims) beta_covariance_bottom_right = self.delta_t * jnp.eye(self.beta_dims) beta_covariance_top = jnp.concatenate( (beta_covariance_top_left, beta_covariance_top_right), axis=1) beta_covariance_bottom = jnp.concatenate( (beta_covariance_top_right, beta_covariance_bottom_right), axis=1) beta_covariance = jnp.concatenate( (beta_covariance_top, beta_covariance_bottom), axis=0) delta_gamma_beta = numpyro.sample( "delta_gamma_beta", MultivariateNormal(loc=beta_mean_vector, covariance_matrix=beta_covariance), rng_key=rng_beta) delta_gamma = delta_gamma_beta[:self.beta_dims] delta_beta = delta_gamma_beta[self.beta_dims:] drift_0 = self.drift(x_0, time) diff = self.diffusion(x_0, time) diff_plus = self.diffusion(x_0, time + self.delta_t) init_x_1 = x_0 + drift_0 * self.delta_t + jnp.matmul(diff, delta_beta) init_x_1 += 1. / self.delta_t * jnp.matmul( diff_plus - diff, delta_beta * self.delta_t - delta_gamma) def scan_fn(carry, s): x_1 = carry x_0_plus = \ x_0 + drift_0 * self.delta_t / self.beta_dims + \ diff[:, s] * jnp.sqrt(self.delta_t) x_0_minus = \ x_0 + drift_0 * self.delta_t / self.beta_dims - \ diff[:, s] * jnp.sqrt(self.delta_t) drift_0_plus = self.drift(x_0_plus, time + self.delta_t) drift_0_minus = self.drift(x_0_minus, time + self.delta_t) x_1 += 0.25 * self.delta_t * (drift_0_plus + drift_0_minus) x_1 -= 0.5 * drift_0 * self.delta_t x_1 += \ 1. / (2 * jnp.sqrt(self.delta_t)) * (drift_0_plus - drift_0_minus) * delta_gamma[s] return x_1, None final_x_1, _ = hk.scan(scan_fn, init_x_1, jnp.arange(self.beta_dims)) return final_x_1
def __call__(self, x_0: Array, y_values: Array, t_mask: Array, training: int, t_0: float = 0.) -> Tuple[Array, Array, Array, Array]: def scan_fn(carry, it): x_t, y_t_path, y_t_generated, t = carry y_t_true, mask = it x_t_new = self.solver_x(x_t, t) d_x = x_t_new - x_t first_derivative = self.mapping.first_derivative(x_t, t) d_y = \ self.mapping.time_derivative(x_t, t)(jnp.array([1.])) * self.delta_t + \ first_derivative(d_x) + \ 0.5 * jnp.einsum("bc,c->b", self.mapping.hessian(x_t, t)(d_x), d_x) y_to_use = \ (mask * y_t_true + jnp.abs(mask - 1) * y_t_path) * training + \ jnp.abs(training - 1) * y_t_generated noise = self.brownian_noise() y_t_new = y_to_use + d_y + jnp.abs(training - 1) * noise y_t_generated_new = y_t_generated + d_y + noise t = t + self.delta_t return (x_t_new, y_t_new, y_t_generated_new, t), \ (x_t_new, y_t_new, y_t_generated_new, t) _, (final_paths_x, final_paths_y, final_paths_y_generated, final_t_seq) = \ hk.scan(scan_fn, (x_0, y_values[0], y_values[0], t_0), (y_values[:-1], t_mask[:-1])) t_seq = jnp.tile(jnp.array([t_0]), [y_values.shape[0]]) t_seq = t_seq.at[1:].set(final_t_seq) paths_x = jnp.tile(jnp.expand_dims(x_0, axis=0), [y_values.shape[0], 1]) paths_x = paths_x.at[1:].set(final_paths_x) paths_y = jnp.tile(jnp.expand_dims(y_values[0], axis=0), [y_values.shape[0], 1]) paths_y_generated = jnp.tile(jnp.expand_dims(y_values[0], axis=0), [y_values.shape[0], 1]) paths_y = paths_y.at[1:].set(final_paths_y) paths_y_generated = paths_y_generated.at[1:].set( final_paths_y_generated) return t_seq, paths_x, paths_y, paths_y_generated
def __call__(self, x, *args_ys): count = self._count if hk.running_init(): # At initialization time, we run just one layer but add an extra first # dimension to every initialized tensor, making sure to use different # random keys for different slices. def creator(next_creator, shape, dtype, init, context): del context def multi_init(shape, dtype): assert shape[0] == count key = hk.maybe_next_rng_key() def rng_context_init(slice_idx): slice_key = maybe_fold_in(key, slice_idx) with maybe_with_rng(slice_key): return init(shape[1:], dtype) return jax.vmap(rng_context_init)(jnp.arange(count)) return next_creator((count, ) + tuple(shape), dtype, multi_init) def getter(next_getter, value, context): trailing_dims = len(context.original_shape) + 1 sliced_value = jax.lax.index_in_dim(value, index=0, axis=value.ndim - trailing_dims, keepdims=False) return next_getter(sliced_value) with hk.experimental.custom_creator( creator), hk.experimental.custom_getter(getter): if len(args_ys) == 1 and args_ys[0] is None: args0 = (None, ) else: args0 = [ jax.lax.dynamic_index_in_dim(ys, 0, keepdims=False) for ys in args_ys ] x, z = self._call_wrapped(x, *args0) if z is None: return x, z # Broadcast state to hold each layer state. def broadcast_state(layer_state): return jnp.broadcast_to(layer_state, [ count, ] + list(layer_state.shape)) zs = jax.tree_util.tree_map(broadcast_state, z) return x, zs else: # Use scan during apply, threading through random seed so that it's # unique for each layer. def layer(carry: LayerStackCarry, scanned: LayerStackScanned): rng = carry.rng def getter(next_getter, value, context): # Getter slices the full param at the current loop index. trailing_dims = len(context.original_shape) + 1 assert value.shape[value.ndim - trailing_dims] == count, ( f'Attempting to use a parameter stack of size ' f'{value.shape[value.ndim - trailing_dims]} for a LayerStack of ' f'size {count}.') sliced_value = jax.lax.dynamic_index_in_dim( value, scanned.i, axis=value.ndim - trailing_dims, keepdims=False) return next_getter(sliced_value) with hk.experimental.custom_getter(getter): if rng is None: out_x, z = self._call_wrapped(carry.x, *scanned.args_ys) else: rng, rng_ = jax.random.split(rng) with hk.with_rng(rng_): out_x, z = self._call_wrapped( carry.x, *scanned.args_ys) return LayerStackCarry(x=out_x, rng=rng), z carry = LayerStackCarry(x=x, rng=hk.maybe_next_rng_key()) scanned = LayerStackScanned(i=jnp.arange(count, dtype=jnp.int32), args_ys=args_ys) carry, zs = hk.scan(layer, carry, scanned, length=count, unroll=self._unroll) return carry.x, zs
def mapped_fn(*args): # Expand in axes and Determine Loop range in_axes_ = _expand_axes(in_axes, args) in_sizes = jax.tree_multimap(_maybe_get_size, args, in_axes_) flat_sizes = jax.tree_flatten(in_sizes)[0] in_size = max(flat_sizes) assert all(i in {in_size, -1} for i in flat_sizes) num_extra_shards = (in_size - 1) // shard_size # Fix Up if necessary last_shard_size = in_size % shard_size last_shard_size = shard_size if last_shard_size == 0 else last_shard_size def apply_fun_to_slice(slice_start, slice_size): input_slice = jax.tree_multimap( lambda array, axis: _maybe_slice(array, slice_start, slice_size, axis ), args, in_axes_) return fun(*input_slice) remainder_shape_dtype = hk.eval_shape( partial(apply_fun_to_slice, 0, last_shard_size)) out_dtypes = jax.tree_map(lambda x: x.dtype, remainder_shape_dtype) out_shapes = jax.tree_map(lambda x: x.shape, remainder_shape_dtype) out_axes_ = _expand_axes(out_axes, remainder_shape_dtype) if num_extra_shards > 0: regular_shard_shape_dtype = hk.eval_shape( partial(apply_fun_to_slice, 0, shard_size)) shard_shapes = jax.tree_map(lambda x: x.shape, regular_shard_shape_dtype) def make_output_shape(axis, shard_shape, remainder_shape): return shard_shape[:axis] + ( shard_shape[axis] * num_extra_shards + remainder_shape[axis],) + shard_shape[axis + 1:] out_shapes = jax.tree_multimap(make_output_shape, out_axes_, shard_shapes, out_shapes) # Calls dynamic Update slice with different argument order # This is here since tree_multimap only works with positional arguments def dynamic_update_slice_in_dim(full_array, update, axis, i): return jax.lax.dynamic_update_slice_in_dim(full_array, update, i, axis) def compute_shard(outputs, slice_start, slice_size): slice_out = apply_fun_to_slice(slice_start, slice_size) update_slice = partial( dynamic_update_slice_in_dim, i=slice_start) return jax.tree_multimap(update_slice, outputs, slice_out, out_axes_) def scan_iteration(outputs, i): new_outputs = compute_shard(outputs, i, shard_size) return new_outputs, () slice_starts = jnp.arange(0, in_size - shard_size + 1, shard_size) def allocate_buffer(dtype, shape): return jnp.zeros(shape, dtype=dtype) outputs = jax.tree_multimap(allocate_buffer, out_dtypes, out_shapes) if slice_starts.shape[0] > 0: outputs, _ = hk.scan(scan_iteration, outputs, slice_starts) if last_shard_size != shard_size: remainder_start = in_size - last_shard_size outputs = compute_shard(outputs, remainder_start, last_shard_size) return outputs