def train(args): # set random seed torch.manual_seed(args.seed) np.random.seed(args.seed) # init model and optimizer if args.verbose: print("Training baseline model:" if args.baseline else "Training HNN model:") output_dim = args.input_dim if args.baseline else 2 nn_model = MLP(args.input_dim, args.hidden_dim, output_dim, args.nonlinearity) model = HNN(args.input_dim, differentiable_model=nn_model, field_type=args.field_type, baseline=args.baseline) optim = torch.optim.Adam(model.parameters(), args.learn_rate, weight_decay=0) # arrange data data = get_dataset(args.name, args.save_dir, verbose=True) x = torch.tensor( data['coords'], requires_grad=True, dtype=torch.float32) test_x = torch.tensor( data['test_coords'], requires_grad=True, dtype=torch.float32) dxdt = torch.Tensor(data['dcoords']) test_dxdt = torch.Tensor(data['test_dcoords']) # vanilla train loop stats = {'train_loss': [], 'test_loss': []} for step in range(args.total_steps+1): # train step ixs = torch.randperm(x.shape[0])[:args.batch_size] dxdt_hat = model.time_derivative(x[ixs]) dxdt_hat += args.input_noise * torch.randn(*x[ixs].shape) # add noise, maybe loss = L2_loss(dxdt[ixs], dxdt_hat) loss.backward() grad = torch.cat([p.grad.flatten() for p in model.parameters()]).clone() optim.step() ; optim.zero_grad() # run test data test_ixs = torch.randperm(test_x.shape[0])[:args.batch_size] test_dxdt_hat = model.time_derivative(test_x[test_ixs]) test_dxdt_hat += args.input_noise * torch.randn(*test_x[test_ixs].shape) # add noise, maybe test_loss = L2_loss(test_dxdt[test_ixs], test_dxdt_hat) # logging stats['train_loss'].append(loss.item()) stats['test_loss'].append(test_loss.item()) if args.verbose and step % args.print_every == 0: print("step {}, train_loss {:.4e}, test_loss {:.4e}, grad norm {:.4e}, grad std {:.4e}" .format(step, loss.item(), test_loss.item(), grad@grad, grad.std())) train_dxdt_hat = model.time_derivative(x) train_dist = (dxdt - train_dxdt_hat)**2 test_dxdt_hat = model.time_derivative(test_x) test_dist = (test_dxdt - test_dxdt_hat)**2 print('Final train loss {:.4e} +/- {:.4e}\nFinal test loss {:.4e} +/- {:.4e}' .format(train_dist.mean().item(), train_dist.std().item()/np.sqrt(train_dist.shape[0]), test_dist.mean().item(), test_dist.std().item()/np.sqrt(test_dist.shape[0]))) return model, stats
def train(args): # import ODENet # from torchdiffeq import odeint from torchdiffeq import odeint_adjoint as odeint device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu') # reproducibility: set random seed torch.manual_seed(args.seed) np.random.seed(args.seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # init model and optimizer if args.verbose: print("Start training with num of points = {} and solver {}.".format(args.num_points, args.solver)) if args.structure == False and args.baseline == True: nn_model = MLP(args.input_dim, 600, args.input_dim, args.nonlinearity).to(device) model = SymODEN_R(args.input_dim, H_net=nn_model, device=device, baseline=True) elif args.structure == False and args.baseline == False: H_net = MLP(args.input_dim, 400, 1, args.nonlinearity).to(device) g_net = MLP(int(args.input_dim/2), 200, int(args.input_dim/2)).to(device) model = SymODEN_R(args.input_dim, H_net=H_net, g_net=g_net, device=device, baseline=False) elif args.structure == True and args.baseline ==False: M_net = MLP(int(args.input_dim/2), 300, int(args.input_dim/2)) V_net = MLP(int(args.input_dim/2), 50, 1).to(device) g_net = MLP(int(args.input_dim/2), 200, int(args.input_dim/2)).to(device) model = SymODEN_R(args.input_dim, M_net=M_net, V_net=V_net, g_net=g_net, device=device, baseline=False, structure=True).to(device) else: raise RuntimeError('argument *baseline* and *structure* cannot both be true') num_parm = get_model_parm_nums(model) print('model contains {} parameters'.format(num_parm)) optim = torch.optim.Adam(model.parameters(), args.learn_rate, weight_decay=1e-4) # arrange data us = [-2.0, -1.0, 0.0, 1.0, 2.0] # us = [0.0] data = get_dataset(seed=args.seed, timesteps=45, save_dir=args.save_dir, rad=args.rad, us=us, samples=50) train_x, t_eval = arrange_data(data['x'], data['t'], num_points=args.num_points) test_x, t_eval = arrange_data(data['test_x'], data['t'], num_points=args.num_points) train_x = torch.tensor(train_x, requires_grad=True, dtype=torch.float32).to(device) test_x = torch.tensor(test_x, requires_grad=True, dtype=torch.float32).to(device) t_eval = torch.tensor(t_eval, requires_grad=True, dtype=torch.float32).to(device) # training loop stats = {'train_loss': [], 'test_loss': [], 'forward_time': [], 'backward_time': [],'nfe': []} for step in range(args.total_steps+1): train_loss = 0 test_loss = 0 for i in range(train_x.shape[0]): t = time.time() train_x_hat = odeint(model, train_x[i, 0, :, :], t_eval, method=args.solver) forward_time = time.time() - t train_loss_mini = L2_loss(train_x[i,:,:,:], train_x_hat) train_loss = train_loss + train_loss_mini t = time.time() train_loss_mini.backward() ; optim.step() ; optim.zero_grad() backward_time = time.time() - t # run test data test_x_hat = odeint(model, test_x[i, 0, :, :], t_eval, method=args.solver) test_loss_mini = L2_loss(test_x[i,:,:,:], test_x_hat) test_loss = test_loss + test_loss_mini # logging stats['train_loss'].append(train_loss.item()) stats['test_loss'].append(test_loss.item()) stats['forward_time'].append(forward_time) stats['backward_time'].append(backward_time) stats['nfe'].append(model.nfe) if args.verbose and step % args.print_every == 0: print("step {}, train_loss {:.4e}, test_loss {:.4e}".format(step, train_loss.item(), test_loss.item())) # calculate loss mean and std for each traj. train_x, t_eval = data['x'], data['t'] test_x, t_eval = data['test_x'], data['t'] train_x = torch.tensor(train_x, requires_grad=True, dtype=torch.float32).to(device) test_x = torch.tensor(test_x, requires_grad=True, dtype=torch.float32).to(device) t_eval = torch.tensor(t_eval, requires_grad=True, dtype=torch.float32).to(device) train_loss = [] test_loss = [] for i in range(train_x.shape[0]): train_x_hat = odeint(model, train_x[i, 0, :, :], t_eval, method=args.solver) train_loss.append((train_x[i,:,:,:] - train_x_hat)**2) # run test data test_x_hat = odeint(model, test_x[i, 0, :, :], t_eval, method=args.solver) test_loss.append((test_x[i,:,:,:] - test_x_hat)**2) train_loss = torch.cat(train_loss, dim=1) train_loss_per_traj = torch.sum(train_loss, dim=(0,2)) test_loss = torch.cat(test_loss, dim=1) test_loss_per_traj = torch.sum(test_loss, dim=(0,2)) print('Final trajectory train loss {:.4e} +/- {:.4e}\nFinal trajectory test loss {:.4e} +/- {:.4e}' .format(train_loss_per_traj.mean().item(), train_loss_per_traj.std().item(), test_loss_per_traj.mean().item(), test_loss_per_traj.std().item())) stats['traj_train_loss'] = train_loss_per_traj.detach().cpu().numpy() stats['traj_test_loss'] = test_loss_per_traj.detach().cpu().numpy() return model, stats
def train(args): device = 'cuda:0' if torch.cuda.is_available() else 'cpu' dtype = torch.get_default_dtype() torch.set_grad_enabled(False) # set random seed torch.manual_seed(args.seed) np.random.seed(args.seed) # init model and optimizer model = dgnet.DGNet(args.input_dim, args.hidden_dim, nonlinearity=args.nonlinearity, friction=args.friction, model=args.model, solver=args.solver) model = model.to(device) optim = torch.optim.Adam(model.parameters(), args.learn_rate, weight_decay=1e-5) # arrange data data = get_dataset('pend-real', args.save_dir) train_x = torch.tensor(data['x'], requires_grad=True, device=device, dtype=dtype) test_x = torch.tensor(data['test_x'], requires_grad=True, device=device, dtype=dtype) train_dxdt = torch.tensor(data['dx'], device=device, dtype=dtype) test_dxdt = torch.tensor(data['test_dx'], device=device, dtype=dtype) input_dim = train_x.shape[-1] x1 = train_x[:-1].detach() x2 = train_x[1:].detach() dxdt = train_dxdt[:-1].clone() dt = 1 / 6. # vanilla train loop stats = {'train_loss': [], 'test_loss': []} for step in range(args.total_steps + 1): with torch.enable_grad(): # train step dxdt_hat = model.discrete_time_derivative(x1, dt=dt, x2=x2) loss = L2_loss(dxdt, dxdt_hat) optim.zero_grad() loss.backward() optim.step() # run validation if args.solver == 'implicit': # because it consumes too long time. test_loss = torch.tensor(float('nan')) else: test_dxdt_hat = model.discrete_time_derivative(test_x, dt=dt) test_loss = L2_loss(test_dxdt, test_dxdt_hat) # logging stats['train_loss'].append(loss.item()) stats['test_loss'].append(test_loss.item()) if args.verbose and step % args.print_every == 0: print("step {}, train_loss {:.4e}, test_loss {:.4e}".format( step, loss.item(), test_loss.item())) if args.friction: print("friction g =", model.g.detach().cpu().numpy()) dxdt_hat = model.discrete_time_derivative(train_x, dt=dt) train_dist = (train_dxdt - dxdt_hat)**2 test_dxdt_hat = model.discrete_time_derivative(test_x, dt=dt) test_dist = (test_dxdt - test_dxdt_hat)**2 print('Final train loss {:.4e}\nFinal test loss {:.4e}'.format( train_dist.mean().item(), test_dist.mean().item())) stats['final_train_loss'] = train_dist.mean().item() stats['final_test_loss'] = test_dist.mean().item() # energy error from data import hamiltonian_fn t_eval = np.squeeze(data['test_t'] - data['test_t'].min()) t_span = [t_eval.min(), t_eval.max()] x0 = test_x[0] true_orbits = test_x.detach().cpu().numpy() model_orbits = model.get_orbit(x0, t_eval=t_eval).detach().cpu().numpy() true_energies = np.stack([hamiltonian_fn(c) for c in true_orbits]) model_energies = np.stack([hamiltonian_fn(c) for c in model_orbits]) stats['true_orbits'] = true_orbits stats['model_orbits'] = model_orbits stats['true_energies'] = true_energies stats['model_energies'] = model_energies distance_energy = (true_energies - model_energies)**2 distance_state = (true_orbits - model_orbits)**2 stats['energy_mse_mean'] = np.mean(distance_energy) print("energy MSE {:.4e}".format(stats['energy_mse_mean'])) stats['state_mse_mean'] = np.mean(distance_state) print("state MSE {:.4e}".format(stats['state_mse_mean'])) return model, stats
def train(args): # set random seed torch.manual_seed(args.seed) np.random.seed(args.seed) # init model and optimizer if args.verbose: print("Training baseline model:" if args.baseline else "Training HNN model:") output_dim = args.input_dim if args.baseline else 2 nn_model = MLP(args.input_dim, 400, output_dim, args.nonlinearity) model = HNN(args.input_dim, differentiable_model=nn_model, field_type=args.field_type, baseline=args.baseline) optim = torch.optim.Adam(model.parameters(), args.learn_rate, weight_decay=1e-4) # the data API is different # make sure it is a fair comparison # generate the data the same way as in the SymODEN # compute the time derivative based on the generated data us = [0.0] data = get_dataset(seed=args.seed, save_dir=args.save_dir, rad=args.rad, us=us, samples=50, timesteps=45) # arrange data train_x, t_eval = data['x'][0,:,:,0:2], data['t'] test_x, t_eval = data['test_x'][0,:,:,0:2], data['t'] train_dxdt = (train_x[1:,:,:] - train_x[:-1,:,:]) / (t_eval[1] - t_eval[0]) test_dxdt = (test_x[1:,:,:] - test_x[:-1,:,:]) / (t_eval[1] - t_eval[0]) train_x = train_x[0:-1,:,:].reshape((-1,2)) test_x = test_x[0:-1,:,:].reshape((-1,2)) test_dxdt = test_dxdt.reshape((-1,2)) train_dxdt = train_dxdt.reshape((-1,2)) x = torch.tensor( train_x, requires_grad=True, dtype=torch.float32) test_x = torch.tensor( test_x, requires_grad=True, dtype=torch.float32) dxdt = torch.Tensor(train_dxdt) test_dxdt = torch.Tensor(test_dxdt) # vanilla train loop stats = {'train_loss': [], 'test_loss': []} for step in range(args.total_steps+1): # train step dxdt_hat = model.rk4_time_derivative(x) if args.use_rk4 else model.time_derivative(x) loss = L2_loss(dxdt, dxdt_hat) loss.backward() ; optim.step() ; optim.zero_grad() # run test data test_dxdt_hat = model.rk4_time_derivative(test_x) if args.use_rk4 else model.time_derivative(test_x) test_loss = L2_loss(test_dxdt, test_dxdt_hat) # logging stats['train_loss'].append(loss.item()) stats['test_loss'].append(test_loss.item()) if args.verbose and step % args.print_every == 0: print("step {}, train_loss {:.4e}, test_loss {:.4e}".format(step, loss.item(), test_loss.item())) train_dxdt_hat = model.time_derivative(x) train_dist = (dxdt - train_dxdt_hat)**2 test_dxdt_hat = model.time_derivative(test_x) test_dist = (test_dxdt - test_dxdt_hat)**2 print('Final train loss {:.4e} +/- {:.4e}\nFinal test loss {:.4e} +/- {:.4e}' .format(train_dist.mean().item(), train_dist.std().item()/np.sqrt(train_dist.shape[0]), test_dist.mean().item(), test_dist.std().item()/np.sqrt(test_dist.shape[0]))) return model, stats
def train(args): device = 'cuda:0' if torch.cuda.is_available() else 'cpu' dtype = torch.get_default_dtype() torch.set_grad_enabled(False) # set random seed torch.manual_seed(args.seed) np.random.seed(args.seed) # init model and optimizer model = dgnet.DGNet(args.input_dim, args.hidden_dim, nonlinearity=args.nonlinearity, model=args.model, solver=args.solver) model = model.to(device) optim = torch.optim.Adam(model.parameters(), args.learn_rate, weight_decay=0) # arrange data t_span = 25 length = 500 dt = t_span / (length - 1) data = get_dataset(args.name, args.save_dir, verbose=True, timesteps=length, t_span=[0, t_span]) train_x = torch.tensor(data['coords'], requires_grad=True, device=device, dtype=dtype) test_x = torch.tensor(data['test_coords'], requires_grad=True, device=device, dtype=dtype) train_dxdt = torch.tensor(data['dcoords'], device=device, dtype=dtype) test_dxdt = torch.tensor(data['test_dcoords'], device=device, dtype=dtype) x_reshaped = train_x.view(-1, length, args.input_dim) x1 = x_reshaped[:, :-1].contiguous().view(-1, args.input_dim) x2 = x_reshaped[:, 1:].contiguous().view(-1, args.input_dim) dxdt = ((x2 - x1) / dt).detach() # vanilla train loop stats = {'train_loss': [], 'test_loss': []} for step in range(args.total_steps + 1): with torch.enable_grad(): # train step ixs = torch.randperm(x1.shape[0])[:args.batch_size] dxdt_hat = model.discrete_time_derivative(x1[ixs], dt=dt, x2=x2[ixs]) dxdt_hat += args.input_noise * torch.randn_like(x1[ixs]) # add noise, maybe loss = L2_loss(dxdt[ixs], dxdt_hat) loss.backward() optim.step() optim.zero_grad() # run test data test_ixs = torch.randperm(test_x.shape[0], device=device)[:args.batch_size] test_dxdt_hat = model.time_derivative(test_x[test_ixs]) test_dxdt_hat += args.input_noise * torch.randn_like(test_x[test_ixs]) # add noise, maybe test_loss = L2_loss(test_dxdt[test_ixs], test_dxdt_hat) # logging stats['train_loss'].append(loss.item()) stats['test_loss'].append(test_loss.item()) if args.verbose and step % args.print_every == 0: print("step {}, train_loss {:.4e}, test_loss {:.4e}" .format(step, loss.item(), test_loss.item())) if args.friction: print("friction g =", model.g.detach().cpu().numpy()) train_dxdt_hat = model.time_derivative(train_x) train_dist = (train_dxdt - train_dxdt_hat)**2 test_dxdt_hat = model.time_derivative(test_x) test_dist = (test_dxdt - test_dxdt_hat)**2 print('Final train loss {:.4e}\nFinal test loss {:.4e}' .format(train_dist.mean().item(), test_dist.mean().item())) stats['final_train_loss'] = train_dist.mean().item() stats['final_test_loss'] = test_dist.mean().item() # energy error from data import get_orbit, random_config, total_energy t_span = [0, t_span] t_points = length trials = 5 * 3 true_energies, model_energies, true_orbits, model_orbits = [], [], [], [] for trial_ix in range(trials): np.random.seed(trial_ix) state = random_config() # true trajectory -> energy orbit, settings = get_orbit(state, t_points=t_points, t_span=t_span) true_orbits.append(orbit) true_energies.append(total_energy(orbit)) # model trajectory -> energy mass = state[:, :1] x0 = state[:, 1:].T.reshape(-1) model_orbit = model.get_orbit(x0, t_eval=np.linspace(t_span[0], t_span[1], t_points)) model_orbit = model_orbit.reshape(model_orbit.shape[0], 4, 2) model_orbit = np.concatenate([np.ones((model_orbit.shape[0], 1, 2)), model_orbit], axis=1) model_orbit = model_orbit.transpose(2, 1, 0) model_orbits.append(model_orbit) model_energies.append(total_energy(model_orbit)) true_energies = np.stack(true_energies) model_energies = np.stack(model_energies) true_orbits = np.stack(true_orbits) model_orbits = np.stack(model_orbits) stats['true_orbits'] = true_orbits stats['model_orbits'] = model_orbits stats['true_energies'] = true_energies stats['model_energies'] = model_energies distance_energy = (true_energies - model_energies)**2 distance_state = (true_orbits - model_orbits)**2 stats['energy_mse_mean'] = np.mean(distance_energy) print("energy MSE {:.4e}".format(stats['energy_mse_mean'])) stats['state_mse_mean'] = np.mean(distance_state) print("state MSE {:.4e}".format(stats['state_mse_mean'])) return model, stats
def train(args): device = 'cuda:0' if torch.cuda.is_available() else 'cpu' n_available_GPUs = torch.cuda.device_count() dtype = torch.get_default_dtype() torch.set_grad_enabled(False) # set random seed torch.manual_seed(args.seed) np.random.seed(args.seed) # arrange data data = get_dataset(args.name, args.save_dir, verbose=True, device='cpu', test_split=0.1) train_u = torch.tensor(data['u'], requires_grad=True, device=device, dtype=dtype) test_u = torch.tensor(data['test_u'], requires_grad=True, device=device, dtype=dtype) train_dudt = torch.tensor(data['dudt'], device=device, dtype=dtype) test_dudt = torch.tensor(data['test_dudt'], device=device, dtype=dtype) t_eval = data['t_eval'] dt = data['dt'] M = test_u.shape[-1] train_shape_origin = train_u.shape test_shape_origin = test_u.shape u1 = train_u[:, :-1].contiguous().view(-1, 1, train_u.shape[-1]) u2 = train_u[:, 1:].contiguous().view(-1, 1, train_u.shape[-1]) dudt = ((u2 - u1) / dt).detach() train_u = train_u.view(-1, 1, train_u.shape[-1]) test_u = test_u.view(-1, 1, test_u.shape[-1]) train_dudt = train_dudt.view(-1, 1, train_dudt.shape[-1]) test_dudt = test_dudt.view(-1, 1, test_dudt.shape[-1]) # init model and optimizer alpha = 2 if args.name.startswith('ch') else 1 model = dgnet.DGNetPDE1d(args.input_dim, args.hidden_dim, nonlinearity=args.nonlinearity, model=args.model, solver=args.solver, name=args.name, dx=data['dx'], alpha=alpha) print(model) model = model.to(device) stats = {'train_loss': [], 'test_loss': []} import glob files = glob.glob('{}.tar'.format(args.result_path)) if len(files) > 0: f = files[0] path_tar = f model.load_state_dict(torch.load(path_tar, map_location=device)) path_pkl = f.replace('.tar', '.pkl') stats = from_pickle(path_pkl) args.total_steps = 0 print('Model successfully loaded from {}'.format(path_tar)) if args.load: path_tar = '{}.tar'.format(args.result_path).replace('_long', '') model.load_state_dict(torch.load(path_tar, map_location=device)) args.total_steps = 0 print('Model successfully loaded from {}'.format(path_tar)) optim = torch.optim.Adam(model.parameters(), args.learn_rate, weight_decay=0) # vanilla train loop for step in range(args.total_steps): # train step idx = torch.randperm(u1.shape[0])[:args.batch_size] with torch.enable_grad(): if n_available_GPUs > 1: dudt_hat = torch.nn.parallel.data_parallel( model, u1[idx], module_kwargs={ 'dt': dt, 'x2': u2[idx], 'func': 'discrete_time_derivative' }) else: dudt_hat = model.discrete_time_derivative(u1[idx], dt=dt, x2=u2[idx]) loss = L2_loss(dudt[idx], dudt_hat) optim.zero_grad() loss.backward() optim.step() # run test data test_idx = torch.randperm(test_u.shape[0])[:args.batch_size] test_dudt_hat = model.time_derivative(test_u[test_idx]) test_loss = L2_loss(test_dudt[test_idx], test_dudt_hat) # logging stats['train_loss'].append(loss.item()) stats['test_loss'].append(test_loss.item()) if args.verbose and step % args.print_every == 0: print("step {}, train_loss {:.4e}, test_loss {:.4e}".format( step, loss.item(), test_loss.item())) if len(train_u) > 0: train_dudt_hat = torch.cat([ model.time_derivative(train_u[idx:idx + args.batch_size]) for idx in range(0, len(train_u), args.batch_size) ], dim=0) train_dist = (train_dudt - train_dudt_hat)**2 test_dudt_hat = torch.cat([ model.time_derivative(test_u[idx:idx + args.batch_size]) for idx in range(0, len(test_u), args.batch_size) ], dim=0) test_dist = (test_dudt - test_dudt_hat)**2 print('Final train loss {:.4e}\nFinal test loss {:.4e}'.format( train_dist.mean().item(), test_dist.mean().item())) stats['final_train_loss'] = train_dist.mean().item() stats['final_test_loss'] = test_dist.mean().item() else: stats['final_train_loss'] = 0.0 stats['final_test_loss'] = 0.0 # sequence generator os.makedirs('{}/results/'.format(args.save_dir), exist_ok=True) print('Generating test sequences') train_u = train_u.view(*train_shape_origin) test_u = test_u.view(*test_shape_origin) import matplotlib as mpl mpl.use('Agg') import matplotlib.pyplot as plt test_u_truth = [] test_u_model = [] for idx in range(len(test_u)): print('Generating a sequence {}/{}'.format(idx, len(test_u)), end='\r') u_truth = test_u[idx].squeeze(1).detach().cpu().numpy() u_model = model.get_orbit( x0=test_u[idx, :1], t_eval=t_eval).squeeze(2).squeeze(1).detach().cpu().numpy() test_u_truth.append(u_truth) test_u_model.append(u_model) energy_truth = data['model'].get_energy(u_truth) energy_model = data['model'].get_energy(u_model) if args.model != 'node': energy_model_truth = model( torch.from_numpy(u_truth).unsqueeze(-2).to(device)).squeeze( 2).squeeze(1).detach().cpu().numpy() * data['dx'] energy_model_model = model( torch.from_numpy(u_model).unsqueeze(-2).to(device)).squeeze( 2).squeeze(1).detach().cpu().numpy() * data['dx'] mass_truth = u_truth.sum(-1) mass_model = u_model.sum(-1) if args.name.startswith('ch'): vmax = 1 vmin = -1 else: vmax = max(np.abs(u_truth).max(), np.abs(u_model).max()) vmin = -vmax fig, (ax1, ax2, ax3, ax4) = plt.subplots(4, 1, sharex=True, figsize=(9., 15.), facecolor='white') ax1.imshow(u_truth.T, interpolation='nearest', vmin=vmin, vmax=vmax, cmap='seismic') ax1.set_aspect('auto') ax1.set_yticks((-0.5, M - 0.5)) ax1.set_yticklabels((0, 1)) ax2.imshow(u_model.T, interpolation='nearest', vmin=vmin, vmax=vmax, cmap='seismic') ax2.set_aspect('auto') ax2.set_yticks((-0.5, M - 0.5)) ax2.set_yticklabels((0, 1)) ax3.plot([], [], color='white', label='energy') if args.model != 'node': ax3.plot(energy_model_truth - energy_model_truth[0], dashes=[2, 2], color='C0') ax3.plot(energy_model_model - energy_model_model[0], dashes=[2, 2], color='C1') ax3.plot(energy_truth - energy_truth[0], color='C0', label='ground truth') ax3.plot(energy_model - energy_model[0], color='C1', label=args.model) ax3.legend() ax4.plot([], [], color='white', label='mass') ax4.plot(mass_truth, color='C0') ax4.plot(mass_model, color='C1') ax4.set_xticks(t_eval[::len(t_eval) // 5] / dt) ax4.set_xticklabels(t_eval[::len(t_eval) // 5]) ax4.set_xlabel('time') fig.savefig('{}_plot{:02d}.png'.format(args.result_path, idx)) plt.close() test_u_truth = np.stack(test_u_truth, axis=0)[:, 1:] test_u_model = np.stack(test_u_model, axis=0)[:, 1:] energy_truth = data['model'].get_energy(test_u_truth) energy_model = data['model'].get_energy(test_u_model) print('energy MSE model', ((energy_truth - energy_model)**2).mean()) stats['energy_mse_mean'] = ((energy_truth - energy_model)**2).mean() print('state MSE model', ((test_u_truth - test_u_model)**2).mean()) stats['state_mse_mean'] = ((test_u_truth - test_u_model)**2).mean() stats['test_u_truth'] = test_u_truth stats['test_u_model'] = test_u_model stats['energy_truth'] = energy_truth stats['energy_model'] = energy_model if args.model != 'node': energy_model_truth = model( torch.from_numpy(test_u_truth).reshape( -1, 1, test_u_truth.shape[-1]).to(device)).detach().cpu( ).numpy().reshape(*test_u_truth.shape[:-1]) energy_model_model = model( torch.from_numpy(test_u_model).reshape( -1, 1, test_u_model.shape[-1]).to(device)).detach().cpu( ).numpy().reshape(*test_u_model.shape[:-1]) stats['energy_model_truth'] = energy_model_truth stats['energy_model_model'] = energy_model_model return model, stats
def train(args): # set random seed torch.manual_seed(args.seed) np.random.seed(args.seed) # init model and optimizer if args.verbose: print("Training baseline model:" if args. baseline else "Training HNN model:") output_dim = args.input_dim if args.baseline else 2 nn_model = MLP(args.input_dim, args.hidden_dim, output_dim, args.nonlinearity) model = HNN(args.input_dim, differentiable_model=nn_model, field_type=args.field_type, baseline=args.baseline) optim = torch.optim.Adam(model.parameters(), args.learn_rate, weight_decay=1e-4) # arrange data data = get_dataset(seed=args.seed) x = torch.tensor(data['x'], requires_grad=True, dtype=torch.float32) test_x = torch.tensor(data['test_x'], requires_grad=True, dtype=torch.float32) dxdt = torch.Tensor(data['dx']) test_dxdt = torch.Tensor(data['test_dx']) # vanilla train loop stats = {'train_loss': [], 'test_loss': []} for step in range(args.total_steps + 1): # train step dxdt_hat = model.rk4_time_derivative( x) if args.use_rk4 else model.time_derivative(x) loss = L2_loss(dxdt, dxdt_hat) loss.backward() optim.step() optim.zero_grad() # run test data test_dxdt_hat = model.rk4_time_derivative( test_x) if args.use_rk4 else model.time_derivative(test_x) test_loss = L2_loss(test_dxdt, test_dxdt_hat) # logging stats['train_loss'].append(loss.item()) stats['test_loss'].append(test_loss.item()) if args.verbose and step % args.print_every == 0: print("step {}, train_loss {:.4e}, test_loss {:.4e}".format( step, loss.item(), test_loss.item())) train_dxdt_hat = model.time_derivative(x) train_dist = (dxdt - train_dxdt_hat)**2 test_dxdt_hat = model.time_derivative(test_x) test_dist = (test_dxdt - test_dxdt_hat)**2 print( 'Final train loss {:.4e} +/- {:.4e}\nFinal test loss {:.4e} +/- {:.4e}' .format(train_dist.mean().item(), train_dist.std().item() / np.sqrt(train_dist.shape[0]), test_dist.mean().item(), test_dist.std().item() / np.sqrt(test_dist.shape[0]))) return model, stats
def train(args): device = 'cuda:0' if torch.cuda.is_available() else 'cpu' dtype = torch.get_default_dtype() torch.set_grad_enabled(False) # set random seed torch.manual_seed(args.seed) np.random.seed(args.seed) # init model and optimizer model = dgnet.DGNet(args.input_dim, args.hidden_dim, nonlinearity=args.nonlinearity, model=args.model, solver=args.solver) model = model.to(device) optim = torch.optim.Adam(model.parameters(), args.learn_rate, weight_decay=1e-4) # arrange data t_span = 20 length = 100 dt = t_span / (length - 1) data = get_dataset(seed=args.seed, noise_std=args.noise, t_span=[0, t_span], timescale=length / t_span) train_x = torch.tensor(data['x'], requires_grad=True, device=device, dtype=dtype) test_x = torch.tensor(data['test_x'], requires_grad=True, device=device, dtype=dtype) train_dxdt = torch.tensor(data['dx'], device=device, dtype=dtype) test_dxdt = torch.tensor(data['test_dx'], device=device, dtype=dtype) input_dim = train_x.shape[-1] x_reshaped = train_x.view(-1, length, input_dim) x1 = x_reshaped[:, :-1].contiguous().view(-1, input_dim) x2 = x_reshaped[:, 1:].contiguous().view(-1, input_dim) dxdt = ((x2 - x1) / dt).detach() # vanilla train loop stats = {'train_loss': [], 'test_loss': []} for step in range(args.total_steps + 1): with torch.enable_grad(): # train step dxdt_hat = model.discrete_time_derivative(x1, dt=dt, x2=x2) loss = L2_loss(dxdt, dxdt_hat) loss.backward() optim.step() optim.zero_grad() # run test data test_dxdt_hat = model.time_derivative(test_x) test_loss = L2_loss(test_dxdt, test_dxdt_hat) # logging stats['train_loss'].append(loss.item()) stats['test_loss'].append(test_loss.item()) if args.verbose and step % args.print_every == 0: print("step {}, train_loss {:.4e}, test_loss {:.4e}".format(step, loss.item(), test_loss.item())) train_dxdt_hat = model.time_derivative(train_x) train_dist = (train_dxdt - train_dxdt_hat)**2 test_dxdt_hat = model.time_derivative(test_x) test_dist = (test_dxdt - test_dxdt_hat)**2 print('Final train loss {:.4e}\nFinal test loss {:.4e}' .format(train_dist.mean().item(), test_dist.mean().item())) stats['final_train_loss'] = train_dist.mean().item() stats['final_test_loss'] = test_dist.mean().item() # energy error def integrate_models(x0=np.asarray([1, 0]), t_span=[0, 5], t_eval=None): from data import dynamics_fn import scipy.integrate rtol = 1e-12 true_x = scipy.integrate.solve_ivp(fun=dynamics_fn, t_span=t_span, y0=x0, t_eval=t_eval, rtol=rtol)['y'].T # integrate along model vector field model_x = model.get_orbit(x0, t_eval, tol=rtol) return true_x, model_x def energy_loss(true_x, integrated_x): true_energy = (true_x**2).sum(1) integration_energy = (integrated_x**2).sum(1) return np.mean((true_energy - integration_energy)**2) t_span = [0, t_span] trials = 5 * 3 t_eval = np.linspace(t_span[0], t_span[1], length) losses = {'model_energy': [], 'model_state': [], } true_orbits = [] model_orbits = [] for i in range(trials): x0 = np.random.rand(2) * 1.6 - .8 # randomly sample a starting px: \in(-2,2) and abs(px) > 0.2 x0 += 0.2 * np.sign(x0) * np.ones_like(x0) true_x, model_x = integrate_models(x0=x0, t_span=t_span, t_eval=t_eval) true_orbits.append(true_x) model_orbits.append(model_x) losses['model_energy'] += [energy_loss(true_x, model_x)] losses['model_state'] += [((true_x - model_x)**2).mean()] print('{:.2f}% done'.format(100 * float(i) / (trials)), end='\r') stats['true_orbits'] = np.stack(true_orbits) stats['model_orbits'] = np.stack(model_orbits) stats['true_energies'] = (stats['true_orbits']**2).sum(-1) stats['model_energies'] = (stats['model_orbits']**2).sum(-1) losses = {k: np.array(v) for k, v in losses.items()} stats['energy_mse_mean'] = np.mean(losses['model_energy']) print("energy MSE {:.4e}".format(stats['energy_mse_mean'])) stats['state_mse_mean'] = np.mean(losses['model_state']) print("state MSE {:.4e}".format(stats['state_mse_mean'])) return model, stats
def train(args): if torch.cuda.is_available() and not args.cpu: device = torch.device("cuda:0") torch.set_default_tensor_type('torch.cuda.FloatTensor') torch.cuda.empty_cache() print("Running on the GPU") else: device = torch.device("cpu") print("Running on the CPU") # set random seed torch.manual_seed(args.seed) np.random.seed(args.seed) print("{} {}".format(args.folder, args.speed)) print("Training scaled model:" if args.scaled else "Training noisy model:") print('{} pairs of coords in latent space '.format(args.latent_dim)) #using universal autoencoder, pre-encode the training points autoencoder = MLPAutoencoder(args.input_dim_ae, args.hidden_dim, args.latent_dim * 2, nonlinearity='relu') full_model = PixelHNN(args.latent_dim * 2, args.hidden_dim, autoencoder=autoencoder, nonlinearity=args.nonlinearity, baseline=args.baseline) path = "{}/saved_models/{}.tar".format(args.save_dir, args.ae_path) full_model.load_state_dict(torch.load(path)) full_model.eval() autoencoder_model = full_model.autoencoder # get dataset (no test data for now) data = get_dataset(args.folder, args.speed, scaled=args.scaled, split=args.split_data, experiment_dir=args.experiment_dir, tensor=True) gcoords = autoencoder_model.encode(data).cpu().detach().numpy() x = torch.tensor(gcoords, dtype=torch.float, requires_grad=True) dx_np = full_model.time_derivative( torch.tensor(gcoords, dtype=torch.float, requires_grad=True)).cpu().detach().numpy() dx = torch.tensor(dx_np, dtype=torch.float) nnmodel = MLP(args.input_dim, args.hidden_dim, args.output_dim) model = HNN(2, nnmodel) model.to(device) optim = torch.optim.Adam(model.parameters(), args.learn_rate, weight_decay=args.weight_decay) # vanilla ae train loop stats = {'train_loss': [], 'test_loss': []} for step in range(args.total_steps + 1): # train step ixs = torch.randperm(x.shape[0])[:args.batch_size] x_train, dxdt = x[ixs].to(device), dx[ixs].to(device) dxdt_hat = model.time_derivative(x_train) loss = L2_loss(dxdt, dxdt_hat) loss.backward() optim.step() optim.zero_grad() stats['train_loss'].append(loss.item()) if step % args.print_every == 0: print("step {}, train_loss {:.4e}".format(step, loss.item())) # train_dist = hnn_ae_loss(x, x_next, model, return_scalar=False) # print('Final train loss {:.4e} +/- {:.4e}' # .format(train_dist.mean().item(), train_dist.std().item() / np.sqrt(train_dist.shape[0]))) return model
def train(args): # set random seed torch.manual_seed(args.seed) np.random.seed(args.seed) # init model and optimizer if args.verbose: print("Training baseline model:" if args.baseline else "Training HNN model:") S_net = MLP(int(args.input_dim/2), 140, int(args.input_dim/2)**2, args.nonlinearity) U_net = MLP(int(args.input_dim/2), 140, 1, args.nonlinearity) model = Lagrangian(int(args.input_dim/2), S_net, U_net, dt=1e-3) num_parm = get_model_parm_nums(model) print('model contains {} parameters'.format(num_parm)) optim = torch.optim.Adam(model.parameters(), args.learn_rate, weight_decay=1e-4) # arrange data data = get_lag_dataset(seed=args.seed) x = torch.tensor( data['x'], requires_grad=False, dtype=torch.float32) # append zero control u = torch.zeros_like(x[:,0]).unsqueeze(-1) x = torch.cat((x, u), -1) test_x = torch.tensor( data['test_x'], requires_grad=False, dtype=torch.float32) # append zero control test_x = torch.cat((test_x, u), -1) dxdt = torch.Tensor(data['dx']) test_dxdt = torch.Tensor(data['test_dx']) # vanilla train loop stats = {'train_loss': [], 'test_loss': []} for step in range(args.total_steps+1): # train step dq, dp, du = model.time_derivative(x).split(1,1) dxdt_hat = torch.cat((dq, dp), -1) loss = L2_loss(dxdt, dxdt_hat) loss.backward() ; optim.step() ; optim.zero_grad() # run test data dq_test, dp_test, du_test = model.time_derivative(test_x).split(1,1) test_dxdt_hat = torch.cat((dq_test, dp_test), -1) test_loss = L2_loss(test_dxdt, test_dxdt_hat) # logging stats['train_loss'].append(loss.item()) stats['test_loss'].append(test_loss.item()) if args.verbose and step % args.print_every == 0: print("step {}, train_loss {:.4e}, test_loss {:.4e}".format(step, loss.item(), test_loss.item())) train_dq, train_dp, train_du = model.time_derivative(x).split(1,1) train_dxdt_hat = torch.cat((train_dq, train_dp), -1) train_dist = (dxdt - train_dxdt_hat)**2 test_dq, test_dp, test_du = model.time_derivative(test_x).split(1,1) test_dxdt_hat = torch.cat((test_dq, test_dp), -1) test_dist = (test_dxdt - test_dxdt_hat)**2 print('Final train loss {:.4e} +/- {:.4e}\nFinal test loss {:.4e} +/- {:.4e}' .format(train_dist.mean().item(), train_dist.std().item()/np.sqrt(train_dist.shape[0]), test_dist.mean().item(), test_dist.std().item()/np.sqrt(test_dist.shape[0]))) return model, stats
def train(args): device = torch.device( 'cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu') # reproducibility: set random seed torch.manual_seed(args.seed) np.random.seed(args.seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # init model and optimizer if args.verbose: print("Start training with num of points = {} and solver {}.".format( args.num_points, args.solver)) if args.structure == False and args.baseline == True: nn_model = MLP(args.input_dim, 600, args.input_dim, args.nonlinearity) model = SymODEN_R(args.input_dim, H_net=nn_model, device=device, baseline=True) elif args.structure == False and args.baseline == False: H_net = MLP(args.input_dim, 400, 1, args.nonlinearity) g_net = MLP(int(args.input_dim / 2), 200, int(args.input_dim / 2)) model = SymODEN_R(args.input_dim, H_net=H_net, g_net=g_net, device=device, baseline=False) elif args.structure == True and args.baseline == False: M_net = MLP(int(args.input_dim / 2), 300, int(args.input_dim / 2)) V_net = MLP(int(args.input_dim / 2), 50, 1) g_net = MLP(int(args.input_dim / 2), 200, int(args.input_dim / 2)) model = SymODEN_R(args.input_dim, M_net=M_net, V_net=V_net, g_net=g_net, device=device, baseline=False, structure=True) else: raise RuntimeError( 'argument *baseline* and *structure* cannot both be true') num_parm = get_model_parm_nums(model) print('model contains {} parameters'.format(num_parm)) optim = torch.optim.Adam(model.parameters(), args.learn_rate, weight_decay=1e-4) data = get_dataset(seed=args.seed) # modified to use the hnn stuff x = torch.tensor(data['x'], requires_grad=True, dtype=torch.float32) # [1125, 2] Bx2 # append zero control u = torch.zeros_like(x[:, 0]).unsqueeze(-1) x = torch.cat((x, u), -1) test_x = torch.tensor(data['test_x'], requires_grad=True, dtype=torch.float32) # append zero control test_x = torch.cat((test_x, u), -1) dxdt = torch.Tensor(data['dx']) # [1125, 2] Bx2 test_dxdt = torch.Tensor(data['test_dx']) # training loop stats = {'train_loss': [], 'test_loss': []} for step in range(args.total_steps + 1): # modified to match hnn dq, dp, du = model.time_derivative(x).split(1, 1) dxdt_hat = torch.cat((dq, dp), -1) loss = L2_loss(dxdt, dxdt_hat) loss.backward() optim.step() optim.zero_grad() # run test data dq_test, dp_test, du_test = model.time_derivative(test_x).split(1, 1) test_dxdt_hat = torch.cat((dq_test, dp_test), -1) test_loss = L2_loss(test_dxdt, test_dxdt_hat) # logging stats['train_loss'].append(loss.item()) stats['test_loss'].append(test_loss.item()) if args.verbose and step % args.print_every == 0: print("step {}, train_loss {:.4e}, test_loss {:.4e}".format( step, loss.item(), test_loss.item())) train_dq, train_dp, train_du = model.time_derivative(x).split(1, 1) train_dxdt_hat = torch.cat((train_dq, train_dp), -1) train_dist = (dxdt - train_dxdt_hat)**2 test_dq, test_dp, test_du = model.time_derivative(test_x).split(1, 1) test_dxdt_hat = torch.cat((test_dq, test_dp), -1) test_dist = (test_dxdt - test_dxdt_hat)**2 print( 'Final train loss {:.4e} +/- {:.4e}\nFinal test loss {:.4e} +/- {:.4e}' .format(train_dist.mean().item(), train_dist.std().item() / np.sqrt(train_dist.shape[0]), test_dist.mean().item(), test_dist.std().item() / np.sqrt(test_dist.shape[0]))) return model, stats
def train(args): # set random seed torch.manual_seed(args.seed) np.random.seed(args.seed) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(device) # init model and optimizer if args.verbose: print("Training baseline model:" if args.baseline else "Training HNN model:") output_dim = args.input_dim if args.baseline else 2 nn_model = MLPAutoencoder(args.input_dim, args.hidden_dim, args.latent_dim, args.nonlinearity) nn_model.to(device) model = HNN(args.input_dim, differentiable_model=nn_model, field_type=args.field_type, baseline=args.baseline, device=device) model.to(device) optim = torch.optim.Adam(model.parameters(), args.learn_rate, weight_decay=0) # arrange data X = np.load('statrectinputs.npy') Y = np.load('statrectoutputs.npy') Y[~np.isfinite(Y)] = 0 xm, xd = give_min_and_dist(X) ym, yd= give_min_and_dist(Y) X = scale(X, xm, xd) Y = scale(Y, ym, yd) n_egs = X.shape[0] x = X[0:int(0.8*n_egs),:] test_x = torch.tensor(X[:-int(0.2*n_egs),:], requires_grad=True, dtype=torch.float32) dxdt = Y[0:int(0.8*n_egs),:] test_dxdt = torch.tensor(Y[:-int(0.2*n_egs),:]) # vanilla train loop stats = {'train_loss': [], 'test_loss': []} for step in range(args.total_steps+1): # train step ixs = torch.randperm(x.shape[0])[:args.batch_size] x = torch.tensor(x[ixs], requires_grad=True, dtype=torch.float32) x.to(device) dxdt_hat = model.time_derivative(x) y = torch.tensor(dxdt[ixs]) y.to(device) loss = L2_loss(y, dxdt_hat) loss.backward() grad = torch.cat([p.grad.flatten() for p in model.parameters()]).clone() optim.step() ; optim.zero_grad() # run test data test_ixs = torch.randperm(test_x.shape[0])[:args.batch_size] test_dxdt_hat = model.time_derivative(test_x[test_ixs]) #test_dxdt_hat += args.input_noise * torch.randn(*test_x[test_ixs].shape) # add noise, maybe test_loss = L2_loss(test_dxdt[test_ixs], test_dxdt_hat) # logging stats['train_loss'].append(loss.item()) stats['test_loss'].append(test_loss.item()) if args.verbose and step % args.print_every == 0: print("step {}, train_loss {:.4e}, test_loss {:.4e}, grad norm {:.4e}, grad std {:.4e}" .format(step, loss.item(), test_loss.item(), grad@grad, grad.std())) ixs = torch.randperm(x.shape[0])[:10000] x = torch.tensor(x[ixs], requires_grad=True, dtype=torch.float32) x.to(device) enc = model.encoding(x).detach().numpy() print(x.shape) fig = plt.figure() ax = fig.add_subplot(111, projection='3d') x = x.detach().numpy() img = ax.scatter(enc[:,0], enc[:,3], enc[:,2], c=enc[:,1], cmap=plt.hot()) fig.colorbar(img) plt.savefig('lrep.png') y0 = torch.tensor([0.4, 0.3, 1/np.sqrt(2), 1/np.sqrt(2)], dtype=torch.float32) update_fn = lambda t, y0: model_update(t, y0, model) orbit, settings = get_orbit(y0, t_points=10, t_span=[0, 10], update_fn=update_fn) print(orbit) plt.scatter(orbit[:,0], orbit[:, 1]) plt.savefig('orbit.png') return model, stats
def train(args): if torch.cuda.is_available() and not args.cpu: device = torch.device("cuda:0") torch.set_default_tensor_type('torch.cuda.FloatTensor') torch.cuda.empty_cache() print("Running on the GPU") else: device = torch.device("cpu") print("Running on the CPU") # set random seed torch.manual_seed(args.seed) np.random.seed(args.seed) # get dataset (no test data for now) angular_velo, acc_1, acc_2, sound = get_dataset_split( args.folder, args.speed, scaled=args.scaled, experiment_dir=args.experiment_dir, tensor=True) sub_col = { 0: [angular_velo, 1, 'v'], 1: [acc_1, 3, 'a1'], 2: [acc_2, 3, 'a2'], 3: [sound, 1, 's'] } col2use = sub_col[args.sub_columns][0] # using universal autoencoder, pre-encode the training points autoencoder = MLPAutoencoder(sub_col[args.sub_columns][1], args.hidden_dim, args.latent_dim * 2, dropout_rate=args.dropout_rate_ae) full_model = PixelHNN(args.latent_dim * 2, args.hidden_dim, autoencoder=autoencoder, nonlinearity=args.nonlinearity, baseline=args.baseline, dropout_rate=args.dropout_rate) path = "{}/saved_models/{}-{}.tar".format(args.save_dir, args.ae_path, sub_col[args.sub_columns][2]) full_model.load_state_dict(torch.load(path)) full_model.eval() autoencoder_model = full_model.autoencoder gcoords = autoencoder_model.encode(col2use).cpu().detach().numpy() x = torch.tensor(gcoords, dtype=torch.float, requires_grad=True) dx_np = full_model.time_derivative( torch.tensor(gcoords, dtype=torch.float, requires_grad=True)).cpu().detach().numpy() dx = torch.tensor(dx_np, dtype=torch.float) nnmodel = MLP(args.input_dim, args.hidden_dim, args.output_dim) model = HNN(2, nnmodel) model.to(device) optim = torch.optim.Adam(model.parameters(), args.learn_rate, weight_decay=args.weight_decay) print("Data from {} {}, column: {}".format(args.folder, args.speed, sub_col[args.sub_columns][2])) # x = torch.tensor(col2use[:-1], dtype=torch.float) # x_next = torch.tensor(col2use[1:], dtype=torch.float) # # autoencoder = MLPAutoencoder(sub_col[args.sub_columns][1], args.hidden_dim, args.latent_dim * 2, dropout_rate=args.dropout_rate_ae) # model = PixelHNN(args.latent_dim * 2, args.hidden_dim, # autoencoder=autoencoder, nonlinearity=args.nonlinearity, baseline=args.baseline, dropout_rate=args.dropout_rate) # model.to(device) # optim = torch.optim.Adam(model.parameters(), args.learn_rate, weight_decay=args.weight_decay) # vanilla ae train loop stats = {'train_loss': []} for step in range(args.total_steps + 1): # train step ixs = torch.randperm(x.shape[0])[:args.batch_size] x_train, dxdt = x[ixs].to(device), dx[ixs].to(device) dxdt_hat = model.time_derivative(x_train) loss = L2_loss(dxdt, dxdt_hat) loss.backward() optim.step() optim.zero_grad() stats['train_loss'].append(loss.item()) if step % args.print_every == 0: print("step {}, train_loss {:.4e}".format(step, loss.item())) # train_dist = hnn_ae_loss(x, x_next, model, return_scalar=False) # print('Final train loss {:.4e} +/- {:.4e}' # .format(train_dist.mean().item(), train_dist.std().item() / np.sqrt(train_dist.shape[0]))) return model