Beispiel #1
0
 def single_mlp(inner_name: str):
     """Creates a single MLP performing the update."""
     mlp = hk.nets.MLP(output_sizes=output_sizes,
                       name=inner_name,
                       activation=activation)
     mlp = jraph.concatenated_args(mlp)
     if normalization_type == 'layer_norm':
         norm = hk.LayerNorm(axis=-1,
                             create_scale=True,
                             create_offset=True,
                             name=name + '_layer_norm')
     elif normalization_type == 'batch_norm':
         batch_norm = hk.BatchNorm(
             create_scale=True,
             create_offset=True,
             decay_rate=0.9,
             name=f'{inner_name}_batch_norm',
             cross_replica_axis=None if hk.running_init() else 'i',
         )
         norm = lambda x: batch_norm(x, is_training)
     elif normalization_type == 'none':
         return mlp
     else:
         raise ValueError(
             f'Unknown normalization type {normalization_type}')
     return jraph.concatenated_args(hk.Sequential([mlp, norm]))
Beispiel #2
0
    def simulate(
        self,
        y0: jnp.ndarray,
        dt: Union[float, jnp.ndarray],
        num_steps_forward: int,
        num_steps_backward: int,
        include_y0: bool,
        return_stats: bool = True,
        **nets_kwargs
    ) -> Union[jnp.ndarray, Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]]:
        """Simulates the continuous dynamics of the ODE specified by the network.

    Args:
      y0: Initial state of the system.
      dt: The size of the time intervals at which to evolve the system.
      num_steps_forward: Number of steps to make into the future.
      num_steps_backward: Number of steps to make into the past.
      include_y0: Whether to include the initial state in the result.
      return_stats: Whether to return additional statistics.
      **nets_kwargs: Keyword arguments to pass to the networks.

    Returns:
      * The state of the system evolved as many steps as specified by the
      arguments into the past and future, all in chronological order.
      * Optionally return a dictionary of additional statistics. For the moment
        this is just an empty dictionary.
    """
        if hk.running_init():
            return self.core(y0, **nets_kwargs)
        yt = integrators.solve_ivp_dt_two_directions(
            fun=lambda t, y: self.core(y, **nets_kwargs),
            y0=y0,
            t0=0.0,
            dt=dt,
            method=self.integrator_method,
            num_steps_forward=num_steps_forward,
            num_steps_backward=num_steps_backward,
            include_y0=include_y0,
            steps_per_dt=self.steps_per_dt,
            ode_int_kwargs=self.ode_int_kwargs)
        # Make time axis second
        yt = jax.tree_map(lambda x: jnp.swapaxes(x, 0, 1), yt)
        if return_stats:
            return yt, dict()
        else:
            return yt
Beispiel #3
0
    def simulate(
        self,
        y0: jnp.ndarray,
        num_steps_forward: int,
        include_y0: bool,
        return_stats: bool = True,
        **nets_kwargs
    ) -> Union[jnp.ndarray, Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]]:
        """Simulates the dynamics of the discrete system.

    Args:
      y0: Initial state of the system.
      num_steps_forward: Number of steps to make into the future.
      include_y0: Whether to include the initial state in the result.
      return_stats: Whether to return additional statistics.
      **nets_kwargs: Keyword arguments to pass to the networks.

    Returns:
      * The state of the system evolved as many steps as specified by the
      arguments into the past and future, all in chronological order.
      * Optionally return a dictionary of additional statistics. For the moment
        this is just an empty dictionary.
    """
        if num_steps_forward < 0:
            raise ValueError("It is required to unroll at least one step.")
        nets_kwargs.pop("dt", None)
        nets_kwargs.pop("num_steps_backward", None)
        if hk.running_init():
            return self.core(y0, **nets_kwargs)

        def step(*args):
            y, _ = args
            if self.residual:
                y_next = y + self.core(y, **nets_kwargs)
            else:
                y_next = self.core(y, **nets_kwargs)
            return y_next, y_next

        if self.use_scan:
            _, yt = jax.lax.scan(step,
                                 init=y0,
                                 xs=None,
                                 length=num_steps_forward)
            if include_y0:
                yt = jnp.concatenate([y0[None], yt], axis=0)
            # Make time axis second
            yt = jax.tree_map(lambda x: jnp.swapaxes(x, 0, 1), yt)
        else:
            yt = [y0]
            for _ in range(num_steps_forward):
                yt.append(step(yt[-1], None)[0])
            if not include_y0:
                yt = yt[1:]
            if len(yt) == 1:
                yt = yt[0][:, None]
            else:
                yt = jax.tree_multimap(lambda args: jnp.stack(args, 1), yt)
        if return_stats:
            return yt, dict()
        else:
            return yt
Beispiel #4
0
    def simulate(self,
                 y0: phase_space.PhaseSpace,
                 dt: Union[float, jnp.ndarray],
                 num_steps_forward: int,
                 num_steps_backward: int,
                 include_y0: bool,
                 return_stats: bool = True,
                 **nets_kwargs) -> _PhysicsSimulationOutput:
        """Simulates the continuous dynamics of the physical system.

    Args:
      y0: Initial state of the system.
      dt: The size of the time intervals at which to evolve the system.
      num_steps_forward: Number of steps to make into the future.
      num_steps_backward: Number of steps to make into the past.
      include_y0: Whether to include the initial state in the result.
      return_stats: Whether to return additional statistics.
      **nets_kwargs: Keyword arguments to pass to the networks.

    Returns:
      * The state of the system evolved as many steps as specified by the
      arguments into the past and future, all in chronological order.
      * Optionally return a dictionary of additional statistics. For the moment
        this only returns the energy of the system at each evaluation point.
    """
        # Define the dynamics
        if self.simulation_space == "velocity":
            dy_dt = lambda t_, y: self.velocity_and_acceleration(  # pylint: disable=g-long-lambda
                y.q, y.p, **nets_kwargs)
            # Special Haiku magic to avoid tracer issues
            if hk.running_init():
                return self.lagrangian(y0, **nets_kwargs)
        else:
            hamiltonian = lambda t_, y: self.hamiltonian(y, **nets_kwargs)
            dy_dt = phase_space.poisson_bracket_with_q_and_p(hamiltonian)
            if hk.running_init():
                return self.hamiltonian(y0, **nets_kwargs)

        # Optionally switch coordinate frame
        if self.input_space == "velocity" and self.simulation_space == "momentum":
            p = self.momentum_from_velocity(y0.q, y0.p, **nets_kwargs)
            y0 = phase_space.PhaseSpace(y0.q, p)
        if self.input_space == "momentum" and self.simulation_space == "velocity":
            q_dot = self.velocity_from_momentum(y0.q, y0.p, **nets_kwargs)
            y0 = phase_space.PhaseSpace(y0.q, q_dot)

        yt = integrators.solve_ivp_dt_two_directions(
            fun=dy_dt,
            y0=y0,
            t0=0.0,
            dt=dt,
            method=self.integrator_method,
            num_steps_forward=num_steps_forward,
            num_steps_backward=num_steps_backward,
            include_y0=include_y0,
            steps_per_dt=self.steps_per_dt,
            ode_int_kwargs=self.ode_int_kwargs)
        # Make time axis second
        yt = jax.tree_map(lambda x: jnp.swapaxes(x, 0, 1), yt)

        # Compute energies for the full trajectory
        yt_energy = jax.tree_map(utils.merge_first_dims, yt)
        if self.simulation_space == "momentum":
            energy = self.energy_from_momentum(yt_energy, **nets_kwargs)
        else:
            energy = self.energy_from_velocity(yt_energy, **nets_kwargs)
        energy = energy.reshape(yt.q.shape[:2])

        # Optionally switch back to input coordinate frame
        if self.input_space == "velocity" and self.simulation_space == "momentum":
            q_dot = self.velocity_from_momentum(yt.q, yt.p, **nets_kwargs)
            yt = phase_space.PhaseSpace(yt.q, q_dot)
        if self.input_space == "momentum" and self.simulation_space == "velocity":
            p = self.momentum_from_velocity(yt.q, yt.p, **nets_kwargs)
            yt = phase_space.PhaseSpace(yt.q, p)

        # Compute energy deficit
        t = energy.shape[-1]
        non_zero_diffs = float((t * (t - 1)) // 2)
        energy_deficits = jnp.abs(energy[..., None, :] - energy[..., None])
        avg_deficit = jnp.sum(energy_deficits, axis=(-2, -1)) / non_zero_diffs
        max_deficit = jnp.max(energy_deficits)

        # Return the states and energies
        if return_stats:
            return yt, dict(avg_energy_deficit=avg_deficit,
                            max_energy_deficit=max_deficit)
        else:
            return yt
Beispiel #5
0
    def __call__(self, x, *args_ys):
        count = self._count
        if hk.running_init():
            # At initialization time, we run just one layer but add an extra first
            # dimension to every initialized tensor, making sure to use different
            # random keys for different slices.
            def creator(next_creator, shape, dtype, init, context):
                del context

                def multi_init(shape, dtype):
                    assert shape[0] == count
                    key = hk.maybe_next_rng_key()

                    def rng_context_init(slice_idx):
                        slice_key = maybe_fold_in(key, slice_idx)
                        with maybe_with_rng(slice_key):
                            return init(shape[1:], dtype)

                    return jax.vmap(rng_context_init)(jnp.arange(count))

                return next_creator((count, ) + tuple(shape), dtype,
                                    multi_init)

            def getter(next_getter, value, context):
                trailing_dims = len(context.original_shape) + 1
                sliced_value = jax.lax.index_in_dim(value,
                                                    index=0,
                                                    axis=value.ndim -
                                                    trailing_dims,
                                                    keepdims=False)
                return next_getter(sliced_value)

            with hk.experimental.custom_creator(
                    creator), hk.experimental.custom_getter(getter):
                if len(args_ys) == 1 and args_ys[0] is None:
                    args0 = (None, )
                else:
                    args0 = [
                        jax.lax.dynamic_index_in_dim(ys, 0, keepdims=False)
                        for ys in args_ys
                    ]
                x, z = self._call_wrapped(x, *args0)
                if z is None:
                    return x, z

                # Broadcast state to hold each layer state.
                def broadcast_state(layer_state):
                    return jnp.broadcast_to(layer_state, [
                        count,
                    ] + list(layer_state.shape))

                zs = jax.tree_util.tree_map(broadcast_state, z)
                return x, zs
        else:
            # Use scan during apply, threading through random seed so that it's
            # unique for each layer.
            def layer(carry: LayerStackCarry, scanned: LayerStackScanned):
                rng = carry.rng

                def getter(next_getter, value, context):
                    # Getter slices the full param at the current loop index.
                    trailing_dims = len(context.original_shape) + 1
                    assert value.shape[value.ndim - trailing_dims] == count, (
                        f'Attempting to use a parameter stack of size '
                        f'{value.shape[value.ndim - trailing_dims]} for a LayerStack of '
                        f'size {count}.')

                    sliced_value = jax.lax.dynamic_index_in_dim(
                        value,
                        scanned.i,
                        axis=value.ndim - trailing_dims,
                        keepdims=False)
                    return next_getter(sliced_value)

                with hk.experimental.custom_getter(getter):
                    if rng is None:
                        out_x, z = self._call_wrapped(carry.x,
                                                      *scanned.args_ys)
                    else:
                        rng, rng_ = jax.random.split(rng)
                        with hk.with_rng(rng_):
                            out_x, z = self._call_wrapped(
                                carry.x, *scanned.args_ys)
                return LayerStackCarry(x=out_x, rng=rng), z

            carry = LayerStackCarry(x=x, rng=hk.maybe_next_rng_key())
            scanned = LayerStackScanned(i=jnp.arange(count, dtype=jnp.int32),
                                        args_ys=args_ys)

            carry, zs = hk.scan(layer,
                                carry,
                                scanned,
                                length=count,
                                unroll=self._unroll)
            return carry.x, zs