def forward(self, X, Ndata, L=1, inst_enc=False, method='dopri5', dt=0.1): ''' Input X - input images [N,T,nc,d,d] Ndata - number of sequences in the dataset (required for elbo) L - number of Monta Carlo draws (from BNN) inst_enc - whether instant encoding is used or not method - numerical integration method dt - numerical integration step size Returns Xrec_mu - reconstructions from the mean embedding - [N,nc,D,D] Xrec_L - reconstructions from latent samples - [L,N,nc,D,D] qz_m - mean of the latent embeddings - [N,q] qz_logv - log variance of the latent embeddings - [N,q] lhood-kl_z - ELBO lhood - reconstruction likelihood kl_z - KL ''' # encode [N,T,nc,d,d] = X.shape h = self.encoder(X[:,0]) qz0_m, qz0_logv = self.fc1(h), self.fc2(h) # N,2q & N,2q q = qz0_m.shape[1]//2 # latent samples eps = torch.randn_like(qz0_m) # N,2q z0 = qz0_m + eps*torch.exp(qz0_logv) # N,2q logp0 = self.mvn.log_prob(eps) # N # ODE t = dt * torch.arange(T,dtype=torch.float).to(z0.device) ztL = [] logpL = [] # sample L trajectories for l in range(L): f = self.bnn.draw_f() # draw a differential function oderhs = lambda t,vs: self.ode2vae_rhs(t,vs,f) # make the ODE forward function zt,logp = odeint(oderhs,(z0,logp0),t,method=method) # T,N,2q & T,N ztL.append(zt.permute([1,0,2]).unsqueeze(0)) # 1,N,T,2q logpL.append(logp.permute([1,0]).unsqueeze(0)) # 1,N,T ztL = torch.cat(ztL,0) # L,N,T,2q logpL = torch.cat(logpL) # L,N,T # decode st_muL = ztL[:,:,:,q:] # L,N,T,q s = self.fc3(st_muL.contiguous().view([L*N*T,q]) ) # L*N*T,h_dim Xrec = self.decoder(s) # L*N*T,nc,d,d Xrec = Xrec.view([L,N,T,nc,d,d]) # L,N,T,nc,d,d # likelihood and elbo if inst_enc: h = self.encoder(X.contiguous().view([N*T,nc,d,d])) qz_enc_m, qz_enc_logv = self.fc1(h), self.fc2(h) # N*T,2q & N*T,2q lhood, kl_z, kl_w, inst_KL = \ self.elbo(qz0_m, qz0_logv, ztL, logpL, X, Xrec, Ndata, qz_enc_m, qz_enc_logv) elbo = lhood - kl_z - inst_KL - self.beta*kl_w else: lhood, kl_z, kl_w = self.elbo(qz0_m, qz0_logv, ztL, logpL, X, Xrec, Ndata) elbo = lhood - kl_z - self.beta*kl_w return Xrec, qz0_m, qz0_logv, ztL, elbo, lhood, kl_z, self.beta*kl_w
def test_large_norm(self): def norm(tensor): return tensor.abs().max() def large_norm(tensor): return 10 * tensor.abs().max() for dtype in DTYPES: for device in DEVICES: for method in ADAPTIVE_METHODS: if dtype == torch.float32 and method == 'dopri8': continue with self.subTest(dtype=dtype, device=device, method=method): x0 = torch.tensor([1.0, 2.0], device=device, dtype=dtype) t = torch.tensor([0., 1.0], device=device, dtype=dtype) norm_f = _NeuralF(width=10, oscillate=True).to(device, dtype) torchdiffeq.odeint(norm_f, x0, t, method=method, options=dict(norm=norm)) large_norm_f = _NeuralF(width=10, oscillate=True).to( device, dtype) with torch.no_grad(): for norm_param, large_norm_param in zip( norm_f.parameters(), large_norm_f.parameters()): large_norm_param.copy_(norm_param) torchdiffeq.odeint(large_norm_f, x0, t, method=method, options=dict(norm=large_norm)) self.assertLessEqual(norm_f.nfe, large_norm_f.nfe)
def test_adaptive_heun(self): f, y0, t_points, sol = construct_problem(TEST_DEVICE) tuple_f = lambda t, y: (f(t, y[0]), f(t, y[1])) tuple_y0 = (y0, y0) tuple_y = torchdiffeq.odeint(tuple_f, tuple_y0, t_points, method='adaptive_heun') max_error0 = (sol - tuple_y[0]).max() max_error1 = (sol - tuple_y[1]).max() self.assertLess(max_error0, eps) self.assertLess(max_error1, eps)
def forward(self, O, R, X, Msrc, Mtgt, tol=None): Otail = O[:, self.d_P:] Ohead = O[:, :self.d_P] self.integration_time = self.integration_time.type_as(O) self.odefunc.set_fixed(Otail, Msrc, Mtgt) P = odeint(self.odefunc, Ohead, self.integration_time, rtol=self.tol, atol=self.tol) return P[-1]
def forward(self, x): self.integration_time = self.integration_time.type_as(x) if self.method in ['rk4_param', 'rk3_param']: out = odeint_plus(self.odefunc, x, self.integration_time, method=self.method, options = {'tableau':self.tableau, 'step_size':self.step_size}) else: out = odeint(self.odefunc, x, self.integration_time, rtol=args.tol, atol=args.tol, method=self.method, options = {'step_size':self.step_size}) return out[1]
def MLE2(x0, F, ts, **kwargs): with torch.no_grad(): LD = LyapunovDynamics(F) x0 = x0.reshape(x0.shape[0], -1) q0 = torch.randn_like(x0) q0 /= (q0 ** 2).sum(-1, keepdims=True).sqrt() lr0 = torch.zeros(x0.shape[0], 1, dtype=x0.dtype, device=x0.device) Lx0 = torch.cat([x0, q0, lr0], dim=-1) Lxt = odeint(LD, Lx0, ts, **kwargs) maximal_exponent = Lxt # [...,-1] return maximal_exponent
def get_true_y(u): true_y0 = torch.rand(1, 2) - 0.5 with torch.no_grad(): true_y = odeint(Lambda(u=u), true_y0, t, method='dopri5') #print(true_y0) #print(u) #print(true_y) return true_y0, true_y
def forward(params, length_seconds=10, fs=300, debug=False): func_rk = func_gen(params) size = params.size x0 = torch.FloatTensor([[0, 0, 0.1, 0] for _ in range(size)]) ts = torch.FloatTensor(np.linspace(0, length_seconds, num=length_seconds*fs)) ys = odeint(func_rk, x0, ts, method='rk4') signal = calc_ecg(ys, params) if debug: return signal, ys else: return signal
def closure(): optimizer.zero_grad() #output = model(input) pred_y_ = odeint(func, batch_y0.squeeze(), batch_t, method=args.method) #loss = loss_fn(output, target) loss_ = lossfunc(pred_y_[:, :, 0], batch_y.squeeze()[:, :, 0]) loss_.backward() return loss_
def forward(self, x): self.integration_time = self.integration_time.type_as(x) out = odeint( self.odefunc, x, self.integration_time, rtol=args.tol, atol=args.tol, method=args.method, ) return out[1]
def get_one_step_prediction(model, x0, dt, device): """ Given a model, and an initial condition (1D numpy array), predict for some dt (scalar) in the future. returns x_hats """ assert type(x0) == np.ndarray x0 = np_to_integratable_type_1D(x0, device) ts = torch.tensor([0., dt], requires_grad=True, dtype=torch.float32).to(device) x_hats = odeint(model, x0, ts, method='rk4') x_hat = x_hats[1] assert x_hat.shape == x0.shape return x_hat[0].detach().cpu().numpy()
def ode_sample(self, batch_size=64, tau=1., t=None, y=None, dt=1e-2): self.module.eval() t = torch.tensor([-self.t1, -self.t0], device=self.device) if t is None else t y = self.sample_t1_marginal(batch_size, tau) if y is None else y return torchdiffeq.odeint(self, y, t, method="rk4", options={"step_size": dt})
def forward(self, x): self.integration_time = self.integration_time.type_as(x) #out = odeint(self.odefunc, x, self.integration_time, rtol=args.tol, atol=args.tol) # out = odeint(self.odefunc, x, self.integration_time, rtol=1e-1, atol=1e-1).permute(1,0,2) out = odeint(self.odefunc, x, self.integration_time, rtol=1e-3, atol=1e-3).permute(1, 0, 2, 3) # odeint(func, z0, samp_ts).permute(1,0,2) return out
def forward(self, x, t=None): if t is not None: time = t else: time = self.integration_time self.output = odeint(self.odefunc, x, time, rtol=self.rtol, atol=self.atol) return self.output[1]
def forward(self, x, t=None): if t is None: times = self.t else: times = t self.outputs = odeint(self.odefunc, x, times, rtol=self.rtol, atol=self.atol) return self.outputs[1]
def test_adams(self): f, y0, t_points, sol = problem() tuple_f = lambda t, y: (f(t, y[0]), f(t, y[1])) tuple_y0 = (y0, y0) tuple_y = torchdiffeq.odeint(tuple_f, tuple_y0, t_points, method='adams') max_error0 = (sol - tuple_y[0]).max() max_error1 = (sol - tuple_y[1]).max() self.assertLess(max_error0, eps) self.assertLess(max_error1, eps)
def closure(): optimizer.zero_grad() #output = model(input) #pred_y_ = odeint(func, batch_y0.squeeze(), batch_t, method=args.method) #loss = loss_fn(output, target) #loss_ = lossfunc(pred_y_[:,:,0], batch_y.squeeze()[:,:,0]) pred_y_ = odeint(func, init_state, t, method=args.method) loss_ = lossfunc( pred_y_[:, :, 0], true_y[:, :, 0] ) + 10000 * torch.nn.functional.relu(-func.w0_learnable) loss_.backward() return loss_
def integrate(self, init_data, probe_points): import torchdiffeq as tde with torch.no_grad(): self.model.eval() logging.info('calculating integration...') return tde.odeint(func=self.odefunc, y0=init_data, t=probe_points, rtol=self.tol, atol=self.tol, method=self.odesolver, options=None)
def forward(self, x, mask): "Pass the input (and mask) through each layer in turn." # for layer in self.layers: # x = layer(x, mask) self.integration_time = self.integration_time.type_as(x) self.ode_layer.set_mask(mask) out = odeint(self.ode_layer, x, self.integration_time, rtol=self.tol, atol=self.tol) return self.norm(out[1])
def predict_horizon(self, state, action_sequence): horizon = len(action_sequence) state = torch.Tensor(state) states = torch.zeros((state.shape[0], horizon)) for i, a in enumerate(action_sequence): s_augmented = torch.cat((state, torch.Tensor([a]).float())) ns = odeint(self, s_augmented, torch.Tensor([0, 0.2]))[1, :4] states[:, i] = ns state = ns return states.detach().numpy()
def forward(ctx, h0, flat_params, s_span): with torch.no_grad(): sol = odeint(self.func, h0, self.s_span, rtol=self.rtol, atol=self.atol, method=self.method, options=self.options) ctx.save_for_backward(self.s_span, self.flat_params, sol) sol = sol if self.return_traj else sol[-1] return sol
def test_odeint(self): for reverse in (False, True): for dtype in DTYPES: for device in DEVICES: for method in METHODS: for ode in PROBLEMS: with self.subTest(reverse=reverse, dtype=dtype, device=device, ode=ode, method=method): f, y0, t_points, sol = construct_problem(dtype=dtype, device=device, ode=ode, reverse=reverse) y = torchdiffeq.odeint(f, y0, t_points[0:1], method=method) self.assertLess((sol[0] - y).abs().max(), 1e-12)
def forward_batched(self, x:torch.Tensor, nn:int, indices:list, timestamps:set): """ Modified forward for ODE batches with different integration times """ timestamps = torch.Tensor(list(timestamps)) if self.adjoint_flag: out = torchdiffeq.odeint_adjoint(self.odefunc, x, timestamps, rtol=self.rtol, atol=self.atol, method=self.method) else: out = torchdiffeq.odeint(self.odefunc, x, timestamps, rtol=self.rtol, atol=self.atol, method=self.method) out = self._build_batch(out, nn, indices).reshape(x.shape) return out
def forward(self, input, dt=0.1): pre_x, pre_y, forward_x = input if pre_x.shape[0] == 0: pre_x = torch.zeros( (1, pre_x.shape[1], pre_x.shape[2])).to(forward_x.device) pre_y = torch.zeros( (1, pre_y.shape[1], pre_y.shape[2])).to(forward_x.device) _, hn = self.rnn_encoder(torch.cat( [pre_x, pre_y], dim=2)) # hn (1, batch_size, hidden_num) t = torch.linspace(0, dt * forward_x.shape[0], forward_x.shape[0] + 1).to(pre_x.device) interpolation = Interpolation(t, torch.cat([pre_x[-1:], forward_x], dim=0), method=self.interpolation) self.ode_net.input_interpolation = interpolation if self.adjoint: from torchdiffeq import odeint_adjoint as odeint else: from torchdiffeq import odeint self.ode_net.cell.call_times = 0 forward_hn_all = odeint(self.ode_net, hn[0], t, rtol=self.rtol, atol=self.atol, method=self.solver)[1:] #forward_hn_all = odeint(self.ode_net, hn[0], t, rtol=self.rtol*len(t)/60, atol=self.atol*len(t)/60, method=self.solver)[1:] ############################################################### # start_hn = hn[0] # forward_hn_list = [] # import time # ti = time.time() # for i in range(0, len(t)-1, 100): # #cur_t = t[i:min(i+30+1, len(t))] # cur_t_len = min(101, len(t)-i) # cur_t = torch.linspace(0, dt * (cur_t_len-1), cur_t_len) # with torch.no_grad(): # forward_hn = odeint(self.ode_net, start_hn, cur_t, rtol=self.rtol, atol=self.atol, method=self.solver)[1:] # forward_hn_list.append(forward_hn) # start_hn = forward_hn[-1] # print(cur_t) # print('{}-{}-{}s-{}-mean:{}-var:{}'.format( # len(t), i, time.time()-ti, self.ode_net.cell.call_times, torch.mean(forward_hn), torch.var(forward_hn)) # ) # ti = time.time() # forward_hn_all = torch.cat(forward_hn_list, dim=0) ############################################################### estimate_y_all = self.recursive_predict(forward_hn_all, max_len=1000) return estimate_y_all
def numerically_integrate( self, state, u=0., T=1, time_steps=200, method='dopri5'): """ Numerical integration of the dynamical system, used as a baseline """ t = torch.linspace(0, T, time_steps).to(device) state = state.to(device) return odeint( lambda _, x: self.forward(x, u), state, t, method=method ).cpu()[: , 0, :]
def _integral_autograd(self, x): assert self.st[ 'cost'], 'Cost nn.Module needs to be specified for integral adjoint' ξ0 = 0. * torch.ones(1).to(x.device) ξ0 = ξ0.repeat(x.shape[0]).unsqueeze(1) x = torch.cat([x, ξ0], 1) return torchdiffeq.odeint(self._integral_autograd_defunc, x, self.s_span, rtol=self.st['rtol'], atol=self.st['atol'], method=self.st['solver'])
def forward(self, x:torch.Tensor, T:int=1): self.integration_time = torch.tensor([0, T]).float() self.integration_time = self.integration_time.type_as(x) if self.adjoint_flag: out = torchdiffeq.odeint_adjoint(self.odefunc, x, self.integration_time, rtol=self.rtol, atol=self.atol, method=self.method) else: out = torchdiffeq.odeint(self.odefunc, x, self.integration_time, rtol=self.rtol, atol=self.atol, method=self.method) return out[-1]
def get_obs(self, x): out_obs = self.fc_obs_in(x) self.integration_time = self.integration_time.type_as(out_obs) out_obs = odeint(self.odefunc_obs, out_obs, self.integration_time, rtol=self.tol, atol=self.tol)[1] # out_obs = self.odefunc_obs(0, out_obs) out_obs = self.fc_state_out(out_obs) return out_obs
def test_dopri8(self): for ode in problems.PROBLEMS.keys(): f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, ode=ode) y = torchdiffeq.odeint(f, y0, t_points, method='dopri8', rtol=1e-12, atol=1e-14) with self.subTest(ode=ode): self.assertLess(rel_error(sol, y), error_tol)
def invert(self, x, time): ''' Solves the ODE in the reverse time. ''' self.inverse_time = torch.tensor([time, 0]).float().type_as(x) out = odeint(self.odefunc, x, self.inverse_time, rtol=self.rtol, atol=self.atol) self.cost = self.odefunc.nfe return out[1]
device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu') true_y0 = torch.tensor([[2., 0.]]) t = torch.linspace(0., 25., args.data_size) true_A = torch.tensor([[-0.1, 2.0], [-2.0, -0.1]]) class Lambda(nn.Module): def forward(self, t, y): return torch.mm(y**3, true_A) with torch.no_grad(): true_y = odeint(Lambda(), true_y0, t, method='dopri5') def get_batch(): s = torch.from_numpy(np.random.choice(np.arange(args.data_size - args.batch_time, dtype=np.int64), args.batch_size, replace=False)) batch_y0 = true_y[s] # (M, D) batch_t = t[:args.batch_time] # (T) batch_y = torch.stack([true_y[s + i] for i in range(args.batch_time)], dim=0) # (T, M, D) return batch_y0, batch_t, batch_y def makedirs(dirname): if not os.path.exists(dirname): os.makedirs(dirname)