def train(self, xs, sys, nepochs, smooth_solver_kwargs=None, learner_criterion_kwargs=None, learner_opt_kwargs=None, subset=None, hp_schedulers=None): """ Runs one step of CEEM algorithm Args: xs (torch.tensor): sys (DiscreteDynamicalSystem): nepochs (int): Number of epochs to run CEEM algorithm for smooth_solver_kwargs (dict): learner_criterion_kwargs (dict): learner_opt_kwargs (dict): subset (int): """ if hp_schedulers is not None: if not isinstance(hp_schedulers, dict): assert len(hp_schedulers) == len(self._learning_criteria) for epoch in range(nepochs): xs, smooth_metrics, learn_metrics = self.step( xs, sys, smooth_solver_kwargs, learner_criterion_kwargs, learner_opt_kwargs, subset=subset, hp_schedulers=hp_schedulers) # Log all metrics. Run tensorboard to visualize log_kv_or_listkv(smooth_metrics, "smooth") log_kv_or_listkv(learn_metrics, "learn") for ecall in self._epoch_callbacks: ecall(epoch) if self._termination_callback(epoch): break logger.dumpkvs()
def train(self, params, callbacks=[]): vparams0 = parameters_to_vector(params).clone() xsms = [] t_start = timeit.default_timer() for k in range(self._max_k): with utils.Timer() as time: ## E-step xfilt, xfiltr, wfilt, meanll = self._fapf.filter(self._y) xsm = self._fapf.FFBSi(xfilt, wfilt) xsms.append(xsm) if self._xlen_cutoff: if len(xsms) > self._xlen_cutoff: xsms = xsms[-self._xlen_cutoff:] logger.logkv('train/Etime', time.dt) ## M-step with utils.Timer() as time: obj = lambda: -self.recursive_Q(xsms, self._y, 0, 0.) self._optimizer(obj, params) logger.logkv('train/Mtime', time.dt) logger.logkv('train/elapsedtime', timeit.default_timer() - t_start) ## log the current value of Q Q = float(self._fapf.Q_MCEM(self._y, xsms[-1])) logger.logkv('train/Q', Q) for callback in callbacks: callback(k) logger.dumpkvs() return params
def train(seed, logdir, sys_seed, k, b): from mpl_toolkits.mplot3d import Axes3D import matplotlib.pyplot as plt ystd = 0.01 # ystd = 0. torch.set_default_dtype(torch.float64) logger.setup(logdir, action='d') N = 128 n = 3 * k B = b true_system = default_lorenz_system(k, obsdif=2) utils.set_rng_seed(sys_seed) xdim = true_system.xdim ydim = true_system.ydim dt = true_system._dt x0mean = torch.tensor([[-6] * k + [-6] * k + [24.] * k]).unsqueeze(0) # simulate true_dynamics over IC distribution x_test = x0mean.repeat(1024, 1, 1) x_test += 5.0 * torch.randn_like(x_test) x_test = x_test.detach() t_test = torch.zeros(1024, 1) tgt_test = true_system.step_derivs(t_test, x_test).detach() ## simulate the true system xs = [x0mean.repeat(B, 1, 1)] xs[0] += 2.5 * torch.randn_like(xs[0]) with torch.no_grad(): for t in range(N - 1): xs.append(true_system.step(torch.tensor([0.] * B), xs[-1])) xs = torch.cat(xs, dim=1) fig = plt.figure() for b in range(B): ax = fig.add_subplot(int(np.ceil(B / 2.)), 2, b + 1, projection='3d') for k_ in range(k): plot3d(plt.gca(), xs[b, :, k_], xs[b, :, k + k_], xs[b, :, 2 * k + k_], linestyle='--', alpha=0.5) plt.savefig(os.path.join(logger.get_dir(), 'figs/traj_%d.png' % b), dpi=300) # plt.show() plt.close() t = torch.tensor(range(N)).unsqueeze(0).expand(B, -1).to( torch.get_default_dtype()) y = true_system.observe(t, xs).detach() # seed for real now utils.set_rng_seed(seed) y += ystd * torch.randn_like(y) # prep system system = deepcopy(true_system) true_params = parameters_to_vector(true_system.parameters()) utils.set_rng_seed(seed) params = true_params * ( (torch.rand_like(true_params) - 0.5) / 5. + 1.) # within 10% vector_to_parameters(params, system.parameters()) params = list(system.parameters()) # specify smoothing criteria smoothing_criteria = [] for b in range(B): obscrit = GaussianObservationCriterion(1.0 * torch.ones(ydim), t[b:b + 1], y[b:b + 1]) dyncrit = GaussianDynamicsCriterion(1e0 * torch.ones(xdim), t[b:b + 1]) smoothing_criteria.append(GroupSOSCriterion([obscrit, dyncrit])) smooth_solver_kwargs = {'verbose': 0, 'tr_rho': 0.01} # specify learning criteria learning_criteria = [GaussianDynamicsCriterion(1e0 * torch.ones(xdim), t)] learning_params = [params] # learning_opts = ['scipy_minimize'] # learner_opt_kwargs = {'method': 'Nelder-Mead', 'tr_rho': 0.1, # 'options':{'adaptive':True}} # learner_opt_kwargs = {'method': 'BFGS', 'tr_rho': 0.1, # 'options':{'disp':True}} learning_opts = ['torch_minimize'] # learner_opt_kwargs = { # 'method': 'Adam', # 'lr': 5e-4, # 'tr_rho': 0.1, # 'nepochs': 200, # 'max_grad_norm': 10.0 # } learner_opt_kwargs = {'method': 'LBFGS'} # instantiate CEEM def ecb(epoch): params = list(system.parameters()) vparams = parameters_to_vector(params) error = (vparams - true_params).norm().item() logger.logkv('test/log10_paramerror', np.log10(error)) return epoch_callbacks = [ecb] class Last10Errors: def __init__(self): return last_10_errors = Last10Errors last_10_errors._arr = [] def tcb(epoch): with torch.no_grad(): tgt_test_pr = system.step_derivs(t_test, x_test) error = float(torch.nn.functional.mse_loss(tgt_test_pr, tgt_test)) logger.logkv('test/log10_error', np.log10(error)) last_10_errors._arr.append(np.log10(error)) if len(last_10_errors._arr) > 100: last_10_errors._arr = last_10_errors._arr[-100:] l10err = torch.tensor(last_10_errors._arr) convcrit = float((l10err.min() - l10err.max()).abs()) logger.logkv('test/log10_convcrit', np.log10(convcrit)) if convcrit < 1e-3: return True return False termination_callback = tcb ecb(-1) tcb(-1) logger.dumpkvs() ceem = CEEM(smoothing_criteria, learning_criteria, learning_params, learning_opts, epoch_callbacks, termination_callback, parallel=min(4, B)) # run CEEM # x0 = torch.zeros_like(xs) x0 = xs + torch.randn_like(xs) ceem.train(xs=x0, sys=system, nepochs=5000, smooth_solver_kwargs=smooth_solver_kwargs, learner_opt_kwargs=learner_opt_kwargs) return
def train_net(net, train_data, train_trgt, test_data, test_trgt, valid_data, valid_trgt, y_std, lr, H, logdir, n_epochs=1000): T = train_data.shape[1] opt = torch.optim.Adam(net.parameters(), lr=lr) scheduler_off = 1000. scheduler = torch.optim.lr_scheduler.LambdaLR( opt, lambda epoch: scheduler_off / (scheduler_off + epoch)) best_val_loss = np.inf t0 = time.time() for e in range(n_epochs): logger.logkv('epoch', e) # train ll = [] coord_error = 0 for t in range(T - H + 1): opt.zero_grad() u = train_data[:, t:t + H] y = train_trgt[:, t + H - 1] y_pred = net(u) assert y.size() == y_pred.size() loss = compute_rms(y.unsqueeze(0), y_pred.unsqueeze(0), y_std) loss.backward() opt.step() ll.append(float(loss)) mean_train_loss = np.mean(ll) logger.logkv('log10_train_loss', np.log10(mean_train_loss)) coord_error /= (T - H) scheduler.step() for param_group in opt.param_groups: logger.logkv('log10_lr', np.log10(param_group['lr'])) if e % 100 == 0: # validation ll = [] coord_error = 0 for t in range(T - H): with torch.no_grad(): u = valid_data[:, t:t + H] y = valid_trgt[:, t + H - 1] y_pred = net(u) loss = compute_rms(y.unsqueeze(0), y_pred.unsqueeze(0), y_std) ll.append(float(loss)) mean_val_loss = np.mean(ll) logger.logkv('log10_val_loss', np.log10(mean_val_loss)) coord_error /= (T - H) # Test ll = [] coord_error = 0 for t in range(T - H): with torch.no_grad(): u = test_data[:, t:t + H] y = test_trgt[:, t + H - 1] y_pred = net(u) loss = compute_rms(y.unsqueeze(0), y_pred.unsqueeze(0), y_std) ll.append(float(loss)) mean_test_loss = np.mean(ll) logger.logkv('log10_test_loss', np.log10(mean_test_loss)) # Save if mean_val_loss < best_val_loss: torch.save(net.state_dict(), os.path.join(logdir, 'best_net.th')) best_val_loss = mean_val_loss if time.time() - t0 > 2: t0 = time.time() logger.dumpkvs() return net
def train(seed, logdir, ystd=0.1, wstd=0.01, sys_seed=4): print('\n\n\n##### SEED %d #####\n\n'%seed) torch.set_default_dtype(torch.float64) logger.setup(logdir, action='d') # Number of timesteps in the trajectory T = 128 n = 3 # Batch size B = 4 k = 1 utils.set_rng_seed(sys_seed) sys = default_lorenz_attractor() dt = sys._dt utils.set_rng_seed(seed) # simulate the system x0mean = torch.tensor([[-6] * k + [-6] * k + [24.] * k]) x0mean = x0mean.unsqueeze(0).repeat(B,1,1) # Rollout with noise Q = (wstd ** 2) * torch.eye(sys.xdim) R = (ystd ** 2) * torch.eye(sys.ydim) Px0 = 5.0 * torch.eye(sys.xdim) Qpdf = MultivariateNormal(torch.zeros((B,1,sys.xdim)), Q.unsqueeze(0).unsqueeze(0)) Rpdf = MultivariateNormal(torch.zeros((B,1,sys.ydim)), R.unsqueeze(0).unsqueeze(0)) Px0pdf = MultivariateNormal(x0mean, Px0.unsqueeze(0).unsqueeze(0)) xs = [Px0pdf.sample()] ys = [sys.observe(0, xs[0]) + Rpdf.sample()] for t in range(T-1): tinp = torch.tensor([t] * B).unsqueeze(1).to(dtype=torch.get_default_dtype()) xs.append(sys.step(tinp, xs[-1]) + Qpdf.sample()) ys.append(sys.observe(tinp, xs[-1]) + Rpdf.sample()) x = torch.cat(xs, dim=1) y = torch.cat(ys, dim=1) t = torch.tensor(range(T)).unsqueeze(0).to(torch.get_default_dtype()).repeat(B,1) m = y.shape[-1] fig = plt.figure() for b in range(B): ax = fig.add_subplot(int(np.ceil(B / 2.)), 2, b + 1, projection='3d') for k_ in range(k): plot3d(plt.gca(), x[b, :, k_], x[b, :, k + k_], x[b, :, 2 * k + k_], linestyle='--', alpha=0.5) plt.savefig(opj(logdir, 'traintrajs.png'), dpi=300) # prep system true_system = sys system = deepcopy(true_system) true_params = parameters_to_vector(true_system.parameters()) params = true_params * ((torch.rand_like(true_params) - 0.5) / 5. + 1.) # within 10% vector_to_parameters(params, system.parameters()) params = list(system.parameters()) # specify smoothing criteria smoothing_criteria = [] for b in range(B): obscrit = GaussianObservationCriterion(torch.ones(2), t[b:b + 1], y[b:b + 1]) dyncrit = GaussianDynamicsCriterion(wstd / ystd * torch.ones(3), t[b:b + 1]) smoothing_criteria.append(GroupSOSCriterion([obscrit, dyncrit])) smooth_solver_kwargs = {'verbose': 0, 'tr_rho': 0.001} # specify learning criteria learning_criteria = [GaussianDynamicsCriterion(torch.ones(3), t)] learning_params = [params] learning_opts = ['scipy_minimize'] learner_opt_kwargs = {'method': 'Nelder-Mead', 'tr_rho': 0.01} # instantiate CEEM timer = {'start_time':timeit.default_timer()} def ecb(epoch): logger.logkv('test/rho', float(system._rho)) logger.logkv('test/sigma', float(system._sigma)) logger.logkv('test/beta', float(system._beta)) logger.logkv('test/rho_pcterr_log10', float(torch.log10((true_system._rho - system._rho).abs() / true_system._rho))) logger.logkv( 'test/sigma_pcterr_log10', float(torch.log10((true_system._sigma - system._sigma).abs() / true_system._sigma))) logger.logkv( 'test/beta_pcterr_log10', float(torch.log10((true_system._beta - system._beta).abs() / true_system._beta))) logger.logkv('time/epochtime', timeit.default_timer() - timer['start_time']) timer['start_time'] = timeit.default_timer() return epoch_callbacks = [ecb] class Last10Errors: def __init__(self): return last_10_errors = Last10Errors last_10_errors._arr = [] def tcb(epoch): params = list(system.parameters()) vparams = parameters_to_vector(params) error = (vparams - true_params).norm().item() last_10_errors._arr.append(float(error)) logger.logkv('test/log10_error', np.log10(error)) if len(last_10_errors._arr) > 10: last_10_errors._arr = last_10_errors._arr[-10:] l10err = torch.tensor(last_10_errors._arr) convcrit = float((l10err.min() - l10err.max()).abs()) logger.logkv('test/log10_convcrit', np.log10(convcrit)) if convcrit < 1e-4: return True return False termination_callback = tcb ceem = CEEM(smoothing_criteria, learning_criteria, learning_params, learning_opts, epoch_callbacks, termination_callback) # run CEEM x0 = torch.zeros_like(x) ecb(-1) logger.dumpkvs() ceem.train(xs=x0, sys=system, nepochs=100, smooth_solver_kwargs=smooth_solver_kwargs, learner_opt_kwargs=learner_opt_kwargs) return float(system._sigma), float(system._rho), float(system._beta)
def run(seed, lr, method, noise, damped, smoketest=False): """ Args: seed (int): seed lr (float): init learning rate method (str): training method in ['qdd', 'nqqd', 'del+mnorm', 'del+logdet'] noise (float): amount of noise in [0.01,0.05, 0.1] damped (bool): use damped pendulum data smoketest (bool): if smoketest, runs 2 epochs """ torch.set_default_dtype(torch.float64) dtype = torch.get_default_dtype() torch.manual_seed(seed) # load the data dataset = 'damped_' if damped else '' noisedict = { 0.01: '0p01', 0.05: '0p05', 0.10: '0p10', 0.20: '0p20', 0.30: '0p30', 0.40: '0p40', 0.50: '0p50', 1.0: '1p0' } dataset += 'dubpen_%s_smoothed.td' % noisedict[noise] dataset = './datasets/' + dataset data = torch.load('./datasets/%sdubpen_qddot.td' % ('damped_' if damped else '')) data_ = torch.load(dataset) dt = 0.05 logdir = 'data/%s_%s_%.1e_%.3f_%d' % ('damped' if damped else 'undamped', method, lr, noise, seed) logger.setup(logdir, action='d') inds = torch.randperm(16) Btr = 8 Bte = 4 Bva = 4 train_data_ = data_[inds[:Btr]] test_data_ = data_[inds[Btr:Btr + Bte]] val_data_ = data_[inds[Btr + Bte:Btr + Bte + Bva]] train_data = data[inds[:Btr]] test_data = data[inds[Btr:Btr + Bte]] val_data = data[inds[Btr + Bte:Btr + Bte + Bva]] t_, smq, smdq, smddq = train_data_[:] ttest_, smqtest, smdqtest, smddqtest = test_data_[:] tval_, smqval, smdqval, smddqval = val_data_[:] t, q, dq, ddq = train_data[:] ttest, qtest, dqtest, ddqtest = test_data[:] tval, qval, dqval, ddqval = val_data[:] B, T, qdim = q.shape # create the appropriate dataloader if 'del' in method: smq_1 = smq[:, :-2] smq_2 = smq[:, 1:-1] smq_3 = smq[:, 2:] smq_B = torch.stack([smq_1, smq_2, smq_3], dim=2).reshape(-1, 3, 2).detach() print(smq_B.shape, smq.shape) dataset = TensorDataset(smq_B) dataloader = DataLoader(dataset, batch_size=256, shuffle=True) elif method == 'qdd': dataset = TensorDataset(smq.reshape(-1, 1, 2), smdq.reshape(-1, 1, 2), smddq.reshape(-1, 1, 2)) dataloader = DataLoader(dataset, batch_size=256, shuffle=True) elif method == 'nqqd': x = torch.cat([smq, smdq], dim=-1) inp = x[:, :-1] out = x[:, 1:] inp = inp.reshape(-1, 1, 4) out = out.reshape(-1, 1, 4) dataset = TensorDataset(inp, out) dataloader = DataLoader(dataset, batch_size=256, shuffle=True) else: raise NotImplementedError # set up logdir and model if damped: system = ForcedSMM(qdim=qdim, dt=dt) else: system = StructuredMechanicalModel(qdim=qdim, dt=dt) # create the appropriate closure def qddcrit(system, smq_, smdq_, smddq_): ddq_ = system.compute_qddot(smq_, smdq_, create_graph=True) ddq_loss = torch.nn.functional.mse_loss(ddq_, smddq_) return ddq_loss def nqqdcrit(system, inp, out): out_ = system.step(torch.ones_like(inp)[..., 0], inp) nqqd_loss = torch.nn.functional.mse_loss(out_, out) return nqqd_loss if 'del' in method: dyncrit = DELCriterion(t_) if 'logdet' in method: bc = LogDetBarrierCriterion bcf = LogDetBarrierCriterion.mineig else: bc = MxNormBarrierCriterion bcf = MxNormBarrierCriterion.mmxnorm # initialize the barrier criterion, and find an appropriate coefficient between it and DELcrit lb = bcf(system, smq).detach() * 0.99 # interior point init barriercrit = bc(lb) delcrit = DELCriterion(t) with torch.no_grad(): dyncritloss = dyncrit(system, smq) barriercritloss = barriercrit(system, smq) mu = float(dyncritloss / barriercritloss) # mu makes them ~equal at init barriercrit = bc(lb, mu=mu, x_override=smq) crit = GroupCriterion([dyncrit, barriercrit]) elif method == 'qdd': crit = qddcrit elif method == 'nqqd': crit = nqqdcrit else: raise NotImplementedError # setup optimizer, scheduler opt = torch.optim.Adam(system.parameters(), lr=lr) sched = torch.optim.lr_scheduler.LambdaLR(opt, lambda k: 500. / (500 + k)) # train best_val_loss = np.inf best_val_loss_test_qddot = np.inf next_params = ptv(system.parameters()).detach() for epoch in range(2 if smoketest else 500): # train with SGD for batch in dataloader: prev_params = next_params opt.zero_grad() loss = crit(system, *batch) loss.backward() opt.step() if 'del' in method: # check line search n_ls = 0 while True: next_params = ptv(system.parameters()).detach() del_params = next_params - prev_params with torch.no_grad(): c = crit(system, smq) if torch.isnan(c): next_params = prev_params + 0.5 * del_params vtp(next_params, system.parameters()) n_ls += 1 else: break sched.step() with torch.no_grad(): val_sqmddqloss = qddcrit(system, smqval, smdqval, smddqval) train_qdd_loss = qddcrit(system, q, dq, ddq) test_qdd_loss = qddcrit(system, qtest, dqtest, ddqtest) val_qdd_loss = qddcrit(system, qval, dqval, ddqval) # select best model using validation error if val_sqmddqloss < best_val_loss: best_val_loss = float(val_sqmddqloss) best_val_loss_test_qddot_loss = float(test_qdd_loss) torch.save( system.state_dict(), os.path.join(logger.get_dir(), 'ckpts', 'best_model.th')) logger.logkv("train/epoch", epoch) logger.logkv("train/loss", float(loss)) logger.logkv("train/log10lr", np.log10(float(opt.param_groups[0]['lr']))) logger.logkv("eval/val_sqmddqloss", float(val_sqmddqloss)) logger.logkv("eval/train_qdd_loss", float(train_qdd_loss)) logger.logkv("eval/test_qdd_loss", float(test_qdd_loss)) logger.logkv("eval/val_qdd_loss", float(val_qdd_loss)) logger.logkv("eval/best_val_loss", float(best_val_loss)) logger.logkv("eval/best_val_loss_test_qddot_loss", float(best_val_loss_test_qddot_loss)) logger.dumpkvs()
def train(seed, logdir, sys_seed, k, b): from mpl_toolkits.mplot3d import Axes3D import matplotlib.pyplot as plt ystd = 0.01 # ystd = 0. torch.set_default_dtype(torch.float64) logger.setup(logdir, action='d') N = 128 n = 3 * k B = b true_system = default_lorenz_system(k, obsdif=2) utils.set_rng_seed(sys_seed) xdim = true_system.xdim ydim = true_system.ydim dt = true_system._dt x0mean = torch.tensor([[-6] * k + [-6] * k + [24.] * k]).unsqueeze(0) # simulate true_dynamics over IC distribution x_test = x0mean.repeat(1024, 1, 1) x_test += 5.0 * torch.randn_like(x_test) x_test = x_test.detach() t_test = torch.zeros(1024, 1) tgt_test = true_system.step_derivs(t_test, x_test).detach() Q = (0.01**2) * torch.eye(true_system.xdim) R = (ystd**2) * torch.eye(true_system.ydim) Px0 = 2.5**2 * torch.eye(true_system.xdim) ## simulate the true system xs = [x0mean.repeat(B, 1, 1)] xs[0] += 2.5 * torch.randn_like(xs[0]) with torch.no_grad(): for t in range(N - 1): xs.append(true_system.step(torch.tensor([0.] * B), xs[-1])) xs = torch.cat(xs, dim=1) fig = plt.figure() for b in range(B): ax = fig.add_subplot(int(np.ceil(B / 2.)), 2, b + 1, projection='3d') for k_ in range(k): plot3d(plt.gca(), xs[b, :, k_], xs[b, :, k + k_], xs[b, :, 2 * k + k_], linestyle='--', alpha=0.5) plt.savefig(os.path.join(logger.get_dir(), 'figs/traj_%d.png' % b), dpi=300) # plt.show() plt.close() t = torch.tensor(range(N)).unsqueeze(0).expand(B, -1).to( torch.get_default_dtype()) y = true_system.observe(t, xs).detach() # seed for real now utils.set_rng_seed(seed) y += ystd * torch.randn_like(y) # prep system system = deepcopy(true_system) true_params = parameters_to_vector(true_system.parameters()) utils.set_rng_seed(seed) params = true_params * ( (torch.rand_like(true_params) - 0.5) / 5. + 1.) # within 10% vector_to_parameters(params, system.parameters()) params = list(system.parameters()) Np = 100 fapf = faPF(Np, system, Q, R, Px0) timer = {'start_time': timeit.default_timer()} def ecb(epoch): logger.logkv('time/epoch', epoch) params = list(system.parameters()) vparams = parameters_to_vector(params) error = (vparams - true_params).norm().item() logger.logkv('test/log10_paramerror', np.log10(error)) logger.logkv('time/epochtime', timeit.default_timer() - timer['start_time']) timer['start_time'] = timeit.default_timer() with torch.no_grad(): tgt_test_pr = system.step_derivs(t_test, x_test) error = float(torch.nn.functional.mse_loss(tgt_test_pr, tgt_test)) logger.logkv('test/log10_error', np.log10(error)) return epoch_callbacks = [ecb] ecb(-1) logger.dumpkvs() trainer = SAEMTrainer( fapf, y, # gamma_sched=lambda x: HarmonicDecayScheduler(x, a=50.), gamma_sched=lambda x: 0.2, max_k=5000, xlen_cutoff=15, ) trainer.train(params, callbacks=epoch_callbacks) return