Example #1
0
 def forward(self, X, Ndata, L=1, inst_enc=False, method='dopri5', dt=0.1):
     ''' Input
             X          - input images [N,T,nc,d,d]
             Ndata      - number of sequences in the dataset (required for elbo)
             L          - number of Monta Carlo draws (from BNN)
             inst_enc   - whether instant encoding is used or not
             method     - numerical integration method
             dt         - numerical integration step size 
         Returns
             Xrec_mu    - reconstructions from the mean embedding - [N,nc,D,D]
             Xrec_L     - reconstructions from latent samples     - [L,N,nc,D,D]
             qz_m       - mean of the latent embeddings           - [N,q]
             qz_logv    - log variance of the latent embeddings   - [N,q]
             lhood-kl_z - ELBO   
             lhood      - reconstruction likelihood
             kl_z       - KL
     '''
     # encode
     [N,T,nc,d,d] = X.shape
     h = self.encoder(X[:,0])
     qz0_m, qz0_logv = self.fc1(h), self.fc2(h) # N,2q & N,2q
     q = qz0_m.shape[1]//2
     # latent samples
     eps   = torch.randn_like(qz0_m)  # N,2q
     z0    = qz0_m + eps*torch.exp(qz0_logv) # N,2q
     logp0 = self.mvn.log_prob(eps) # N 
     # ODE
     t  = dt * torch.arange(T,dtype=torch.float).to(z0.device)
     ztL   = []
     logpL = []
     # sample L trajectories
     for l in range(L):
         f       = self.bnn.draw_f() # draw a differential function
         oderhs  = lambda t,vs: self.ode2vae_rhs(t,vs,f) # make the ODE forward function
         zt,logp = odeint(oderhs,(z0,logp0),t,method=method) # T,N,2q & T,N
         ztL.append(zt.permute([1,0,2]).unsqueeze(0)) # 1,N,T,2q
         logpL.append(logp.permute([1,0]).unsqueeze(0)) # 1,N,T
     ztL   = torch.cat(ztL,0) # L,N,T,2q
     logpL = torch.cat(logpL) # L,N,T
     # decode
     st_muL = ztL[:,:,:,q:] # L,N,T,q
     s = self.fc3(st_muL.contiguous().view([L*N*T,q]) ) # L*N*T,h_dim
     Xrec = self.decoder(s) # L*N*T,nc,d,d
     Xrec = Xrec.view([L,N,T,nc,d,d]) # L,N,T,nc,d,d
     # likelihood and elbo
     if inst_enc:
         h = self.encoder(X.contiguous().view([N*T,nc,d,d]))
         qz_enc_m, qz_enc_logv = self.fc1(h), self.fc2(h) # N*T,2q & N*T,2q
         lhood, kl_z, kl_w, inst_KL = \
             self.elbo(qz0_m, qz0_logv, ztL, logpL, X, Xrec, Ndata, qz_enc_m, qz_enc_logv)
         elbo = lhood - kl_z - inst_KL - self.beta*kl_w
     else:
         lhood, kl_z, kl_w = self.elbo(qz0_m, qz0_logv, ztL, logpL, X, Xrec, Ndata)
         elbo = lhood - kl_z - self.beta*kl_w
     return Xrec, qz0_m, qz0_logv, ztL, elbo, lhood, kl_z, self.beta*kl_w
Example #2
0
    def test_large_norm(self):
        def norm(tensor):
            return tensor.abs().max()

        def large_norm(tensor):
            return 10 * tensor.abs().max()

        for dtype in DTYPES:
            for device in DEVICES:
                for method in ADAPTIVE_METHODS:
                    if dtype == torch.float32 and method == 'dopri8':
                        continue

                    with self.subTest(dtype=dtype,
                                      device=device,
                                      method=method):
                        x0 = torch.tensor([1.0, 2.0],
                                          device=device,
                                          dtype=dtype)
                        t = torch.tensor([0., 1.0], device=device, dtype=dtype)

                        norm_f = _NeuralF(width=10,
                                          oscillate=True).to(device, dtype)
                        torchdiffeq.odeint(norm_f,
                                           x0,
                                           t,
                                           method=method,
                                           options=dict(norm=norm))
                        large_norm_f = _NeuralF(width=10, oscillate=True).to(
                            device, dtype)
                        with torch.no_grad():
                            for norm_param, large_norm_param in zip(
                                    norm_f.parameters(),
                                    large_norm_f.parameters()):
                                large_norm_param.copy_(norm_param)
                        torchdiffeq.odeint(large_norm_f,
                                           x0,
                                           t,
                                           method=method,
                                           options=dict(norm=large_norm))

                        self.assertLessEqual(norm_f.nfe, large_norm_f.nfe)
Example #3
0
    def test_adaptive_heun(self):
        f, y0, t_points, sol = construct_problem(TEST_DEVICE)

        tuple_f = lambda t, y: (f(t, y[0]), f(t, y[1]))
        tuple_y0 = (y0, y0)

        tuple_y = torchdiffeq.odeint(tuple_f, tuple_y0, t_points, method='adaptive_heun')
        max_error0 = (sol - tuple_y[0]).max()
        max_error1 = (sol - tuple_y[1]).max()
        self.assertLess(max_error0, eps)
        self.assertLess(max_error1, eps)
Example #4
0
 def forward(self, O, R, X, Msrc, Mtgt, tol=None):
     Otail = O[:, self.d_P:]
     Ohead = O[:, :self.d_P]
     self.integration_time = self.integration_time.type_as(O)
     self.odefunc.set_fixed(Otail, Msrc, Mtgt)
     P = odeint(self.odefunc,
                Ohead,
                self.integration_time,
                rtol=self.tol,
                atol=self.tol)
     return P[-1]
Example #5
0
 def forward(self, x):
     self.integration_time = self.integration_time.type_as(x)
     
     if self.method in ['rk4_param', 'rk3_param']:
         out = odeint_plus(self.odefunc, x, self.integration_time,
                           method=self.method, options = {'tableau':self.tableau, 'step_size':self.step_size})
     else:
         out = odeint(self.odefunc, x, self.integration_time, rtol=args.tol, atol=args.tol,
                      method=self.method, options = {'step_size':self.step_size})
             
     return out[1]
Example #6
0
def MLE2(x0, F, ts, **kwargs):
    with torch.no_grad():
        LD = LyapunovDynamics(F)
        x0 = x0.reshape(x0.shape[0], -1)
        q0 = torch.randn_like(x0)
        q0 /= (q0 ** 2).sum(-1, keepdims=True).sqrt()
        lr0 = torch.zeros(x0.shape[0], 1, dtype=x0.dtype, device=x0.device)
        Lx0 = torch.cat([x0, q0, lr0], dim=-1)
        Lxt = odeint(LD, Lx0, ts, **kwargs)
        maximal_exponent = Lxt  # [...,-1]
    return maximal_exponent
Example #7
0
def get_true_y(u):
    true_y0 = torch.rand(1, 2) - 0.5

    with torch.no_grad():
        true_y = odeint(Lambda(u=u), true_y0, t, method='dopri5')

    #print(true_y0)
    #print(u)
    #print(true_y)

    return true_y0, true_y
Example #8
0
def forward(params, length_seconds=10, fs=300, debug=False):
    func_rk = func_gen(params)
    size = params.size
    x0 = torch.FloatTensor([[0, 0, 0.1, 0] for _ in range(size)])
    ts = torch.FloatTensor(np.linspace(0, length_seconds, num=length_seconds*fs))
    ys = odeint(func_rk, x0, ts, method='rk4')
    signal = calc_ecg(ys, params)
    if debug:
        return signal, ys
    else:
        return signal
 def closure():
     optimizer.zero_grad()
     #output = model(input)
     pred_y_ = odeint(func,
                      batch_y0.squeeze(),
                      batch_t,
                      method=args.method)
     #loss = loss_fn(output, target)
     loss_ = lossfunc(pred_y_[:, :, 0], batch_y.squeeze()[:, :, 0])
     loss_.backward()
     return loss_
Example #10
0
 def forward(self, x):
     self.integration_time = self.integration_time.type_as(x)
     out = odeint(
         self.odefunc,
         x,
         self.integration_time,
         rtol=args.tol,
         atol=args.tol,
         method=args.method,
         )
     return out[1]
Example #11
0
def get_one_step_prediction(model, x0, dt, device):
    """ Given a model, and an initial condition (1D numpy array), predict for some dt (scalar) in the future.
        returns x_hats
    """
    assert type(x0) == np.ndarray
    x0 = np_to_integratable_type_1D(x0, device)
    ts = torch.tensor([0., dt], requires_grad=True, dtype=torch.float32).to(device)
    x_hats = odeint(model, x0, ts, method='rk4')
    x_hat = x_hats[1]
    assert x_hat.shape == x0.shape
    return x_hat[0].detach().cpu().numpy()
Example #12
0
    def ode_sample(self, batch_size=64, tau=1., t=None, y=None, dt=1e-2):
        self.module.eval()

        t = torch.tensor([-self.t1, -self.t0],
                         device=self.device) if t is None else t
        y = self.sample_t1_marginal(batch_size, tau) if y is None else y
        return torchdiffeq.odeint(self,
                                  y,
                                  t,
                                  method="rk4",
                                  options={"step_size": dt})
 def forward(self, x):
     self.integration_time = self.integration_time.type_as(x)
     #out = odeint(self.odefunc, x, self.integration_time, rtol=args.tol, atol=args.tol)
     # out = odeint(self.odefunc, x, self.integration_time, rtol=1e-1, atol=1e-1).permute(1,0,2)
     out = odeint(self.odefunc,
                  x,
                  self.integration_time,
                  rtol=1e-3,
                  atol=1e-3).permute(1, 0, 2, 3)
     # odeint(func, z0, samp_ts).permute(1,0,2)
     return out
 def forward(self, x, t=None):
     if t is not None:
         time = t
     else:
         time = self.integration_time
     self.output = odeint(self.odefunc,
                          x,
                          time,
                          rtol=self.rtol,
                          atol=self.atol)
     return self.output[1]
Example #15
0
 def forward(self, x, t=None):
     if t is None:
         times = self.t
     else:
         times = t
     self.outputs = odeint(self.odefunc,
                           x,
                           times,
                           rtol=self.rtol,
                           atol=self.atol)
     return self.outputs[1]
Example #16
0
    def test_adams(self):
        f, y0, t_points, sol = problem()

        tuple_f = lambda t, y: (f(t, y[0]), f(t, y[1]))
        tuple_y0 = (y0, y0)

        tuple_y = torchdiffeq.odeint(tuple_f, tuple_y0, t_points, method='adams')
        max_error0 = (sol - tuple_y[0]).max()
        max_error1 = (sol - tuple_y[1]).max()
        self.assertLess(max_error0, eps)
        self.assertLess(max_error1, eps)
 def closure():
     optimizer.zero_grad()
     #output = model(input)
     #pred_y_ = odeint(func, batch_y0.squeeze(), batch_t, method=args.method)
     #loss = loss_fn(output, target)
     #loss_ = lossfunc(pred_y_[:,:,0], batch_y.squeeze()[:,:,0])
     pred_y_ = odeint(func, init_state, t, method=args.method)
     loss_ = lossfunc(
         pred_y_[:, :, 0], true_y[:, :, 0]
     ) + 10000 * torch.nn.functional.relu(-func.w0_learnable)
     loss_.backward()
     return loss_
 def integrate(self, init_data, probe_points):
     import torchdiffeq as tde
     with torch.no_grad():
         self.model.eval()
         logging.info('calculating integration...')
         return tde.odeint(func=self.odefunc,
                           y0=init_data,
                           t=probe_points,
                           rtol=self.tol,
                           atol=self.tol,
                           method=self.odesolver,
                           options=None)
Example #19
0
 def forward(self, x, mask):
     "Pass the input (and mask) through each layer in turn."
     #         for layer in self.layers:
     #             x = layer(x, mask)
     self.integration_time = self.integration_time.type_as(x)
     self.ode_layer.set_mask(mask)
     out = odeint(self.ode_layer,
                  x,
                  self.integration_time,
                  rtol=self.tol,
                  atol=self.tol)
     return self.norm(out[1])
Example #20
0
    def predict_horizon(self, state, action_sequence):
        horizon = len(action_sequence)
        state = torch.Tensor(state)
        states = torch.zeros((state.shape[0], horizon))
        for i, a in enumerate(action_sequence):
            s_augmented = torch.cat((state, torch.Tensor([a]).float()))
            ns = odeint(self, s_augmented, torch.Tensor([0, 0.2]))[1, :4]

            states[:, i] = ns
            state = ns

        return states.detach().numpy()
Example #21
0
 def forward(ctx, h0, flat_params, s_span):
     with torch.no_grad():
         sol = odeint(self.func,
                      h0,
                      self.s_span,
                      rtol=self.rtol,
                      atol=self.atol,
                      method=self.method,
                      options=self.options)
     ctx.save_for_backward(self.s_span, self.flat_params, sol)
     sol = sol if self.return_traj else sol[-1]
     return sol
Example #22
0
    def test_odeint(self):
        for reverse in (False, True):
            for dtype in DTYPES:
                for device in DEVICES:
                    for method in METHODS:
                        for ode in PROBLEMS:

                            with self.subTest(reverse=reverse, dtype=dtype, device=device, ode=ode, method=method):
                                f, y0, t_points, sol = construct_problem(dtype=dtype, device=device, ode=ode,
                                                                         reverse=reverse)
                                y = torchdiffeq.odeint(f, y0, t_points[0:1], method=method)
                                self.assertLess((sol[0] - y).abs().max(), 1e-12)
Example #23
0
    def forward_batched(self, x:torch.Tensor, nn:int, indices:list, timestamps:set):
        """ Modified forward for ODE batches with different integration times """
        timestamps = torch.Tensor(list(timestamps))
        if self.adjoint_flag:
            out = torchdiffeq.odeint_adjoint(self.odefunc, x, timestamps,
                                             rtol=self.rtol, atol=self.atol, method=self.method)
        else:
            out = torchdiffeq.odeint(self.odefunc, x, timestamps,
                                     rtol=self.rtol, atol=self.atol, method=self.method)

        out = self._build_batch(out, nn, indices).reshape(x.shape)
        return out
    def forward(self, input, dt=0.1):
        pre_x, pre_y, forward_x = input

        if pre_x.shape[0] == 0:
            pre_x = torch.zeros(
                (1, pre_x.shape[1], pre_x.shape[2])).to(forward_x.device)
            pre_y = torch.zeros(
                (1, pre_y.shape[1], pre_y.shape[2])).to(forward_x.device)
        _, hn = self.rnn_encoder(torch.cat(
            [pre_x, pre_y], dim=2))  # hn (1, batch_size, hidden_num)
        t = torch.linspace(0, dt * forward_x.shape[0],
                           forward_x.shape[0] + 1).to(pre_x.device)
        interpolation = Interpolation(t,
                                      torch.cat([pre_x[-1:], forward_x],
                                                dim=0),
                                      method=self.interpolation)
        self.ode_net.input_interpolation = interpolation

        if self.adjoint:
            from torchdiffeq import odeint_adjoint as odeint
        else:
            from torchdiffeq import odeint
        self.ode_net.cell.call_times = 0
        forward_hn_all = odeint(self.ode_net,
                                hn[0],
                                t,
                                rtol=self.rtol,
                                atol=self.atol,
                                method=self.solver)[1:]
        #forward_hn_all = odeint(self.ode_net, hn[0], t, rtol=self.rtol*len(t)/60, atol=self.atol*len(t)/60, method=self.solver)[1:]

        ###############################################################
        # start_hn = hn[0]
        # forward_hn_list = []
        # import time
        # ti = time.time()
        # for i in range(0, len(t)-1, 100):
        #     #cur_t = t[i:min(i+30+1, len(t))]
        #     cur_t_len = min(101, len(t)-i)
        #     cur_t = torch.linspace(0, dt * (cur_t_len-1), cur_t_len)
        #     with torch.no_grad():
        #         forward_hn = odeint(self.ode_net, start_hn, cur_t, rtol=self.rtol, atol=self.atol, method=self.solver)[1:]
        #     forward_hn_list.append(forward_hn)
        #     start_hn = forward_hn[-1]
        #     print(cur_t)
        #     print('{}-{}-{}s-{}-mean:{}-var:{}'.format(
        #         len(t), i, time.time()-ti, self.ode_net.cell.call_times, torch.mean(forward_hn), torch.var(forward_hn))
        #     )
        #     ti = time.time()
        # forward_hn_all = torch.cat(forward_hn_list, dim=0)
        ###############################################################
        estimate_y_all = self.recursive_predict(forward_hn_all, max_len=1000)
        return estimate_y_all
Example #25
0
    def numerically_integrate(
            self, state, u=0., T=1, time_steps=200, method='dopri5'):
        """
        Numerical integration of the dynamical system, used as a baseline

        """

        t = torch.linspace(0, T, time_steps).to(device)
        state = state.to(device)
        return odeint(
            lambda _, x: self.forward(x, u), state, t, method=method
        ).cpu()[: , 0, :]
Example #26
0
 def _integral_autograd(self, x):
     assert self.st[
         'cost'], 'Cost nn.Module needs to be specified for integral adjoint'
     ξ0 = 0. * torch.ones(1).to(x.device)
     ξ0 = ξ0.repeat(x.shape[0]).unsqueeze(1)
     x = torch.cat([x, ξ0], 1)
     return torchdiffeq.odeint(self._integral_autograd_defunc,
                               x,
                               self.s_span,
                               rtol=self.st['rtol'],
                               atol=self.st['atol'],
                               method=self.st['solver'])
Example #27
0
    def forward(self, x:torch.Tensor, T:int=1):
        self.integration_time = torch.tensor([0, T]).float()
        self.integration_time = self.integration_time.type_as(x)

        if self.adjoint_flag:
            out = torchdiffeq.odeint_adjoint(self.odefunc, x, self.integration_time,
                                             rtol=self.rtol, atol=self.atol, method=self.method)
        else:
            out = torchdiffeq.odeint(self.odefunc, x, self.integration_time,
                                     rtol=self.rtol, atol=self.atol, method=self.method)
            
        return out[-1]
Example #28
0
    def get_obs(self, x):
        out_obs = self.fc_obs_in(x)
        self.integration_time = self.integration_time.type_as(out_obs)
        out_obs = odeint(self.odefunc_obs,
                         out_obs,
                         self.integration_time,
                         rtol=self.tol,
                         atol=self.tol)[1]
        # out_obs = self.odefunc_obs(0, out_obs)
        out_obs = self.fc_state_out(out_obs)

        return out_obs
Example #29
0
 def test_dopri8(self):
     for ode in problems.PROBLEMS.keys():
         f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE,
                                                           ode=ode)
         y = torchdiffeq.odeint(f,
                                y0,
                                t_points,
                                method='dopri8',
                                rtol=1e-12,
                                atol=1e-14)
         with self.subTest(ode=ode):
             self.assertLess(rel_error(sol, y), error_tol)
Example #30
0
 def invert(self, x, time):
     '''
     Solves the ODE in the reverse time. 
     '''
     self.inverse_time = torch.tensor([time, 0]).float().type_as(x)
     out = odeint(self.odefunc,
                  x,
                  self.inverse_time,
                  rtol=self.rtol,
                  atol=self.atol)
     self.cost = self.odefunc.nfe
     return out[1]
Example #31
0
device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')

true_y0 = torch.tensor([[2., 0.]])
t = torch.linspace(0., 25., args.data_size)
true_A = torch.tensor([[-0.1, 2.0], [-2.0, -0.1]])


class Lambda(nn.Module):

    def forward(self, t, y):
        return torch.mm(y**3, true_A)


with torch.no_grad():
    true_y = odeint(Lambda(), true_y0, t, method='dopri5')


def get_batch():
    s = torch.from_numpy(np.random.choice(np.arange(args.data_size - args.batch_time, dtype=np.int64), args.batch_size, replace=False))
    batch_y0 = true_y[s]  # (M, D)
    batch_t = t[:args.batch_time]  # (T)
    batch_y = torch.stack([true_y[s + i] for i in range(args.batch_time)], dim=0)  # (T, M, D)
    return batch_y0, batch_t, batch_y


def makedirs(dirname):
    if not os.path.exists(dirname):
        os.makedirs(dirname)