def single_mlp(inner_name: str): """Creates a single MLP performing the update.""" mlp = hk.nets.MLP(output_sizes=output_sizes, name=inner_name, activation=activation) mlp = jraph.concatenated_args(mlp) if normalization_type == 'layer_norm': norm = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True, name=name + '_layer_norm') elif normalization_type == 'batch_norm': batch_norm = hk.BatchNorm( create_scale=True, create_offset=True, decay_rate=0.9, name=f'{inner_name}_batch_norm', cross_replica_axis=None if hk.running_init() else 'i', ) norm = lambda x: batch_norm(x, is_training) elif normalization_type == 'none': return mlp else: raise ValueError( f'Unknown normalization type {normalization_type}') return jraph.concatenated_args(hk.Sequential([mlp, norm]))
def simulate( self, y0: jnp.ndarray, dt: Union[float, jnp.ndarray], num_steps_forward: int, num_steps_backward: int, include_y0: bool, return_stats: bool = True, **nets_kwargs ) -> Union[jnp.ndarray, Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]]: """Simulates the continuous dynamics of the ODE specified by the network. Args: y0: Initial state of the system. dt: The size of the time intervals at which to evolve the system. num_steps_forward: Number of steps to make into the future. num_steps_backward: Number of steps to make into the past. include_y0: Whether to include the initial state in the result. return_stats: Whether to return additional statistics. **nets_kwargs: Keyword arguments to pass to the networks. Returns: * The state of the system evolved as many steps as specified by the arguments into the past and future, all in chronological order. * Optionally return a dictionary of additional statistics. For the moment this is just an empty dictionary. """ if hk.running_init(): return self.core(y0, **nets_kwargs) yt = integrators.solve_ivp_dt_two_directions( fun=lambda t, y: self.core(y, **nets_kwargs), y0=y0, t0=0.0, dt=dt, method=self.integrator_method, num_steps_forward=num_steps_forward, num_steps_backward=num_steps_backward, include_y0=include_y0, steps_per_dt=self.steps_per_dt, ode_int_kwargs=self.ode_int_kwargs) # Make time axis second yt = jax.tree_map(lambda x: jnp.swapaxes(x, 0, 1), yt) if return_stats: return yt, dict() else: return yt
def simulate( self, y0: jnp.ndarray, num_steps_forward: int, include_y0: bool, return_stats: bool = True, **nets_kwargs ) -> Union[jnp.ndarray, Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]]: """Simulates the dynamics of the discrete system. Args: y0: Initial state of the system. num_steps_forward: Number of steps to make into the future. include_y0: Whether to include the initial state in the result. return_stats: Whether to return additional statistics. **nets_kwargs: Keyword arguments to pass to the networks. Returns: * The state of the system evolved as many steps as specified by the arguments into the past and future, all in chronological order. * Optionally return a dictionary of additional statistics. For the moment this is just an empty dictionary. """ if num_steps_forward < 0: raise ValueError("It is required to unroll at least one step.") nets_kwargs.pop("dt", None) nets_kwargs.pop("num_steps_backward", None) if hk.running_init(): return self.core(y0, **nets_kwargs) def step(*args): y, _ = args if self.residual: y_next = y + self.core(y, **nets_kwargs) else: y_next = self.core(y, **nets_kwargs) return y_next, y_next if self.use_scan: _, yt = jax.lax.scan(step, init=y0, xs=None, length=num_steps_forward) if include_y0: yt = jnp.concatenate([y0[None], yt], axis=0) # Make time axis second yt = jax.tree_map(lambda x: jnp.swapaxes(x, 0, 1), yt) else: yt = [y0] for _ in range(num_steps_forward): yt.append(step(yt[-1], None)[0]) if not include_y0: yt = yt[1:] if len(yt) == 1: yt = yt[0][:, None] else: yt = jax.tree_multimap(lambda args: jnp.stack(args, 1), yt) if return_stats: return yt, dict() else: return yt
def simulate(self, y0: phase_space.PhaseSpace, dt: Union[float, jnp.ndarray], num_steps_forward: int, num_steps_backward: int, include_y0: bool, return_stats: bool = True, **nets_kwargs) -> _PhysicsSimulationOutput: """Simulates the continuous dynamics of the physical system. Args: y0: Initial state of the system. dt: The size of the time intervals at which to evolve the system. num_steps_forward: Number of steps to make into the future. num_steps_backward: Number of steps to make into the past. include_y0: Whether to include the initial state in the result. return_stats: Whether to return additional statistics. **nets_kwargs: Keyword arguments to pass to the networks. Returns: * The state of the system evolved as many steps as specified by the arguments into the past and future, all in chronological order. * Optionally return a dictionary of additional statistics. For the moment this only returns the energy of the system at each evaluation point. """ # Define the dynamics if self.simulation_space == "velocity": dy_dt = lambda t_, y: self.velocity_and_acceleration( # pylint: disable=g-long-lambda y.q, y.p, **nets_kwargs) # Special Haiku magic to avoid tracer issues if hk.running_init(): return self.lagrangian(y0, **nets_kwargs) else: hamiltonian = lambda t_, y: self.hamiltonian(y, **nets_kwargs) dy_dt = phase_space.poisson_bracket_with_q_and_p(hamiltonian) if hk.running_init(): return self.hamiltonian(y0, **nets_kwargs) # Optionally switch coordinate frame if self.input_space == "velocity" and self.simulation_space == "momentum": p = self.momentum_from_velocity(y0.q, y0.p, **nets_kwargs) y0 = phase_space.PhaseSpace(y0.q, p) if self.input_space == "momentum" and self.simulation_space == "velocity": q_dot = self.velocity_from_momentum(y0.q, y0.p, **nets_kwargs) y0 = phase_space.PhaseSpace(y0.q, q_dot) yt = integrators.solve_ivp_dt_two_directions( fun=dy_dt, y0=y0, t0=0.0, dt=dt, method=self.integrator_method, num_steps_forward=num_steps_forward, num_steps_backward=num_steps_backward, include_y0=include_y0, steps_per_dt=self.steps_per_dt, ode_int_kwargs=self.ode_int_kwargs) # Make time axis second yt = jax.tree_map(lambda x: jnp.swapaxes(x, 0, 1), yt) # Compute energies for the full trajectory yt_energy = jax.tree_map(utils.merge_first_dims, yt) if self.simulation_space == "momentum": energy = self.energy_from_momentum(yt_energy, **nets_kwargs) else: energy = self.energy_from_velocity(yt_energy, **nets_kwargs) energy = energy.reshape(yt.q.shape[:2]) # Optionally switch back to input coordinate frame if self.input_space == "velocity" and self.simulation_space == "momentum": q_dot = self.velocity_from_momentum(yt.q, yt.p, **nets_kwargs) yt = phase_space.PhaseSpace(yt.q, q_dot) if self.input_space == "momentum" and self.simulation_space == "velocity": p = self.momentum_from_velocity(yt.q, yt.p, **nets_kwargs) yt = phase_space.PhaseSpace(yt.q, p) # Compute energy deficit t = energy.shape[-1] non_zero_diffs = float((t * (t - 1)) // 2) energy_deficits = jnp.abs(energy[..., None, :] - energy[..., None]) avg_deficit = jnp.sum(energy_deficits, axis=(-2, -1)) / non_zero_diffs max_deficit = jnp.max(energy_deficits) # Return the states and energies if return_stats: return yt, dict(avg_energy_deficit=avg_deficit, max_energy_deficit=max_deficit) else: return yt
def __call__(self, x, *args_ys): count = self._count if hk.running_init(): # At initialization time, we run just one layer but add an extra first # dimension to every initialized tensor, making sure to use different # random keys for different slices. def creator(next_creator, shape, dtype, init, context): del context def multi_init(shape, dtype): assert shape[0] == count key = hk.maybe_next_rng_key() def rng_context_init(slice_idx): slice_key = maybe_fold_in(key, slice_idx) with maybe_with_rng(slice_key): return init(shape[1:], dtype) return jax.vmap(rng_context_init)(jnp.arange(count)) return next_creator((count, ) + tuple(shape), dtype, multi_init) def getter(next_getter, value, context): trailing_dims = len(context.original_shape) + 1 sliced_value = jax.lax.index_in_dim(value, index=0, axis=value.ndim - trailing_dims, keepdims=False) return next_getter(sliced_value) with hk.experimental.custom_creator( creator), hk.experimental.custom_getter(getter): if len(args_ys) == 1 and args_ys[0] is None: args0 = (None, ) else: args0 = [ jax.lax.dynamic_index_in_dim(ys, 0, keepdims=False) for ys in args_ys ] x, z = self._call_wrapped(x, *args0) if z is None: return x, z # Broadcast state to hold each layer state. def broadcast_state(layer_state): return jnp.broadcast_to(layer_state, [ count, ] + list(layer_state.shape)) zs = jax.tree_util.tree_map(broadcast_state, z) return x, zs else: # Use scan during apply, threading through random seed so that it's # unique for each layer. def layer(carry: LayerStackCarry, scanned: LayerStackScanned): rng = carry.rng def getter(next_getter, value, context): # Getter slices the full param at the current loop index. trailing_dims = len(context.original_shape) + 1 assert value.shape[value.ndim - trailing_dims] == count, ( f'Attempting to use a parameter stack of size ' f'{value.shape[value.ndim - trailing_dims]} for a LayerStack of ' f'size {count}.') sliced_value = jax.lax.dynamic_index_in_dim( value, scanned.i, axis=value.ndim - trailing_dims, keepdims=False) return next_getter(sliced_value) with hk.experimental.custom_getter(getter): if rng is None: out_x, z = self._call_wrapped(carry.x, *scanned.args_ys) else: rng, rng_ = jax.random.split(rng) with hk.with_rng(rng_): out_x, z = self._call_wrapped( carry.x, *scanned.args_ys) return LayerStackCarry(x=out_x, rng=rng), z carry = LayerStackCarry(x=x, rng=hk.maybe_next_rng_key()) scanned = LayerStackScanned(i=jnp.arange(count, dtype=jnp.int32), args_ys=args_ys) carry, zs = hk.scan(layer, carry, scanned, length=count, unroll=self._unroll) return carry.x, zs