def model_simulate(y0, dy0, n_steps): x = np.empty((n_steps+1, *y0.shape[1:])) x[0] = y0.copy() xx = as_tensor(x[0]) dx = as_tensor(dy0.copy()) for i in range(1,n_steps+1): xxo = xx * 1. xx = std_out * model_forward(torch.cat((xx.reshape(1,J+1,K), dx), axis=1)) + mean_out + xx.reshape(1,J+1,-1) dx = xx - xxo x[i] = xx.detach().cpu().numpy().copy() return x
def __init__(self, r=0., sigma2=0.): super(ObsOp_subsampleGaussian, self).__init__() assert sigma2 >= 0. self.sigma2 = as_tensor(sigma2) self.sigma = torch.sqrt(self.sigma2) self.ndistr = torch.distributions.normal.Normal(loc=0., scale=self.sigma) assert 0. <= r <= 1. self.r = as_tensor(r) self.mdistr = torch.distributions.Bernoulli(probs=1 - self.r) self.mask = 1.
def get_rollout_fun(dg_train, model_forward, prediction_task): J, K = dg_train.data.shape[1]-1, dg_train.data.shape[2] if prediction_task == 'update_with_past': raise NotImplementedError assert isinstance(dg_train, DatasetRelPredPast) std_out = as_tensor(dg_train.std_out) mean_out = as_tensor(dg_train.mean_out) def model_simulate(y0, dy0, n_steps): x = np.empty((n_steps+1, *y0.shape[1:])) x[0] = y0.copy() xx = as_tensor(x[0]) dx = as_tensor(dy0.copy()) for i in range(1,n_steps+1): xxo = xx * 1. xx = std_out * model_forward(torch.cat((xx.reshape(1,J+1,K), dx), axis=1)) + mean_out + xx.reshape(1,J+1,-1) dx = xx - xxo x[i] = xx.detach().cpu().numpy().copy() return x elif prediction_task == 'update': raise NotImplementedError assert isinstance(dg_train, DatasetRelPred) def model_simulate(y0, dy0, n_steps): x = np.empty((n_steps+1, *y0.shape[1:])) x[0] = y0.copy() xx = as_tensor(x[0]).reshape(1,1,-1) for i in range(1,n_steps+1): xx = model_forward(xx.reshape(1,J+1,-1)) + xx.reshape(1,J+1,-1) x[i] = xx.detach().cpu().numpy().copy() return x elif prediction_task == 'state': assert isinstance(dg_train, Dataset) def model_simulate(y0, dy0, n_steps): x = np.empty((n_steps+1, *y0.shape[1:])) x[0] = y0.copy() xx = as_tensor(x[0]).reshape(1,1,-1) for i in range(1,n_steps+1): xx = model_forward(xx.reshape(1,J+1,-1)) x[i] = xx.detach().cpu().numpy().copy() return x return model_simulate
def exp_fsr( x_init ): # torch decorator for forward-solve in reverse learned from model_exp_id x = as_tensor(x_init) for t in range(res['back_solve_dt_fac'] * T_rollout // n_chunks_recursive): x = model_forwarder_eb.forward(x) return x.detach().cpu().numpy()
def model_simulate(y0, dy0, n_steps): x = np.empty((n_steps + 1, *y0.shape[1:]), dtype=dtype_np) x[0] = y0.copy() xx = as_tensor(x[0]).reshape(1, 1, -1) for i in range(1, n_steps + 1): xx = model_forwarder_np(xx.reshape(1, J + 1, -1)) x[i] = xx.detach().cpu().numpy().copy() return x
def __init__(self, frq=4, sigma2=0.): super(ObsOp_rotsampleGaussian, self).__init__() assert sigma2 >= 0. self.sigma2 = as_tensor(sigma2) self.sigma = torch.sqrt(self.sigma2) self.ndistr = torch.distributions.normal.Normal(loc=0., scale=self.sigma) assert 0. <= frq self.frq = frq self.mask = 1. self.ridx = 0
def __init__(self, model_forwarder, prediction_task, K, J, N, T=1, x_init=None): super(Rollout, self).__init__() self.model_forwarder = model_forwarder self.prediction_task = prediction_task self.T = T x_init = np.random.normal(size=(N,K*(J+1))) if x_init is None else x_init assert x_init.ndim in [2,3] self.X = torch.nn.Parameter(as_tensor(x_init))
def load_model_from_exp_conf(res_dir, conf): net_kwargs = { 'filters': conf['filters'], 'kernel_sizes': conf['kernel_sizes'], 'filters_ks1_init': conf['filters_ks1_init'], 'filters_ks1_inter': conf['filters_ks1_inter'], 'filters_ks1_final': conf['filters_ks1_final'], 'additiveResShortcuts' : conf['additiveResShortcuts'], 'direct_shortcut' : conf['direct_shortcut'], 'dropout_rate' : conf['dropout_rate'], 'layerNorm' : conf['layerNorm'], 'init_net' : conf['init_net'], 'K_net' : conf['K_net'], 'J_net' : conf['J_net'], 'F_net' : conf['F_net'], 'dt_net' : conf['dt_net'], 'alpha_net' : conf['alpha_net'], 'model_forwarder' : conf['model_forwarder'], 'padding_mode' : conf['padding_mode'] } model, model_forward = named_network(model_name=conf['model_name'], n_input_channels=conf['J']+1, n_output_channels=conf['J']+1, seq_length=conf['seq_length'], **net_kwargs) test_input = np.random.normal(size=(10, conf['seq_length']*(conf['J']+1), conf['K'])) print(f'model output shape to test input of shape {test_input.shape}', model.forward(as_tensor(test_input)).shape) print('total #parameters: ', np.sum([np.prod(item.shape) for item in model.state_dict().values()])) lead_time, exp_id = conf['lead_time'], conf['exp_id'] save_dir = res_dir + 'models/' + exp_id + '/' model_fn = f'{exp_id}_dt{lead_time}.pt' print('save_dir + model_fn', save_dir + model_fn) model.load_state_dict(torch.load(save_dir + model_fn, map_location=torch.device(device))) try: training_outputs = np.load(save_dir + '_training_outputs' + '.npy', allow_pickle=True)[()] except: training_outputs = None print('WARNING: could not load diagnostical outputs from model training!') return model, model_forward, training_outputs
def set_state(self, x_init): self.X = torch.nn.Parameter(as_tensor(x_init))
def solve_4dvar(y, m, T_obs, T_win, T_shift, x_init, model_pars, obs_pars, optimizer_pars, res_dir): """ def solve_4dvar(system_pars, model_pars, optimizer_pars, setup_pars, optimiziation_schemes, res_dir, data_dir, fn=None): """ # extract key variable names from input dicts T, N, J, K = m.shape J -= 1 assert y.shape == (T, N, (J + 1) * K) if x_init is None: x_init = get_init(sortL96intoChannels(y[0], J=J).detach().cpu(), m[0].detach().cpu(), method='interpolate') assert x_init.shape == (N, J + 1, K) # get model model, model_forwarder, args = get_model(model_pars, res_dir=res_dir, exp_dir='') model_observer = obs_pars['obs_operator'](**obs_pars['obs_operator_args']) prior = torch.distributions.normal.Normal(loc=torch.zeros((1, J + 1, K)), scale=1. * torch.ones( (1, J + 1, K))) gen = GenModel(model_forwarder, model_observer, prior, T=T_win, x_init=None) priors = None assert len( T_obs ) == T # easy to generalize, but for now only support one obs per time point n_starts = (T - T_win) // T_shift + 1 out, losses, times = [], [], [] for n in range(n_starts): print('\n') print(f'optimizing window number {n+1} / {n_starts}') print('\n') idx = np.where( np.logical_and(n * T_shift + T_win > T_obs, T_obs >= n * T_shift))[0] assert len(idx) > 0 # atm not supporting empty integration window opt_res = optim_initial_state(gen, T_rollouts=[T_win], T_obs=[T_obs[idx] - n * T_shift], N=N, n_chunks=1, optimizer_pars=optimizer_pars, x_inits=[x_init], targets=[y[idx]], grndtrths=None, loss_masks=[m[idx]]) x_sols, loss_vals, time_vals, _ = opt_res out.append(sortL96intoChannels(x_sols[0], J=J)) losses.append(loss_vals) times.append(time_vals) priors = [ SimplePrior(K=K, J=J, loc=as_tensor(out[-1][n]), scale=1.) for n in range(N) ] x_init = gen._forward(x=as_tensor(out[-1]), T_obs=[T_shift - 1])[0].detach().cpu().numpy() return np.stack(out, axis=0), losses, times
def optim_initial_state(gen, T_rollouts, T_obs, N, n_chunks, optimizer_pars, x_inits, targets, grndtrths=None, loss_masks=None, priors=None, f_init=None): sample_shape = gen.prior.sample().shape # (..., J+1, K) J, K = sample_shape[-2] - 1, sample_shape[-1] n_steps = optimizer_pars['n_steps'] x_sols = np.zeros((n_chunks, N, K * (J + 1))) loss_vals = np.inf * np.ones((n_steps * n_chunks, N)) time_vals = time.time() * np.ones((n_steps * n_chunks, N)) state_mses = np.inf * np.ones((n_chunks, N)) loss_masks = [torch.ones( (N, J + 1, K)) for i in range(n_chunks)] if loss_masks is None else loss_masks assert len(loss_masks) == n_chunks if priors is None: class Const_prior(object): def __init__(self): pass def log_prob(self, x): return 0. priors = [Const_prior() for n in range(N)] assert len(priors) == N i_ = 0 for j in range(n_chunks): print('\n') print(f'optimizing over chunk #{j+1} out of {n_chunks}') print('\n') target = sortL96intoChannels(as_tensor(targets[j]), J=J) loss_mask = loss_masks[j] assert len(target) == len(T_obs[j]) and len(loss_mask) == len(T_obs[j]) for n in range(N): print('\n') print(f'optimizing over initial state #{n+1} / {N}') print('\n') gen.set_state(x_inits[j][n:n + 1]) gen.set_rollout_len(T_rollouts[j]) optimizer = torch.optim.LBFGS( params=[gen.X], lr=optimizer_pars['lr'], max_iter=optimizer_pars['max_iter'], max_eval=optimizer_pars['max_eval'], tolerance_grad=optimizer_pars['tolerance_grad'], tolerance_change=optimizer_pars['tolerance_change'], history_size=optimizer_pars['history_size'], line_search_fn='strong_wolfe') for i_n in range(n_steps): with torch.no_grad(): loss = -gen.log_prob(y=target[:, n:n + 1], m=loss_mask[:, n:n + 1], T_obs=T_obs[j]) loss = loss - priors[n].log_prob(gen.X) if i_n == 0: print('initial loss: ', loss) if torch.any(torch.isnan(loss)): loss_vals[i_n, n] = loss.detach().cpu().numpy() time_vals[i_ + i_n, n] = time.time() - time_vals[i_ + i_n, n] print(('{:.4f}'.format(time_vals[i_n, n]), loss_vals[i_n, n])) print('NaN loss - skipping iteration') print('optimizier.state', optimizer.state[gen.X]) continue def closure(): loss = -gen.log_prob(y=target[:, n:n + 1], m=loss_mask[:, n:n + 1], T_obs=T_obs[j]) loss = loss - priors[n].log_prob(gen.X) if torch.is_grad_enabled(): optimizer.zero_grad() if loss.requires_grad: loss.backward() return loss optimizer.step(closure) loss_vals[i_ + i_n, n] = loss.detach().cpu().numpy() time_vals[i_ + i_n, n] = time.time() - time_vals[i_ + i_n, n] print(('{:.4f}'.format(time_vals[i_n, n]), loss_vals[i_n, n])) x_sols[j][n] = sortL96fromChannels( gen.X.detach().cpu().numpy().copy()) state_mses[j][n] = np.inf if grndtrths is None else (( x_sols[j][n] - grndtrths[j][n])**2).mean() # if solving recursively, define next target as current initial state estimate if j < n_chunks - 1 and targets[j + 1] is None: targets[j + 1] = x_sols[j].copy().reshape(1, *x_sols[j].shape) i_ += n_steps with torch.no_grad(): if not grndtrths is None: print('Eucl. distance to initial value', mse_loss_fullyObs(x_sols[j], grndtrths[j])) print( 'Eucl. distance to x_init', mse_loss_fullyObs(x_sols[j], sortL96fromChannels(x_inits[j]))) if j < n_chunks - 1 and x_inits[j + 1] is None: x_inits[j + 1] = sortL96intoChannels(x_sols[j], J=J).copy() if not f_init is None: x_inits[j + 1] = f_init(x_inits[j + 1]) # correting time stamps for solving multiple trials sequentially for j in range(n_chunks): top_new = time_vals[(j + 1) * n_steps - 1, N - 1] for n in range(1, N)[::-1]: # correct for the fact that n-1 other problems were solve before for this j time_vals[j * n_steps + np.arange(n_steps), n] -= time_vals[(j + 1) * n_steps - 1, n - 1] if j > 0: # continue from j-1 for this n time_vals[j * n_steps + np.arange(n_steps), n] += time_vals[j * n_steps - 1, n] if j > 0: # for first trial (n=0) of this j, clear previous time and continue from j-1 time_vals[j * n_steps + np.arange(n_steps), 0] -= top_old time_vals[j * n_steps + np.arange(n_steps), 0] += time_vals[j * n_steps - 1, 0] top_old = top_new return x_sols, loss_vals, time_vals, state_mses
def solve_initstate(system_pars, model_pars, optimizer_pars, setup_pars, optimiziation_schemes, res_dir, data_dir, fn=None): # extract key variable names from input dicts K, J = system_pars['K'], system_pars['J'] T, dt, N_trials = system_pars['T'], system_pars['dt'], system_pars[ 'N_trials'] n_starts, T_rollout, T_pred = setup_pars['n_starts'], setup_pars[ 'T_rollout'], setup_pars['T_pred'] n_chunks, n_chunks_recursive = setup_pars['n_chunks'], setup_pars[ 'n_chunks_recursive'] N = len(n_starts) recursions_per_chunks = n_chunks_recursive // n_chunks assert recursions_per_chunks == n_chunks_recursive / n_chunks assert T_rollout // n_chunks_recursive == T_rollout / n_chunks_recursive assert T_rollout // n_chunks == T_rollout / n_chunks if optimiziation_schemes['LBFGS_full_chunks']: assert optimiziation_schemes['LBFGS_chunks'] # requirement for init if optimiziation_schemes['LBFGS_full_backsolve']: assert optimiziation_schemes['backsolve'] # requirement for init # get model model, model_forwarder, args = get_model(model_pars, res_dir=res_dir, exp_dir='') # ### instantiate observation operator model_observer = system_pars['obs_operator']( **system_pars['obs_operator_args']) # ### define prior over initial states prior = torch.distributions.normal.Normal(loc=torch.zeros((1, J + 1, K)), scale=1. * torch.ones( (1, J + 1, K))) # ### define generative model for observed data gen = GenModel(model_forwarder, model_observer, prior, T=T_rollout, x_init=None) # prepare function output model_forwarder_str, optimizer_str = args[ 'model_forwarder'], optimizer_pars['optimizer'] obs_operator_str = model_observer.__class__.__name__ exp_id = model_pars['exp_id'] # output dictionary res = { 'exp_id': exp_id, 'K': K, 'J': J, 'T': T, 'dt': dt, 'back_solve_dt_fac': system_pars['back_solve_dt_fac'], 'F': system_pars['F'], 'h': system_pars['h'], 'b': system_pars['b'], 'c': system_pars['c'], 'conf_exp': args['conf_exp'], 'model_forwarder': model_pars['model_forwarder'], # should still be string 'dt_net': model_pars['dt_net'], 'n_starts': n_starts, 'T_rollout': T_rollout, 'T_pred': T_pred, 'n_chunks': n_chunks, 'n_chunks_recursive': n_chunks_recursive, 'recursions_per_chunks': recursions_per_chunks, 'n_steps': optimizer_pars['n_steps'], 'n_steps_tot': optimizer_pars['n_steps'] * n_chunks, 'optimizer_pars': optimizer_pars, 'optimiziation_schemes': optimiziation_schemes, 'obs_operator': obs_operator_str, 'obs_operator_args': system_pars['obs_operator_args'] } # ### get data for 'typical' L96 state sequences out, datagen_dict = get_data(K=K, J=J, T=T, dt=dt, N_trials=N_trials, F=res['F'], h=res['h'], b=res['b'], c=res['c'], resimulate=True, solver=rk4_default, save_sim=False, data_dir=data_dir) grndtrths = [out[n_starts] for j in range(n_chunks_recursive)] res['initial_states'] = np.stack( [sortL96intoChannels(z, J=J) for z in grndtrths]) T_obs = [[(j + 1) * (T_rollout // n_chunks) - 1 for j in range(n_ + 1)] for n_ in range(n_chunks)] print('T_obs[-1]', T_obs[-1]) res['targets'] = np.stack( [sortL96intoChannels(out[n_starts + t + 1], J=J) for t in T_obs[-1]], axis=0) # ## Generate observed data: (sub-)sample noisy observations res['test_state'] = sortL96intoChannels(out[n_starts + T_pred], J=J).reshape( 1, len(n_starts), J + 1, K) res['test_state_obs'] = gen._sample_obs(as_tensor( res['test_state'])) # sets the loss masks! res['test_state_obs'] = sortL96fromChannels( res['test_state_obs'].detach().cpu().numpy()) res['test_state_mask'] = torch.stack(gen.masks, dim=0).detach().cpu().numpy() res['targets_obs'] = gen._sample_obs(as_tensor( res['targets'])) # sets the loss masks! res['targets_obs'] = sortL96fromChannels( res['targets_obs'].detach().cpu().numpy()) res['loss_mask'] = torch.stack(gen.masks, dim=0).detach().cpu().numpy() if fn is None: fn = 'results/data_assimilation/fullyobs_initstate_tests_' fn = fn + f'exp{exp_id}_{model_forwarder_str}_{optimizer_str}_{obs_operator_str}' print('\n') print('storing intermediate results') print('\n') np.save(res_dir + fn, arr=res) # ### define setup for optimization T_rollouts = np.arange(1, n_chunks + 1) * (T_rollout // n_chunks) grndtrths = [out[n_starts] for j in range(n_chunks)] targets = [res['targets_obs'][:len(T_obs[j])] for j in range(n_chunks)] loss_masks = [ torch.stack(gen.masks[:len(T_obs[j])], dim=0) for j in range(n_chunks) ] grndtrths_chunks = [out[n_starts] for j in range(n_chunks_recursive)] # ## L-BFGS, solve across full rollout time recursively, initialize from forward solver in reverse if optimiziation_schemes['LBFGS_recurse_chunks']: print('\n') print( 'L-BFGS, solve across full rollout time recursively, initialize from forward solver in reverse' ) print('\n') assert len( T_obs) == 1 # only allow single observation at end of interval assert len(res['targets_obs']) == 1 # functions for explicitly solving backwards class Model_eb(torch.nn.Module): def __init__(self, model): super(Model_eb, self).__init__() self.model = model def forward(self, x): return -self.model.forward(x) if res['model_forwarder'] == 'rk4_default': Model_forwarder = Model_forwarder_rk4default elif res['model_forwarder'] == 'predictor_corrector': Model_forwarder = Model_forwarder_predictorCorrector model_forwarder_eb = Model_forwarder(model=Model_eb(model), dt=dt / res['back_solve_dt_fac']) def exp_fsr( x_init ): # torch decorator for forward-solve in reverse learned from model_exp_id x = as_tensor(x_init) for t in range(res['back_solve_dt_fac'] * T_rollout // n_chunks_recursive): x = model_forwarder_eb.forward(x) return x.detach().cpu().numpy() x_init = get_init(sortL96intoChannels(res['targets_obs'], J=J)[0], res['loss_mask'][0], method='interpolate') x_inits = [exp_fsr(x_init)] + [None for j in range(n_chunks_recursive)] opt_res = optim_initial_state( gen, T_rollouts=np.arange(1, n_chunks_recursive + 1) * (T_rollout // n_chunks_recursive), T_obs=[[j] for j in range(recursions_per_chunks)], N=N, n_chunks=n_chunks_recursive, optimizer_pars=optimizer_pars, x_inits=x_inits, targets=list(np.repeat(targets, recursions_per_chunks, axis=0)), grndtrths=[ out[n_starts + T_rollout - (j + 1) * (T_rollout // n_chunks_recursive)] for j in range(n_chunks_recursive) ], loss_masks=list( np.repeat(loss_masks, recursions_per_chunks, axis=0)), f_init=exp_fsr) res['x_sols_LBFGS_recurse_chunks'] = opt_res[0] res['loss_vals_LBFGS_recurse_chunks'] = opt_res[1] res['time_vals_LBFGS_recurse_chunks'] = opt_res[2] res['state_mses_LBFGS_recurse_chunks'] = opt_res[3] print('\n') print('storing results') print('\n') np.save(res_dir + fn, arr=res) # ## L-BFGS, solve across single chunks recursively, initialize from last chunk if optimiziation_schemes['LBFGS_chunks']: print('\n') print( 'L-BFGS, solve across single chunks recursively, initialize from last chunk' ) print('\n') assert len( T_obs) == 1 # only allow single observation at end of interval assert len(res['targets_obs']) == 1 x_inits = [None for i in range(n_chunks_recursive)] x_inits[0] = get_init(sortL96intoChannels(res['targets_obs'], J=J)[0], res['loss_mask'][0], method='interpolate') opt_res = optim_initial_state( gen, T_rollouts=np.ones(n_chunks_recursive, dtype=np.int) * (T_rollout // n_chunks_recursive), T_obs=[[0] for j in range(n_chunks_recursive)], N=N, n_chunks=n_chunks_recursive, optimizer_pars=optimizer_pars, x_inits=x_inits, targets=[res['targets_obs']] + [None for i in range(n_chunks_recursive - 1)], grndtrths=[ out[n_starts + T_rollout - (j + 1) * (T_rollout // n_chunks_recursive)] for j in range(n_chunks_recursive) ], loss_masks=[torch.stack(gen.masks, dim=0)] + [ torch.ones((1, N, J + 1, K)) for i in range(n_chunks_recursive - 1) ]) res['x_sols_LBFGS_chunks'] = opt_res[0] res['loss_vals_LBFGS_chunks'] = opt_res[1] res['time_vals_LBFGS_chunks'] = opt_res[2] res['state_mses_LBFGS_chunks'] = opt_res[3] print('\n') print('storing results') print('\n') np.save(res_dir + fn, arr=res) # ## L-BFGS, solve across full rollout time in one go, initialize from chunked approach if optimiziation_schemes['LBFGS_full_chunks']: print('\n') print( 'L-BFGS, solve across full rollout time in one go, initialize from chunked approach' ) print('\n') x_inits = res['x_sols_LBFGS_chunks'][recursions_per_chunks - 1:][::recursions_per_chunks] x_inits = [sortL96intoChannels(z, J=J).copy() for z in x_inits] opt_res = optim_initial_state(gen, T_rollouts=T_rollouts, T_obs=T_obs, N=N, n_chunks=n_chunks, optimizer_pars=optimizer_pars, x_inits=x_inits, targets=targets, grndtrths=grndtrths, loss_masks=loss_masks) res['x_sols_LBFGS_full_chunks'] = opt_res[0] res['loss_vals_LBFGS_full_chunks'] = opt_res[1] res['time_vals_LBFGS_full_chunks'] = opt_res[2] res['state_mses_LBFGS_full_chunks'] = opt_res[3] print('\n') print('storing results') print('\n') np.save(res_dir + fn, arr=res) # ## numerical forward solve in reverse if optimiziation_schemes['backsolve']: print('\n') print('numerical forward solve in reverse') print('\n') # functions for explicitly solving backwards from L96sim.L96_base import f1, f2, pf2 dX_dt = np.empty(K * (J + 1), dtype=dtype_np) if J > 0: def fun_eb(t, x): return -f2(x, res['F'], res['h'], res['b'], res['c'], dX_dt, K, J) else: def fun_eb(t, x): return -f1(x, res['F'], dX_dt, K) def explicit_backsolve(x_init, times_eb, fun_eb): x_sols = np.zeros_like(x_init) for i__ in range(x_init.shape[0]): out2 = rk4_default(fun=fun_eb, y0=x_init[i__], times=times_eb) x_sols[i__] = out2[-1].copy() #.detach().cpu().numpy().copy() return x_sols res['state_mses_backsolve'] = np.zeros( (n_chunks_recursive, len(n_starts))) res['x_sols_backsolve'] = np.zeros( (n_chunks_recursive, len(n_starts), K * (J + 1))) res['loss_vals_backsolve'] = np.zeros( (n_chunks_recursive, len(n_starts))) print('backward solving') res['time_vals_backsolve'] = time.time() * np.ones( (n_chunks_recursive, len(n_starts))) x_init = get_init(sortL96intoChannels(res['targets_obs'][0], J=J), res['loss_mask'][0], method='interpolate') for j in range(recursions_per_chunks): times_eb = dt * np.linspace( 0, T_rollouts[0] / recursions_per_chunks, res['back_solve_dt_fac'] * T_rollouts[0] / recursions_per_chunks + 1) print('x_init.shape', sortL96fromChannels(x_init).shape) res['x_sols_backsolve'][j] = explicit_backsolve( sortL96fromChannels(x_init), times_eb, fun_eb) x_init = sortL96intoChannels(res['x_sols_backsolve'][j].copy(), J=J) print('x_init.shape - out', x_init.shape) res['time_vals_backsolve'][j] = res['time_vals_backsolve'][0] x_target = out[n_starts + T_rollouts[0] - (j + 1) * (T_rollouts[0] // recursions_per_chunks)] res['state_mses_backsolve'][j] = ((res['x_sols_backsolve'][j] - x_target)**2).mean(axis=1) res['time_vals_backsolve'][j] = time.time( ) - res['time_vals_backsolve'][0] for j in range(recursions_per_chunks, n_chunks_recursive): res['x_sols_backsolve'][j] = res['x_sols_backsolve'][ recursions_per_chunks - 1] res['time_vals_backsolve'][j] = res['time_vals_backsolve'][ recursions_per_chunks - 1] res['state_mses_backsolve'][j] = ((res['x_sols_backsolve'][j] - out[n_starts])**2).mean(axis=1) print('\n') print('storing results') print('\n') np.save(res_dir + fn, arr=res) # ## L-BFGS, solve across full rollout time in one go, initiate from backward solution if optimiziation_schemes['LBFGS_full_backsolve']: print('\n') print( 'L-BFGS, solve across full rollout time in one go, initiate from backward solution' ) print('\n') x_inits = res['x_sols_backsolve'][recursions_per_chunks - 1:][::recursions_per_chunks] x_inits = [sortL96intoChannels(z, J=J).copy() for z in x_inits] opt_res = optim_initial_state(gen, T_rollouts=T_rollouts, T_obs=T_obs, N=N, n_chunks=n_chunks, optimizer_pars=optimizer_pars, x_inits=x_inits, targets=targets, grndtrths=grndtrths, loss_masks=loss_masks) res['x_sols_LBFGS_full_backsolve'] = opt_res[0] res['loss_vals_LBFGS_full_backsolve'] = opt_res[1] res['time_vals_LBFGS_full_backsolve'] = opt_res[2] res['state_mses_LBFGS_full_backsolve'] = opt_res[3] print('\n') print('storing results') print('\n') np.save(res_dir + fn, arr=res) # ## L-BFGS, solve across full rollout time in one go # - warning, this can be excruciatingly slow and hard to converge ! if optimiziation_schemes['LBFGS_full_persistence']: print('\n') print('L-BFGS, solve across full rollout time in one go') print('\n') x_init = get_init(sortL96intoChannels(res['targets_obs'], J=J)[0], res['loss_mask'][0], method='interpolate') x_inits = [x_init for j in range(n_chunks)] opt_res = optim_initial_state(gen, T_rollouts=T_rollouts, T_obs=T_obs, N=N, n_chunks=n_chunks, optimizer_pars=optimizer_pars, x_inits=x_inits, targets=targets, grndtrths=grndtrths, loss_masks=loss_masks) res['x_sols_LBFGS_full_persistence'] = opt_res[0] res['loss_vals_LBFGS_full_persistence'] = opt_res[1] res['time_vals_LBFGS_full_persistence'] = opt_res[2] res['state_mses_LBFGS_full_persistence'] = opt_res[3] print('\n') print('storing results') print('\n') np.save(res_dir + fn, arr=res) print('\n') print('done') print('\n')
def run_exp_4DVar(exp_id, datadir, res_dir, T_win, T_shift, B, K, J, T, N_trials, dt, spin_up_time, l96_F, l96_h, l96_b, l96_c, obs_operator, obs_operator_r, obs_operator_sig2, obs_operator_frq, model_exp_id, model_forwarder, optimizer, n_steps, lr, max_iter, max_eval, tolerance_grad, tolerance_change, history_size): fetch_commit = subprocess.Popen(['git', 'rev-parse', 'HEAD'], shell=False, stdout=subprocess.PIPE) commit_id = fetch_commit.communicate()[0].strip().decode("utf-8") fetch_commit.kill() T_shift = T_win if T_shift < 0 else T_shift # backwards compatibility to older exps with T_shift=T_win model_pars = { 'exp_id': model_exp_id, 'model_forwarder': model_forwarder, 'K_net': K, 'J_net': J, 'dt_net': dt } optimizer_pars = { 'optimizer': optimizer, 'n_steps': n_steps, 'lr': lr, 'max_iter': max_iter, 'max_eval': None if max_eval < 0 else max_eval, 'tolerance_grad': tolerance_grad, 'tolerance_change': tolerance_change, 'history_size': history_size } if model_forwarder == 'rk4_default': model_forwarder = rk4_default if model_forwarder == 'predictor_corrector': model_forwarder = predictor_corrector out, datagen_dict = get_data(K=K, J=J, T=T + spin_up_time, dt=dt, N_trials=N_trials, F=l96_F, h=l96_h, b=l96_b, c=l96_c, resimulate=True, solver=model_forwarder, save_sim=False, data_dir='') out = sortL96intoChannels(out.transpose(1, 0, 2)[int(spin_up_time / dt):], J=J) print('out.shape', out.shape) model, model_forwarder, args = get_model(model_pars, res_dir=res_dir, exp_dir='') obs_pars = {} if obs_operator == 'ObsOp_subsampleGaussian': obs_pars['obs_operator'] = ObsOp_subsampleGaussian obs_pars['obs_operator_args'] = { 'r': obs_operator_r, 'sigma2': obs_operator_sig2 } elif obs_operator == 'ObsOp_identity': obs_pars['obs_operator'] = ObsOp_identity obs_pars['obs_operator_args'] = {} elif obs_operator == 'ObsOp_rotsampleGaussian': obs_pars['obs_operator'] = ObsOp_rotsampleGaussian obs_pars['obs_operator_args'] = { 'frq': obs_operator_frq, 'sigma2': obs_operator_sig2 } else: raise NotImplementedError() model_observer = obs_pars['obs_operator'](**obs_pars['obs_operator_args']) prior = torch.distributions.normal.Normal(loc=torch.zeros((1, J + 1, K)), scale=B * torch.ones( (1, J + 1, K))) gen = GenModel(model_forwarder, model_observer, prior) y = sortL96fromChannels(gen._sample_obs( as_tensor(out))) # sets the loss masks! m = torch.stack(gen.masks, dim=0) save_dir = 'results/data_assimilation/' + exp_id + '/' mkdir_from_path(res_dir + save_dir) open(res_dir + save_dir + commit_id + '.txt', 'w') fn = save_dir + 'res' print('4D-VAR') x_sols, losses, times = solve_4dvar(y, m, T_obs=np.arange(y.shape[0]), T_win=T_win, T_shift=T_shift, x_init=None, model_pars=model_pars, obs_pars=obs_pars, optimizer_pars=optimizer_pars, res_dir=res_dir) np.save( res_dir + save_dir + 'out', { 'out': out, 'y': y.detach().cpu().numpy(), 'm': m.detach().cpu().numpy(), 'x_sols': x_sols, 'losses': losses, 'times': times, 'T_win': T_win, 'T_shift': T_shift }) print('x_sols.shape', x_sols.shape) print('done')
def run_exp(exp_id, datadir, res_dir, K, K_local, J, T, N_trials, dt, n_local, prediction_task, lead_time, seq_length, train_frac, validation_frac, spin_up_time, model_name, loss_fun, weight_decay, batch_size, batch_size_eval, max_epochs, eval_every, max_patience, lr, lr_min, lr_decay, max_lr_patience, only_eval, normalize_data, **net_kwargs): fetch_commit = subprocess.Popen(['git', 'rev-parse', 'HEAD'], shell=False, stdout=subprocess.PIPE) commit_id = fetch_commit.communicate()[0].strip().decode("utf-8") fetch_commit.kill() # load data fn_data = f'out_K{K}_J{J}_T{T}_N{N_trials}_dt0_{str(dt)[2:]}' out = np.load(datadir + fn_data + '.npy') print('data.shape', out.shape) assert (out.shape[1] - 1) * dt == T K_local = K if K_local < 0 else K_local DatasetClass = sel_dataset_class(prediction_task, N_trials, local=(K_local < K)) test_frac = 1. - (train_frac + validation_frac) assert test_frac >= 0. spin_up = int(spin_up_time / dt) dg_args = { 'data': out, 'J': J, 'offset': lead_time, 'normalize': bool(normalize_data) } if DatasetClass == DatasetMultiTrial_shattered: dg_args['K_local'] = K_local dg_args['n_local'] = n_local dg_train = DatasetClass(start=spin_up, end=int(np.floor(T / dt * train_frac)), **dg_args) dg_val = DatasetClass( start=int(np.ceil(T / dt * train_frac)), end=int(np.ceil(T / dt * (train_frac + validation_frac))), **dg_args) print('N training data:', len(dg_train)) print('N validation data:', len(dg_val)) batch_size_eval = batch_size if batch_size_eval < 1 else batch_size_eval validation_loader = torch.utils.data.DataLoader(dg_val, batch_size=batch_size_eval, drop_last=False, num_workers=0) train_loader = torch.utils.data.DataLoader(dg_train, batch_size=batch_size, drop_last=True, num_workers=0) ## define model print('net_kwargs', net_kwargs) model_fn = f'{exp_id}_dt{lead_time}.pt' model, model_forward = named_network(model_name=model_name, n_input_channels=J + 1, n_output_channels=J + 1, seq_length=seq_length, **net_kwargs) try: print('model.layer1.weights', model.layer1.weight.shape) except: pass if K_local < K: test_input = np.random.normal(size=(10, seq_length * (J + 1), K_local + 3 * n_local)) else: test_input = np.random.normal(size=(10, seq_length * (J + 1), K)) print(f'model output shape to test input of shape {test_input.shape}', model_forward(as_tensor(test_input)).shape) print( 'total #parameters: ', np.sum([np.prod(item.shape) for item in model.state_dict().values()])) ## train model save_dir = res_dir + 'models/' + exp_id + '/' if only_eval: print('loading model from disk') model.load_state_dict( torch.load(save_dir + model_fn, map_location=torch.device(device))) else: # actually train mkdir_from_path(save_dir) print('saving model state_dict to ' + save_dir + model_fn) open(save_dir + commit_id + '.txt', 'w') output_fn = '_training_outputs' extra_args = {} if loss_fun == 'local_mse': extra_args = { 'n_local': n_local, 'pad_local': (2, 2) if J == 1 else (2, 1) # relevant local area for L96 diff.eq. } print('loss_fun', loss_fun) loss_fun = loss_function(loss_fun, extra_args=extra_args) print('loss_fun', loss_fun) training_outputs = train_model(model, train_loader, validation_loader, device, model_forward, loss_fun=loss_fun, weight_decay=weight_decay, max_epochs=max_epochs, max_patience=max_patience, lr=lr, lr_min=lr_min, lr_decay=lr_decay, max_lr_patience=max_lr_patience, eval_every=eval_every, verbose=True, save_dir=save_dir, model_fn=model_fn, output_fn=output_fn) print('saving full model to ' + save_dir + model_fn[:-3] + '_full_model.pt') torch.save(model, save_dir + model_fn[:-3] + '_full_model.pt')
def run_exp_parametrization(exp_id, datadir, res_dir, parametrization, n_hiddens, kernel_size, K, J, T, dt, spin_up_time, l96_F, l96_h, l96_b, l96_c, train_frac, validation_frac, offset, model_exp_id, model_forwarder, loss_fun, batch_size, eval_every, lr, lr_min, lr_decay, weight_decay, max_epochs, max_patience, max_lr_patience): fetch_commit = subprocess.Popen(['git', 'rev-parse', 'HEAD'], shell=False, stdout=subprocess.PIPE) commit_id = fetch_commit.communicate()[0].strip().decode("utf-8") fetch_commit.kill() # load model model_pars = { 'exp_id': model_exp_id, 'model_forwarder': model_forwarder, 'K_net': K, 'J_net': 0, 'dt_net': dt } if model_forwarder == 'rk4_default': model_forwarder = rk4_default elif model_forwarder == 'predictor_corrector': model_forwarder = predictor_corrector model, model_forwarder, args = get_model(model_pars, res_dir=res_dir, exp_dir='') # instantiate parametrizations if parametrization == 'linear': param_train = Parametrization_lin(a=as_tensor(np.array([-0.75])), b=as_tensor(np.array([-0.4]))) param_offline = Parametrization_lin(a=as_tensor(np.array([-0.75])), b=as_tensor(np.array([-0.4]))) elif parametrization == 'nn': param_train = Parametrization_nn(n_hiddens=n_hiddens, kernel_size=kernel_size) param_offline = Parametrization_nn(n_hiddens=n_hiddens, kernel_size=kernel_size) # make sure they share initialization: param_offline.load_state_dict(copy.deepcopy(param_train.state_dict())) else: raise NotImplementedError() for p in model.parameters(): p.requires_grad = False model_parametrized = Parametrized_twoLevel_L96(emulator=model, parametrization=param_train) model_forwarder_parametrized = Model_forwarder_rk4default( model=model_parametrized, dt=dt) print('torch.nn.Parameters of parametrization require grad: ') for p in model_forwarder_parametrized.model.param.parameters(): print(p.requires_grad) print('torch.nn.Parameters of emulator require grad: ') for p in model_forwarder_parametrized.model.emulator.parameters(): print(p.requires_grad) if len(offset) > 1: # multi-step predictions print('multi-step predictions') gm = GenModel(model_forwarder=model_forwarder_parametrized, model_observer=ObsOp_identity(), prior=SimplePrior(J=0, K=K)) class MultiStepForwarder(torch.nn.Module): def __init__(self, model, offset): super(MultiStepForwarder, self).__init__() self.model = model self.offset = offset def forward(self, x): return torch.stack(gm._forward(x=x, T_obs=self.offset), dim=1) model_forwarder_parametrized = MultiStepForwarder( model=gm, offset=np.asarray(offset)) print('model forwarder', model_forwarder_parametrized) if parametrization == 'linear': print('initialized a', model_parametrized.param.a) print('initialized b', model_parametrized.param.b) elif parametrization == 'nn': print('initialized first-layer weights', model_parametrized.param.layers[0].weight) # ground-truth two-level L96 model (based on Numba implementation): dX_dt = np.empty(K * (J + 1), dtype=dtype_np) if J > 0: def fun(t, x): return f2(x, l96_F, l96_h, l96_b, l96_c, dX_dt, K, J) else: def fun(t, x): return f1(x, l96_F, dX_dt, K) class Torch_solver(torch.nn.Module): # numerical solver (from numpy/numba/Julia) def __init__(self, fun): self.fun = fun def forward(self, x): x = sortL96fromChannels(x.detach().cpu().numpy()).flatten() return as_tensor( sortL96intoChannels(np.atleast_2d(self.fun(0., x)), J=J)) model_forwarder_np = Model_forwarder_rk4default(Torch_solver(fun), dt=dt) # create some training data from the true two-level L96 X_init = 0.5 + np.random.randn(1, K * (J + 1)) * 1.0 X_init = l96_F * X_init.astype(dtype=dtype_np) / np.maximum(J, 50) def model_simulate(y0, dy0, n_steps): x = np.empty((n_steps + 1, *y0.shape[1:]), dtype=dtype_np) x[0] = y0.copy() xx = as_tensor(x[0]).reshape(1, 1, -1) for i in range(1, n_steps + 1): xx = model_forwarder_np(xx.reshape(1, J + 1, -1)) x[i] = xx.detach().cpu().numpy().copy() return x T_dur = int(T / dt) spin_up = int(spin_up_time / dt) print('simulating high-res (two-level L96) data') data_full = model_simulate(y0=sortL96intoChannels(X_init, J=J), dy0=None, n_steps=T_dur + spin_up) print('full data shape: ', data_full.shape) assert np.all(np.isfinite(data_full)) # offline training of parametrization print('offline training') X = sortL96intoChannels(data_full[:, 0, :], J=model_pars['J_net']) y = -l96_h * l96_c * sortL96intoChannels(data_full[:, 1:, :].mean(axis=1), J=model_pars['J_net']) dg_train = Dataset_offline( data=(X, y), start=spin_up, end=spin_up + int(np.floor(T_dur * train_frac)) - np.max(offset)) print('len dg_train', len(dg_train)) train_loader = torch.utils.data.DataLoader(dg_train, batch_size=batch_size, drop_last=True, num_workers=0) dg_val = Dataset_offline( data=(X, y), start=spin_up + int(np.floor(T_dur * train_frac)), end=spin_up + int(np.ceil(T_dur * (train_frac + validation_frac))) - np.max(offset)) print('len dg_val', len(dg_val)) validation_loader = torch.utils.data.DataLoader(dg_val, batch_size=batch_size, drop_last=False, num_workers=0) print('starting optimization of parametrization') training_outputs_offline = train_model(model=param_offline, train_loader=train_loader, validation_loader=validation_loader, device=device, model_forward=param_offline, loss_fun=loss_function( loss_fun=loss_fun, extra_args={}), lr=lr, lr_min=lr_min, lr_decay=lr_decay, weight_decay=weight_decay, max_epochs=max_epochs, max_patience=max_patience, max_lr_patience=max_lr_patience, eval_every=eval_every) if parametrization == 'linear': print('learned a', param_offline.a) print('learned b', param_offline.b) elif parametrization == 'nn': print('initialized first-layer weights', param_offline.layers[0].weight) # online training of parametrization print('online training') # two-level simulates for fast and slow variables, we only take the slow ones for training ! data = data_full[:, 0, :] data = data.reshape(1, *data.shape) # N x T x K*(J+1) print('training data shape: ', data_full.shape) DatasetClass = sel_dataset_class(prediction_task='state', N_trials=1, local=False, offset=offset) print('dataset class', DatasetClass) print('len(offset)', len(offset)) assert train_frac + validation_frac <= 1. dg_dict = { 'data': data, 'J': 0, 'offset': offset[0] if len(offset) == 1 else offset, 'normalize': False } dg_train = DatasetClass(start=spin_up, end=spin_up + int(np.floor(T_dur * train_frac)) - np.max(offset), **dg_dict) print('len dg_train', len(dg_train)) train_loader = torch.utils.data.DataLoader(dg_train, batch_size=batch_size, drop_last=True, num_workers=0) dg_val = DatasetClass( start=spin_up + int(np.floor(T_dur * train_frac)), end=spin_up + int(np.ceil(T_dur * (train_frac + validation_frac))) - np.max(offset), **dg_dict) print('len dg_val', len(dg_val)) validation_loader = torch.utils.data.DataLoader(dg_val, batch_size=batch_size, drop_last=False, num_workers=0) print('starting optimization of parametrization') training_outputs = train_model(model=model_forwarder_parametrized, train_loader=train_loader, validation_loader=validation_loader, device=device, model_forward=model_forwarder_parametrized, loss_fun=loss_function(loss_fun=loss_fun, extra_args={}), lr=lr, lr_min=lr_min, lr_decay=lr_decay, weight_decay=weight_decay, max_epochs=max_epochs, max_patience=max_patience, max_lr_patience=max_lr_patience, eval_every=eval_every) if parametrization == 'linear': print('learned a', model_parametrized.param.a) print('learned b', model_parametrized.param.b) elif parametrization == 'nn': print('initialized first-layer weights', model_parametrized.param.layers[0].weight) save_dir = 'results/parametrization/' + exp_id + '/' mkdir_from_path(res_dir + save_dir) open(res_dir + save_dir + commit_id + '.txt', 'w') state_dict = param_train.state_dict() for key, value in state_dict.items(): state_dict[key] = value.detach().cpu().numpy() state_dict_offline = param_offline.state_dict() for key, value in state_dict_offline.items(): state_dict_offline[key] = value.detach().cpu().numpy() np.save( res_dir + save_dir + 'out', { 'data_full': data_full, 'X_init': data_full[-1].reshape(1, -1), 'param_train_state_dict': state_dict, 'param_offline_state_dict': state_dict_offline, 'X': X, 'y': y }) print('done')
def forward(self, x): x = sortL96fromChannels(x.detach().cpu().numpy()).flatten() return as_tensor( sortL96intoChannels(np.atleast_2d(self.fun(0., x)), J=J))