Exemplo n.º 1
0
 def __call__(self, tangent_func: phase_space.SymplecticTangentFunction,
              t: jnp.ndarray, y: phase_space.PhaseSpace,
              dt: jnp.ndarray) -> phase_space.PhaseSpace:
     q, p = y.q, y.p
     # This is intentional to prevent a bug where one uses y later
     del y
     # We always broadcast opposite to numpy (e.g. leading dims (batch) count)
     if dt.ndim > 0:
         dt = dt.reshape(dt.shape + (1, ) * (q.ndim - dt.ndim))
     if t.ndim > 0:
         t = t.reshape(t.shape + (1, ) * (q.ndim - t.ndim))
     t_q = t
     t_p = t
     for c, d in zip(self.momentum_coefficients,
                     self.position_coefficients):
         # Update momentum
         if c != 0.0:
             dp_dt = tangent_func(t_p, phase_space.PhaseSpace(q, p)).p
             p = p + c * dt * dp_dt
             t_p = t_p + c * dt
         # Update position
         if d != 0.0:
             dq_dt = tangent_func(t_q, phase_space.PhaseSpace(q, p)).q
             q = q + d * dt * dq_dt
             t_q = t_q + d * dt
     return phase_space.PhaseSpace(position=q, momentum=p)
Exemplo n.º 2
0
 def local_lagrangian(*q_and_q_dot):
     # We take the sum so we can easily take gradients
     return jnp.sum(
         self.lagrangian(phase_space.PhaseSpace(*q_and_q_dot),
                         **kwargs))
Exemplo n.º 3
0
    def simulate(self,
                 y0: phase_space.PhaseSpace,
                 dt: Union[float, jnp.ndarray],
                 num_steps_forward: int,
                 num_steps_backward: int,
                 include_y0: bool,
                 return_stats: bool = True,
                 **nets_kwargs) -> _PhysicsSimulationOutput:
        """Simulates the continuous dynamics of the physical system.

    Args:
      y0: Initial state of the system.
      dt: The size of the time intervals at which to evolve the system.
      num_steps_forward: Number of steps to make into the future.
      num_steps_backward: Number of steps to make into the past.
      include_y0: Whether to include the initial state in the result.
      return_stats: Whether to return additional statistics.
      **nets_kwargs: Keyword arguments to pass to the networks.

    Returns:
      * The state of the system evolved as many steps as specified by the
      arguments into the past and future, all in chronological order.
      * Optionally return a dictionary of additional statistics. For the moment
        this only returns the energy of the system at each evaluation point.
    """
        # Define the dynamics
        if self.simulation_space == "velocity":
            dy_dt = lambda t_, y: self.velocity_and_acceleration(  # pylint: disable=g-long-lambda
                y.q, y.p, **nets_kwargs)
            # Special Haiku magic to avoid tracer issues
            if hk.running_init():
                return self.lagrangian(y0, **nets_kwargs)
        else:
            hamiltonian = lambda t_, y: self.hamiltonian(y, **nets_kwargs)
            dy_dt = phase_space.poisson_bracket_with_q_and_p(hamiltonian)
            if hk.running_init():
                return self.hamiltonian(y0, **nets_kwargs)

        # Optionally switch coordinate frame
        if self.input_space == "velocity" and self.simulation_space == "momentum":
            p = self.momentum_from_velocity(y0.q, y0.p, **nets_kwargs)
            y0 = phase_space.PhaseSpace(y0.q, p)
        if self.input_space == "momentum" and self.simulation_space == "velocity":
            q_dot = self.velocity_from_momentum(y0.q, y0.p, **nets_kwargs)
            y0 = phase_space.PhaseSpace(y0.q, q_dot)

        yt = integrators.solve_ivp_dt_two_directions(
            fun=dy_dt,
            y0=y0,
            t0=0.0,
            dt=dt,
            method=self.integrator_method,
            num_steps_forward=num_steps_forward,
            num_steps_backward=num_steps_backward,
            include_y0=include_y0,
            steps_per_dt=self.steps_per_dt,
            ode_int_kwargs=self.ode_int_kwargs)
        # Make time axis second
        yt = jax.tree_map(lambda x: jnp.swapaxes(x, 0, 1), yt)

        # Compute energies for the full trajectory
        yt_energy = jax.tree_map(utils.merge_first_dims, yt)
        if self.simulation_space == "momentum":
            energy = self.energy_from_momentum(yt_energy, **nets_kwargs)
        else:
            energy = self.energy_from_velocity(yt_energy, **nets_kwargs)
        energy = energy.reshape(yt.q.shape[:2])

        # Optionally switch back to input coordinate frame
        if self.input_space == "velocity" and self.simulation_space == "momentum":
            q_dot = self.velocity_from_momentum(yt.q, yt.p, **nets_kwargs)
            yt = phase_space.PhaseSpace(yt.q, q_dot)
        if self.input_space == "momentum" and self.simulation_space == "velocity":
            p = self.momentum_from_velocity(yt.q, yt.p, **nets_kwargs)
            yt = phase_space.PhaseSpace(yt.q, p)

        # Compute energy deficit
        t = energy.shape[-1]
        non_zero_diffs = float((t * (t - 1)) // 2)
        energy_deficits = jnp.abs(energy[..., None, :] - energy[..., None])
        avg_deficit = jnp.sum(energy_deficits, axis=(-2, -1)) / non_zero_diffs
        max_deficit = jnp.max(energy_deficits)

        # Return the states and energies
        if return_stats:
            return yt, dict(avg_energy_deficit=avg_deficit,
                            max_energy_deficit=max_deficit)
        else:
            return yt
Exemplo n.º 4
0
 def local_hamiltonian(p_):
     # We take the sum so we can easily take gradients
     return jnp.sum(
         self.hamiltonian(phase_space.PhaseSpace(q, p_), **kwargs))