Пример #1
0
def test_odeint_adjoint(sensitivity, solver, interpolator, stiffness):
    f = VanDerPol(stiffness)
    x = torch.randn(1024, 2, requires_grad=True)
    t0 = time.time()
    prob = ODEProblem(f, sensitivity=sensitivity, interpolator=interpolator, solver=solver, atol=1e-4, rtol=1e-4, atol_adjoint=1e-4, rtol_adjoint=1e-4)
    t_eval, sol_torchdyn = prob.odeint(x, t_span)
    t_end1 = time.time() - t0

    t0 = time.time()
    sol_torchdiffeq = torchdiffeq.odeint_adjoint(f, x, t_span, method='dopri5', atol=1e-4, rtol=1e-4)
    t_end2 = time.time() - t0

    true_sol = torchdiffeq.odeint_adjoint(f, x, t_span, method='dopri5', atol=1e-9, rtol=1e-9)

    t0 = time.time()
    grad1 = torch.autograd.grad(sol_torchdyn[-1].sum(), x)[0]
    t_end1 = time.time() - t0

    t0 = time.time()
    grad2 = torch.autograd.grad(sol_torchdiffeq[-1].sum(), x)[0]
    t_end2 = time.time() - t0

    grad_true = torch.autograd.grad(true_sol[-1].sum(), x)[0]

    err1 = (grad1-grad_true).abs().sum(1)
    err2 = (grad2-grad_true).abs().sum(1)
    assert (err1 <= 1e-3).all() and (err1.mean() <= err2.mean())
Пример #2
0
    def test_seminorm(self):
        torch.manual_seed(3456786)  # test can be flaky
        for dtype in DTYPES:
            for device in DEVICES:
                for method in ADAPTIVE_METHODS:
                    if method == 'adaptive_heun':
                        # Adaptive heun is consistently an awful choice with seminorms, it seems. My guess is that it's
                        # consistently overconfident with its step sizes, and that having seminorms turned off means
                        # that it actually gets it right more often.
                        continue
                    if dtype == torch.float32 and method == 'dopri8':
                        continue

                    with self.subTest(dtype=dtype,
                                      device=device,
                                      method=method):

                        x0 = torch.tensor([1.0, 2.0],
                                          device=device,
                                          dtype=dtype)
                        t = torch.tensor([0., 1.0], device=device, dtype=dtype)

                        norm_f = _NeuralF(width=256,
                                          oscillate=True).to(device, dtype)
                        out = torchdiffeq.odeint_adjoint(norm_f,
                                                         x0,
                                                         t,
                                                         atol=3e-7,
                                                         method=method)
                        norm_f.nfe = 0
                        out.sum().backward()

                        seminorm_f = _NeuralF(width=256, oscillate=True).to(
                            device, dtype)
                        with torch.no_grad():
                            for norm_param, seminorm_param in zip(
                                    norm_f.parameters(),
                                    seminorm_f.parameters()):
                                seminorm_param.copy_(norm_param)
                        out = torchdiffeq.odeint_adjoint(
                            seminorm_f,
                            x0,
                            t,
                            atol=1e-6,
                            method=method,
                            adjoint_options=dict(norm='seminorm'))
                        seminorm_f.nfe = 0
                        out.sum().backward()

                        self.assertLessEqual(seminorm_f.nfe, norm_f.nfe)
Пример #3
0
 def forward(self, x):
     self.integration_time = self.integration_time.type_as(x)
     if self.adjoint:
         out = odeint_adjoint(self.odefunc, x, self.integration_time, rtol=1e-4, atol=1e-4)
     else:
         out = odeint(self.odefunc, x, self.integration_time, rtol=1e-4, atol=1e-4)
     return out[1]
Пример #4
0
    def forward(self, first_point, time_steps_to_predict, backwards=False):
        """
		# Decode the trajectory through ODE Solver
		"""
        n_traj_samples, n_traj = first_point.size()[0], first_point.size()[1]
        n_dims = first_point.size()[-1]
        if not self.ode_adjoint:
            pred_y = odeint(self.ode_func,
                            first_point,
                            time_steps_to_predict,
                            rtol=self.odeint_rtol,
                            atol=self.odeint_atol,
                            method=self.ode_method)
        else:
            pred_y = odeint_adjoint(self.ode_func,
                                    first_point,
                                    time_steps_to_predict,
                                    rtol=self.odeint_rtol,
                                    atol=self.odeint_atol,
                                    method=self.ode_method)
        pred_y = pred_y.permute(1, 2, 0, 3)

        assert (torch.mean(pred_y[:, :, 0, :] - first_point) < 0.001)
        assert (pred_y.size()[0] == n_traj_samples)
        assert (pred_y.size()[1] == n_traj)

        return pred_y
Пример #5
0
    def forward(self, x, eval_times=None):
        """Solves ODE starting from x.
        Parameters
        ----------
        x : torch.Tensor
            Shape (batch_size, self.odefunc.data_dim)
        eval_times : None or torch.Tensor
            If None, returns solution of ODE at final time t=1. If torch.Tensor
            then returns full ODE trajectory evaluated at points in eval_times.
        """
        # Forward pass corresponds to solving ODE, so reset number of function
        # evaluations counter

        if eval_times is None:
            integration_time = torch.tensor([0, 1]).float().type_as(x)
        else:
            integration_time = eval_times.type_as(x)

        if self.odefunc.augment_dim > 0:
            if self.is_conv:
                # Add augmentation
                batch_size, channels, height, width = x.shape
                aug = torch.zeros(batch_size,
                                  self.odefunc.augment_dim,
                                  height,
                                  width,
                                  device=x.device)
                # Shape (batch_size, channels + augment_dim, height, width)
                x_aug = torch.cat([x, aug], 1)
            else:
                # Add augmentation
                aug = torch.zeros(x.shape[0],
                                  self.odefunc.augment_dim,
                                  device=x.device)
                # Shape (batch_size, data_dim + augment_dim)
                x_aug = torch.cat([x, aug], 1)
        else:
            x_aug = x

        if self.adjoint:
            out = odeint_adjoint(self.odefunc,
                                 x_aug,
                                 integration_time,
                                 rtol=self.tol,
                                 atol=self.tol,
                                 method=self.method,
                                 options={'max_num_steps': self.max_num_steps})
        else:
            out = odeint(self.odefunc,
                         x_aug,
                         integration_time,
                         rtol=self.tol,
                         atol=self.tol,
                         method=self.method,
                         options={'max_num_steps': self.max_num_steps})

        if eval_times is None:
            return out[1]  # Return only final time
        else:
            return out
Пример #6
0
    def test_adjoint(self):
        for ode in problems.PROBLEMS.keys():
            f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True)

            y = torchdiffeq.odeint_adjoint(f, y0, t_points, method='dopri5')
            with self.subTest(ode=ode):
                self.assertLess(rel_error(sol, y), error_tol)
Пример #7
0
 def _adjoint(self, x):
     return torchdiffeq.odeint_adjoint(self.defunc,
                                       x,
                                       self.s_span,
                                       rtol=self.st['rtol'],
                                       atol=self.st['atol'],
                                       method=self.st['solver'])
Пример #8
0
    def reverse(self, z, c):
        _logpz = torch.zeros(z.shape[0], 1).to(z)
        states = (z, _logpz, c)

        if self.train_T:
            integration_times = torch.stack([
                torch.tensor(0.0).to(z),
                self.sqrt_end_time * self.sqrt_end_time
            ]).to(z)
        else:
            integration_times = torch.tensor([0., self.T],
                                             requires_grad=False).to(z)

        integration_times = self._flip(integration_times, 0)

        # Refresh the odefunc statistics.
        self.odefunc.before_odeint()
        state_t = odeint_adjoint(
            self.odefunc,
            states,
            integration_times.to(z),
            atol=self.test_atol,
            rtol=self.test_rtol,
            method=self.test_solver,
        )

        if len(integration_times) == 2:
            state_t = tuple(s[1] for s in state_t)

        x_t, _, c = state_t[:3]

        return x_t, c
Пример #9
0
    def forward(self, x, eval_times=None):

        dt = self.T / self.time_steps

        if eval_times is None:
            integration_time = torch.tensor([0, self.T]).float().type_as(x)
        else:
            integration_time = eval_times.type_as(x)

        if self.dynamics.augment_dim > 0:
            x = x.view(x.size(0), -1)
            aug = torch.zeros(x.shape[0],
                              self.dynamics.augment_dim).to(self.device)
            x_aug = torch.cat([x, aug], 1)
        else:
            x_aug = x

        if self.adjoint:
            out = odeint_adjoint(self.dynamics,
                                 x_aug,
                                 integration_time,
                                 method='euler',
                                 options={'step_size': dt})
        else:
            out = odeint(self.dynamics,
                         x_aug,
                         integration_time,
                         method='euler',
                         options={'step_size': dt})
        if eval_times is None:
            return out[1]
        else:
            return out
Пример #10
0
    def forward(self, x, eval_times=None, only_last=True):

        if eval_times is None:
            integration_time = torch.tensor([0, 1]).float().type_as(x)
        else:
            integration_time = eval_times.type_as(x)

        if self.odefunc.augment_dim > 0:

            batch_size, channels, height, width = x.shape
            aug = torch.zeros(batch_size, self.odefunc.augment_dim, height,
                              width)
            x_aug = torch.cat([x, aug], 1)

        else:
            x_aug = x

        out = odeint_adjoint(self.odefunc,
                             x_aug,
                             integration_time,
                             rtol=self.tol,
                             atol=self.tol,
                             method='dopri5',
                             options={'max_num_steps': 1000})

        if only_last:
            return out[1]  # Return only final time
        else:
            return out
Пример #11
0
    def forward(self, x, c, log_det):
        _logpx = torch.zeros(x.shape[0], 1).to(x)
        states = (x, _logpx, c)

        if self.train_T:
            integration_times = torch.tensor(
                [0.0, self.sqrt_end_time * self.sqrt_end_time]).to(x)
        else:
            integration_times = torch.tensor([0.0, self.T]).to(x)

        # Refresh the odefunc statistics.
        self.odefunc.before_odeint()

        state_t = odeint_adjoint(
            self.odefunc,
            states,
            integration_times.to(x),
            atol=[self.atol, self.atol, self.atol],
            rtol=[self.rtol, self.rtol, self.atol],
            method=self.solver,
            options=self.solver_options,
        )

        if len(integration_times) == 2:
            state_t = tuple(s[1] for s in state_t)

        z_t, logpz_t, c = state_t[:3]
        log_det = log_det - logpz_t

        return z_t, c, log_det
Пример #12
0
    def forward(self, x, eval_times=None):
        # Forward pass corresponds to solving ODE, so reset number of function
        # evaluations counter
        self.odefunc.nfe = 0

        if eval_times is None:
            integration_time = torch.tensor([0, 1]).float().type_as(x)
        else:
            integration_time = eval_times.type_as(x)

        if self.adjoint:
            out = odeint_adjoint(self.odefunc,
                                 x,
                                 integration_time,
                                 rtol=self.tol,
                                 atol=self.tol,
                                 method='dopri5',
                                 options={'max_num_steps': MAX_NUM_STEPS})
        else:
            out = odeint(self.odefunc,
                         x,
                         integration_time,
                         rtol=self.tol,
                         atol=self.tol,
                         method='dopri5',
                         options={'max_num_steps': MAX_NUM_STEPS})

        if eval_times is None:
            return out[1]  # Return only final time
        else:
            return out
Пример #13
0
    def forward(self, x, eval_times=None, only_last=True):

        if eval_times is None:

            #integration_time = torch.tensor([0, 1]).float().type_as(x)
            integration_time = torch.arange(0, 1, 0.1).float().type_as(x)

        else:

            integration_time = eval_times.type_as(x)

        if self.odefun.augment_dim > 0:

            aug = torch.zeros(x.shape[0], self.odefun.augment_dim)
            x_aug = torch.cat([x, aug], 1)
        else:
            x_aug = x

        out = odeint_adjoint(self.odefun,
                             x_aug,
                             integration_time,
                             rtol=self.tol,
                             atol=self.tol,
                             method='dopri5',
                             options={'max_num_steps': 1000})

        if only_last:
            return out[-1]
        else:
            return out
    def _run_ode(self, *xs, dynamics, **kwargs):
        # TODO: kwargs should be parsed to avoid conflicts!
        assert(all(x.shape[0] == xs[0].shape[0] for x in xs[1:]))
        n_batch = xs[0].shape[0]
        logp_init = torch.zeros(n_batch, 1).to(xs[0])
        state = [*xs, logp_init]
        ts = torch.linspace(0.0, self._t_max, self._n_time_steps).to(xs[0])
        kwargs = {**self._kwargs, **kwargs}
        if not self._use_checkpoints:
            from torchdiffeq import odeint_adjoint
            *ys, dlogp = odeint_adjoint(
                dynamics,
                state,
                t=ts,
                method=self._integrator_method,
                rtol=self._integrator_rtol,
                atol=self._integrator_atol,
                options=kwargs
            )
        else:
#             raise NotImplementedError()
            from anode.adjoint import odesolver_adjoint
            state = odesolver_adjoint(dynamics, state, options=kwargs)
        ys = [y[-1] for y in ys]
        dlogp = dlogp[-1]
        return (*ys, dlogp)
    def test_adjoint(self):
        """
        Test against dopri5
        """
        f, y0, t_points, _ = construct_problem(TEST_DEVICE)

        func = lambda y0, t_points: torchdiffeq.odeint(
            f, y0, t_points, method='dopri5')
        ys = func(y0, t_points)
        torch.manual_seed(0)
        gradys = torch.rand_like(ys)
        ys.backward(gradys)

        # reg_y0_grad = y0.grad
        reg_t_grad = t_points.grad
        reg_a_grad = f.a.grad
        reg_b_grad = f.b.grad

        f, y0, t_points, _ = construct_problem(TEST_DEVICE)

        func = lambda y0, t_points: torchdiffeq.odeint_adjoint(
            f, y0, t_points, method='dopri5')
        ys = func(y0, t_points)
        ys.backward(gradys)

        # adj_y0_grad = y0.grad
        adj_t_grad = t_points.grad
        adj_a_grad = f.a.grad
        adj_b_grad = f.b.grad

        # self.assertLess(max_abs(reg_y0_grad - adj_y0_grad), eps)
        self.assertLess(max_abs(reg_t_grad - adj_t_grad), eps)
        self.assertLess(max_abs(reg_a_grad - adj_a_grad), eps)
        self.assertLess(max_abs(reg_b_grad - adj_b_grad), eps)
Пример #16
0
    def test_seminorm(self):
        for dtype in DTYPES:
            for device in DEVICES:
                for method in ADAPTIVE_METHODS:

                    with self.subTest(dtype=dtype,
                                      device=device,
                                      method=method):

                        if dtype == torch.float32:
                            tol = 1e-6
                        else:
                            tol = 1e-8

                        x0 = torch.tensor([1.0, 2.0],
                                          device=device,
                                          dtype=dtype)
                        t = torch.tensor([0., 1.0], device=device, dtype=dtype)

                        ode_f = _NeuralF(width=1024,
                                         oscillate=True).to(device, dtype)

                        out = torchdiffeq.odeint_adjoint(ode_f,
                                                         x0,
                                                         t,
                                                         atol=tol,
                                                         rtol=tol,
                                                         method=method)
                        ode_f.nfe = 0
                        out.sum().backward()
                        default_nfe = ode_f.nfe

                        out = torchdiffeq.odeint_adjoint(
                            ode_f,
                            x0,
                            t,
                            atol=tol,
                            rtol=tol,
                            method=method,
                            adjoint_options=dict(norm='seminorm'))
                        ode_f.nfe = 0
                        out.sum().backward()
                        seminorm_nfe = ode_f.nfe

                        self.assertLessEqual(seminorm_nfe, default_nfe)
Пример #17
0
 def forward(self, y0, t=None, **kwargs):
     if self.t is not None and t is None:
         t = self.t
     elif self.t is None and t is not None:
         pass
     else:
         raise ValueError(
             'you should add `t` when define NeuralODE or call a NeuralODE object'
         )
     if 'event_fn' not in kwargs or ('event_fn' in kwargs
                                     and kwargs['event_fn'] is None):
         solution = odeint_adjoint(self.func, y0=y0, t=t, **kwargs)
         if self.last:
             solution = solution[-1]
         return solution
     else:
         event_t, solution = odeint_adjoint(self.func, y0=y0, t=t, **kwargs)
         return event_t, solution
Пример #18
0
 def evolve(self, h, time_diff):
     t = torch.tensor([0, time_diff.item()],
                      dtype=time_diff.dtype,
                      device=time_diff.device)
     out = torchdiffeq.odeint_adjoint(func=self.func,
                                      y0=h,
                                      t=t,
                                      method='rk4')
     return out[1]
Пример #19
0
 def _torchdiffeq_adjoint(self, x):
     return torchdiffeq.odeint_adjoint(
         self.defunc,
         x,
         self.s_span,
         rtol=self.rtol,
         atol=self.atol,
         **self.solver,
         adjoint_options=dict(norm=make_norm(x)))[-1]
Пример #20
0
    def test_adjoint(self):
        for device in DEVICES:
            for method in METHODS:

                with self.subTest(device=device, method=method):
                    f, y0, t_points, _ = construct_problem(device=device)
                    func = lambda y0, t_points: torchdiffeq.odeint_adjoint(
                        f, y0, t_points, method=method)
                    self.assertTrue(
                        torch.autograd.gradcheck(func, (y0, t_points)))
Пример #21
0
    def forward(self, x):
        if self.norm:
            adjoint_options = dict(norm=common.make_norm(x))
        else:
            adjoint_options = None

        z = torchdiffeq.odeint_adjoint(self.func, x[:, 0], self.times, rtol=self.rtol, atol=self.atol,
                                       adjoint_options=adjoint_options)

        loss = torch.nn.functional.mse_loss(x, z.transpose(0, 1))
        return loss
Пример #22
0
    def forward(self, x):
        self.integration_time = self.integration_time.type_as(x)

        # Actual ODE solver used in forward propogation
        out = odeint_adjoint(self.odefunc,
                             x,
                             self.integration_time,
                             rtol=tol,
                             atol=tol)

        return out[1]
Пример #23
0
    def forward_batched(self, x:torch.Tensor, nn:int, indices:list, timestamps:set):
        """ Modified forward for ODE batches with different integration times """
        timestamps = torch.Tensor(list(timestamps))
        if self.adjoint_flag:
            out = torchdiffeq.odeint_adjoint(self.odefunc, x, timestamps,
                                             rtol=self.rtol, atol=self.atol, method=self.method)
        else:
            out = torchdiffeq.odeint(self.odefunc, x, timestamps,
                                     rtol=self.rtol, atol=self.atol, method=self.method)

        out = self._build_batch(out, nn, indices).reshape(x.shape)
        return out
Пример #24
0
    def forward(self, x:torch.Tensor, T:int=1):
        self.integration_time = torch.tensor([0, T]).float()
        self.integration_time = self.integration_time.type_as(x)

        if self.adjoint_flag:
            out = torchdiffeq.odeint_adjoint(self.odefunc, x, self.integration_time,
                                             rtol=self.rtol, atol=self.atol, method=self.method)
        else:
            out = torchdiffeq.odeint(self.odefunc, x, self.integration_time,
                                     rtol=self.rtol, atol=self.atol, method=self.method)
            
        return out[-1]
Пример #25
0
    def integrate(self,
                  t0,
                  t1,
                  x,
                  logpx,
                  tol=None,
                  method=None,
                  norm=None,
                  intermediate_states=0):
        """
        Args:
            t0: (N,)
            t1: (N,)
            x: (N, ...)
            logpx: (N,)
        """
        self.nfe = 0

        tol = tol or self.tol
        method = method or self.method
        e = torch.randn_like(x)[:, :self.dim]
        energy = torch.zeros(1).to(x)
        jacnorm = torch.zeros(1).to(x)
        initial_state = (t0, t1, e, x, logpx, energy, jacnorm)

        if intermediate_states > 1:
            tt = torch.linspace(self.start_time, self.end_time,
                                intermediate_states).to(t0)
        else:
            tt = torch.tensor([self.start_time, self.end_time]).to(t0)

        solution = odeint_adjoint(
            self,
            initial_state,
            tt,
            rtol=tol,
            atol=tol,
            method=method,
        )

        if intermediate_states > 1:
            y = solution[3]
            _, _, _, _, logpy, energy, jacnorm = tuple(s[-1] for s in solution)
        else:
            _, _, _, y, logpy, energy, jacnorm = tuple(s[-1] for s in solution)

        regularization = (self.energy_regularization *
                          (energy - energy.detach()) +
                          self.jacnorm_regularization *
                          (jacnorm - jacnorm.detach()))

        return y, logpy + regularization  # hacky method to introduce regularization.
Пример #26
0
    def test_adjoint(self):
        for reverse in (False, True):
            for dtype in DTYPES:
                for device in DEVICES:
                    for ode in PROBLEMS:
                        if ode == 'linear':
                            eps = 2e-3
                        else:
                            eps = 1e-4

                        with self.subTest(reverse=reverse, dtype=dtype, device=device, ode=ode):
                            f, y0, t_points, sol = construct_problem(dtype=dtype, device=device, ode=ode,
                                                                     reverse=reverse)
                            y = torchdiffeq.odeint_adjoint(f, y0, t_points)
                            self.assertLess(rel_error(sol, y), eps)
Пример #27
0
    def trajectory(self,
                   x: torch.Tensor,
                   s_span: torch.Tensor,
                   method='odeint',
                   **kwargs):
        if method == 'adjoint':
            solution = odeint_adjoint(self.func, x, s_span, **kwargs)
        elif method == 'odeint':
            solution = odeint(self.func, x, s_span, **kwargs)
        else:
            raise ValueError(
                'Please check parameters `method`, it should be `adjoint` or `odeint`'
            )

        return solution
Пример #28
0
    def forward(self, X, i=-1):
        '''

        Input:
            X - torch Tensor of size (N, C, W, H).
            i - int index corresponding to time point of ODE estimation of the solution.

        Output:
            X_ti - torch Tensor of size (N, C', W', H'). Estimation of X at time point t[i].

        '''
        return torchdiffeq.odeint_adjoint(self.f,
                                          X,
                                          self.t,
                                          rtol=TOL,
                                          atol=TOL)[i]
Пример #29
0
 def forward(self, vt, x):
     integration_time_vector = vt.type_as(x)
     if self.adjoint:
         out = ode.odeint_adjoint(self.odefunc,
                                  x,
                                  integration_time_vector,
                                  rtol=self.rtol,
                                  atol=self.atol,
                                  method=self.method)
     else:
         out = ode.odeint(self.odefunc,
                          x,
                          integration_time_vector,
                          rtol=self.rtol,
                          atol=self.atol,
                          method=self.method)
     # return out[-1]
     return out[-1] if self.terminal else out  # 100 * 400 * 10
Пример #30
0
    def test_adjoint(self):
        f, y0, t_points, sol = construct_problem(device="cpu", ode="constant")

        def event_fn(t, y):
            return torch.sum(y - sol[-1])

        t, y = torchdiffeq.odeint_adjoint(f,
                                          y0,
                                          t_points[0:2],
                                          event_fn=event_fn,
                                          method="dopri5")
        y = y[-1]
        self.assertLess(rel_error(sol[-1], y), 1e-4)
        self.assertLess(rel_error(t_points[-1], t), 1e-4)

        # Make sure adjoint mode backward code can still be run.
        t.backward(retain_graph=True)
        y.sum().backward()