Exemplo n.º 1
0
 def model_simulate(y0, dy0, n_steps):
     x = np.empty((n_steps+1, *y0.shape[1:]))
     x[0] = y0.copy()
     xx = as_tensor(x[0])
     dx = as_tensor(dy0.copy())
     for i in range(1,n_steps+1):
         xxo = xx * 1.
         xx = std_out * model_forward(torch.cat((xx.reshape(1,J+1,K), dx), axis=1)) + mean_out + xx.reshape(1,J+1,-1)
         dx = xx - xxo
         x[i] = xx.detach().cpu().numpy().copy()
     return x
Exemplo n.º 2
0
    def __init__(self, r=0., sigma2=0.):
        super(ObsOp_subsampleGaussian, self).__init__()
        assert sigma2 >= 0.
        self.sigma2 = as_tensor(sigma2)
        self.sigma = torch.sqrt(self.sigma2)
        self.ndistr = torch.distributions.normal.Normal(loc=0.,
                                                        scale=self.sigma)

        assert 0. <= r <= 1.
        self.r = as_tensor(r)
        self.mdistr = torch.distributions.Bernoulli(probs=1 - self.r)
        self.mask = 1.
Exemplo n.º 3
0
def get_rollout_fun(dg_train, model_forward, prediction_task):

    J, K = dg_train.data.shape[1]-1, dg_train.data.shape[2]

    if prediction_task == 'update_with_past':

        raise NotImplementedError

        assert isinstance(dg_train, DatasetRelPredPast)
        std_out = as_tensor(dg_train.std_out)
        mean_out = as_tensor(dg_train.mean_out)

        def model_simulate(y0, dy0, n_steps):
            x = np.empty((n_steps+1, *y0.shape[1:]))
            x[0] = y0.copy()
            xx = as_tensor(x[0])
            dx = as_tensor(dy0.copy())
            for i in range(1,n_steps+1):
                xxo = xx * 1.
                xx = std_out * model_forward(torch.cat((xx.reshape(1,J+1,K), dx), axis=1)) + mean_out + xx.reshape(1,J+1,-1)
                dx = xx - xxo
                x[i] = xx.detach().cpu().numpy().copy()
            return x

    elif prediction_task == 'update': 

        raise NotImplementedError

        assert isinstance(dg_train, DatasetRelPred)
        def model_simulate(y0, dy0, n_steps):
            x = np.empty((n_steps+1, *y0.shape[1:]))
            x[0] = y0.copy()
            xx = as_tensor(x[0]).reshape(1,1,-1)
            for i in range(1,n_steps+1):
                xx = model_forward(xx.reshape(1,J+1,-1)) + xx.reshape(1,J+1,-1)
                x[i] = xx.detach().cpu().numpy().copy()
            return x

    elif prediction_task == 'state': 

        assert isinstance(dg_train, Dataset)
        def model_simulate(y0, dy0, n_steps):
            x = np.empty((n_steps+1, *y0.shape[1:]))
            x[0] = y0.copy()
            xx = as_tensor(x[0]).reshape(1,1,-1)
            for i in range(1,n_steps+1):
                xx = model_forward(xx.reshape(1,J+1,-1))
                x[i] = xx.detach().cpu().numpy().copy()
            return x

    return model_simulate
Exemplo n.º 4
0
 def exp_fsr(
     x_init
 ):  # torch decorator for forward-solve in reverse learned from model_exp_id
     x = as_tensor(x_init)
     for t in range(res['back_solve_dt_fac'] * T_rollout //
                    n_chunks_recursive):
         x = model_forwarder_eb.forward(x)
     return x.detach().cpu().numpy()
Exemplo n.º 5
0
 def model_simulate(y0, dy0, n_steps):
     x = np.empty((n_steps + 1, *y0.shape[1:]), dtype=dtype_np)
     x[0] = y0.copy()
     xx = as_tensor(x[0]).reshape(1, 1, -1)
     for i in range(1, n_steps + 1):
         xx = model_forwarder_np(xx.reshape(1, J + 1, -1))
         x[i] = xx.detach().cpu().numpy().copy()
     return x
Exemplo n.º 6
0
    def __init__(self, frq=4, sigma2=0.):
        super(ObsOp_rotsampleGaussian, self).__init__()
        assert sigma2 >= 0.
        self.sigma2 = as_tensor(sigma2)
        self.sigma = torch.sqrt(self.sigma2)
        self.ndistr = torch.distributions.normal.Normal(loc=0.,
                                                        scale=self.sigma)

        assert 0. <= frq
        self.frq = frq
        self.mask = 1.
        self.ridx = 0
Exemplo n.º 7
0
    def __init__(self, model_forwarder, prediction_task, K, J, N, T=1, x_init=None):

        super(Rollout, self).__init__()

        self.model_forwarder = model_forwarder
        self.prediction_task = prediction_task

        self.T = T

        x_init = np.random.normal(size=(N,K*(J+1))) if x_init is None else x_init
        assert x_init.ndim in [2,3]
        self.X = torch.nn.Parameter(as_tensor(x_init))
Exemplo n.º 8
0
def load_model_from_exp_conf(res_dir, conf):
    
    net_kwargs = {
            'filters': conf['filters'],
            'kernel_sizes': conf['kernel_sizes'],
            'filters_ks1_init': conf['filters_ks1_init'],
            'filters_ks1_inter': conf['filters_ks1_inter'],
            'filters_ks1_final': conf['filters_ks1_final'],
            'additiveResShortcuts' : conf['additiveResShortcuts'],
            'direct_shortcut' : conf['direct_shortcut'],
            'dropout_rate' : conf['dropout_rate'],
            'layerNorm' : conf['layerNorm'],
            'init_net' : conf['init_net'], 
            'K_net' : conf['K_net'], 
            'J_net' : conf['J_net'], 
            'F_net' : conf['F_net'],
            'dt_net' : conf['dt_net'], 
            'alpha_net' : conf['alpha_net'],
            'model_forwarder' : conf['model_forwarder'],
            'padding_mode' : conf['padding_mode']
    }

    model, model_forward = named_network(model_name=conf['model_name'], 
                                         n_input_channels=conf['J']+1, 
                                         n_output_channels=conf['J']+1,
                                         seq_length=conf['seq_length'],
                                         **net_kwargs)

    test_input = np.random.normal(size=(10, conf['seq_length']*(conf['J']+1), conf['K']))
    print(f'model output shape to test input of shape {test_input.shape}', 
          model.forward(as_tensor(test_input)).shape)

    print('total #parameters: ', np.sum([np.prod(item.shape) for item in model.state_dict().values()]))

    lead_time, exp_id = conf['lead_time'], conf['exp_id']
    save_dir = res_dir + 'models/' + exp_id + '/'
    model_fn = f'{exp_id}_dt{lead_time}.pt'
    
    print('save_dir + model_fn', save_dir + model_fn)
    model.load_state_dict(torch.load(save_dir + model_fn, map_location=torch.device(device)))
    
    try: 
        training_outputs = np.load(save_dir + '_training_outputs' + '.npy', allow_pickle=True)[()]
    except:
        training_outputs = None
        print('WARNING: could not load diagnostical outputs from model training!')
    
    return model, model_forward, training_outputs
Exemplo n.º 9
0
    def set_state(self, x_init):

        self.X = torch.nn.Parameter(as_tensor(x_init))
Exemplo n.º 10
0
def solve_4dvar(y, m, T_obs, T_win, T_shift, x_init, model_pars, obs_pars,
                optimizer_pars, res_dir):
    """
    def solve_4dvar(system_pars, model_pars, optimizer_pars, setup_pars, optimiziation_schemes,
                res_dir, data_dir, fn=None):
    """

    # extract key variable names from input dicts
    T, N, J, K = m.shape
    J -= 1
    assert y.shape == (T, N, (J + 1) * K)

    if x_init is None:
        x_init = get_init(sortL96intoChannels(y[0], J=J).detach().cpu(),
                          m[0].detach().cpu(),
                          method='interpolate')
    assert x_init.shape == (N, J + 1, K)

    # get model
    model, model_forwarder, args = get_model(model_pars,
                                             res_dir=res_dir,
                                             exp_dir='')
    model_observer = obs_pars['obs_operator'](**obs_pars['obs_operator_args'])
    prior = torch.distributions.normal.Normal(loc=torch.zeros((1, J + 1, K)),
                                              scale=1. * torch.ones(
                                                  (1, J + 1, K)))
    gen = GenModel(model_forwarder,
                   model_observer,
                   prior,
                   T=T_win,
                   x_init=None)
    priors = None

    assert len(
        T_obs
    ) == T  # easy to generalize, but for now only support one obs per time point
    n_starts = (T - T_win) // T_shift + 1

    out, losses, times = [], [], []
    for n in range(n_starts):

        print('\n')
        print(f'optimizing window number {n+1} / {n_starts}')
        print('\n')

        idx = np.where(
            np.logical_and(n * T_shift + T_win > T_obs,
                           T_obs >= n * T_shift))[0]
        assert len(idx) > 0  # atm not supporting empty integration window

        opt_res = optim_initial_state(gen,
                                      T_rollouts=[T_win],
                                      T_obs=[T_obs[idx] - n * T_shift],
                                      N=N,
                                      n_chunks=1,
                                      optimizer_pars=optimizer_pars,
                                      x_inits=[x_init],
                                      targets=[y[idx]],
                                      grndtrths=None,
                                      loss_masks=[m[idx]])

        x_sols, loss_vals, time_vals, _ = opt_res

        out.append(sortL96intoChannels(x_sols[0], J=J))
        losses.append(loss_vals)
        times.append(time_vals)

        priors = [
            SimplePrior(K=K, J=J, loc=as_tensor(out[-1][n]), scale=1.)
            for n in range(N)
        ]

        x_init = gen._forward(x=as_tensor(out[-1]),
                              T_obs=[T_shift - 1])[0].detach().cpu().numpy()

    return np.stack(out, axis=0), losses, times
Exemplo n.º 11
0
def optim_initial_state(gen,
                        T_rollouts,
                        T_obs,
                        N,
                        n_chunks,
                        optimizer_pars,
                        x_inits,
                        targets,
                        grndtrths=None,
                        loss_masks=None,
                        priors=None,
                        f_init=None):

    sample_shape = gen.prior.sample().shape  # (..., J+1, K)
    J, K = sample_shape[-2] - 1, sample_shape[-1]
    n_steps = optimizer_pars['n_steps']

    x_sols = np.zeros((n_chunks, N, K * (J + 1)))
    loss_vals = np.inf * np.ones((n_steps * n_chunks, N))
    time_vals = time.time() * np.ones((n_steps * n_chunks, N))
    state_mses = np.inf * np.ones((n_chunks, N))

    loss_masks = [torch.ones(
        (N, J + 1,
         K)) for i in range(n_chunks)] if loss_masks is None else loss_masks
    assert len(loss_masks) == n_chunks

    if priors is None:

        class Const_prior(object):
            def __init__(self):
                pass

            def log_prob(self, x):
                return 0.

        priors = [Const_prior() for n in range(N)]
    assert len(priors) == N

    i_ = 0
    for j in range(n_chunks):

        print('\n')
        print(f'optimizing over chunk #{j+1} out of {n_chunks}')
        print('\n')

        target = sortL96intoChannels(as_tensor(targets[j]), J=J)
        loss_mask = loss_masks[j]

        assert len(target) == len(T_obs[j]) and len(loss_mask) == len(T_obs[j])

        for n in range(N):

            print('\n')
            print(f'optimizing over initial state #{n+1} / {N}')
            print('\n')

            gen.set_state(x_inits[j][n:n + 1])
            gen.set_rollout_len(T_rollouts[j])

            optimizer = torch.optim.LBFGS(
                params=[gen.X],
                lr=optimizer_pars['lr'],
                max_iter=optimizer_pars['max_iter'],
                max_eval=optimizer_pars['max_eval'],
                tolerance_grad=optimizer_pars['tolerance_grad'],
                tolerance_change=optimizer_pars['tolerance_change'],
                history_size=optimizer_pars['history_size'],
                line_search_fn='strong_wolfe')

            for i_n in range(n_steps):

                with torch.no_grad():
                    loss = -gen.log_prob(y=target[:, n:n + 1],
                                         m=loss_mask[:, n:n + 1],
                                         T_obs=T_obs[j])
                    loss = loss - priors[n].log_prob(gen.X)

                    if i_n == 0:
                        print('initial loss: ', loss)
                    if torch.any(torch.isnan(loss)):
                        loss_vals[i_n, n] = loss.detach().cpu().numpy()
                        time_vals[i_ + i_n,
                                  n] = time.time() - time_vals[i_ + i_n, n]
                        print(('{:.4f}'.format(time_vals[i_n,
                                                         n]), loss_vals[i_n,
                                                                        n]))
                        print('NaN loss - skipping iteration')

                        print('optimizier.state', optimizer.state[gen.X])

                        continue

                def closure():
                    loss = -gen.log_prob(y=target[:, n:n + 1],
                                         m=loss_mask[:, n:n + 1],
                                         T_obs=T_obs[j])
                    loss = loss - priors[n].log_prob(gen.X)
                    if torch.is_grad_enabled():
                        optimizer.zero_grad()
                    if loss.requires_grad:
                        loss.backward()
                    return loss

                optimizer.step(closure)
                loss_vals[i_ + i_n, n] = loss.detach().cpu().numpy()
                time_vals[i_ + i_n, n] = time.time() - time_vals[i_ + i_n, n]
                print(('{:.4f}'.format(time_vals[i_n, n]), loss_vals[i_n, n]))

            x_sols[j][n] = sortL96fromChannels(
                gen.X.detach().cpu().numpy().copy())
            state_mses[j][n] = np.inf if grndtrths is None else ((
                x_sols[j][n] - grndtrths[j][n])**2).mean()

        # if solving recursively, define next target as current initial state estimate
        if j < n_chunks - 1 and targets[j + 1] is None:
            targets[j + 1] = x_sols[j].copy().reshape(1, *x_sols[j].shape)

        i_ += n_steps

        with torch.no_grad():
            if not grndtrths is None:
                print('Eucl. distance to initial value',
                      mse_loss_fullyObs(x_sols[j], grndtrths[j]))
            print(
                'Eucl. distance to x_init',
                mse_loss_fullyObs(x_sols[j], sortL96fromChannels(x_inits[j])))

        if j < n_chunks - 1 and x_inits[j + 1] is None:
            x_inits[j + 1] = sortL96intoChannels(x_sols[j], J=J).copy()
            if not f_init is None:
                x_inits[j + 1] = f_init(x_inits[j + 1])

    # correting time stamps for solving multiple trials sequentially
    for j in range(n_chunks):
        top_new = time_vals[(j + 1) * n_steps - 1, N - 1]
        for n in range(1, N)[::-1]:
            # correct for the fact that n-1 other problems were solve before for this j
            time_vals[j * n_steps + np.arange(n_steps),
                      n] -= time_vals[(j + 1) * n_steps - 1, n - 1]
            if j > 0:
                # continue from j-1 for this n
                time_vals[j * n_steps + np.arange(n_steps),
                          n] += time_vals[j * n_steps - 1, n]
        if j > 0:
            # for first trial (n=0) of this j, clear previous time and continue from j-1
            time_vals[j * n_steps + np.arange(n_steps), 0] -= top_old
            time_vals[j * n_steps + np.arange(n_steps),
                      0] += time_vals[j * n_steps - 1, 0]
        top_old = top_new

    return x_sols, loss_vals, time_vals, state_mses
Exemplo n.º 12
0
def solve_initstate(system_pars,
                    model_pars,
                    optimizer_pars,
                    setup_pars,
                    optimiziation_schemes,
                    res_dir,
                    data_dir,
                    fn=None):

    # extract key variable names from input dicts
    K, J = system_pars['K'], system_pars['J']
    T, dt, N_trials = system_pars['T'], system_pars['dt'], system_pars[
        'N_trials']

    n_starts, T_rollout, T_pred = setup_pars['n_starts'], setup_pars[
        'T_rollout'], setup_pars['T_pred']
    n_chunks, n_chunks_recursive = setup_pars['n_chunks'], setup_pars[
        'n_chunks_recursive']

    N = len(n_starts)
    recursions_per_chunks = n_chunks_recursive // n_chunks

    assert recursions_per_chunks == n_chunks_recursive / n_chunks

    assert T_rollout // n_chunks_recursive == T_rollout / n_chunks_recursive
    assert T_rollout // n_chunks == T_rollout / n_chunks

    if optimiziation_schemes['LBFGS_full_chunks']:
        assert optimiziation_schemes['LBFGS_chunks']  # requirement for init

    if optimiziation_schemes['LBFGS_full_backsolve']:
        assert optimiziation_schemes['backsolve']  # requirement for init

    # get model
    model, model_forwarder, args = get_model(model_pars,
                                             res_dir=res_dir,
                                             exp_dir='')

    # ### instantiate observation operator
    model_observer = system_pars['obs_operator'](
        **system_pars['obs_operator_args'])

    # ### define prior over initial states
    prior = torch.distributions.normal.Normal(loc=torch.zeros((1, J + 1, K)),
                                              scale=1. * torch.ones(
                                                  (1, J + 1, K)))

    # ### define generative model for observed data
    gen = GenModel(model_forwarder,
                   model_observer,
                   prior,
                   T=T_rollout,
                   x_init=None)

    # prepare function output
    model_forwarder_str, optimizer_str = args[
        'model_forwarder'], optimizer_pars['optimizer']
    obs_operator_str = model_observer.__class__.__name__
    exp_id = model_pars['exp_id']

    # output dictionary
    res = {
        'exp_id': exp_id,
        'K': K,
        'J': J,
        'T': T,
        'dt': dt,
        'back_solve_dt_fac': system_pars['back_solve_dt_fac'],
        'F': system_pars['F'],
        'h': system_pars['h'],
        'b': system_pars['b'],
        'c': system_pars['c'],
        'conf_exp': args['conf_exp'],
        'model_forwarder':
        model_pars['model_forwarder'],  # should still be string
        'dt_net': model_pars['dt_net'],
        'n_starts': n_starts,
        'T_rollout': T_rollout,
        'T_pred': T_pred,
        'n_chunks': n_chunks,
        'n_chunks_recursive': n_chunks_recursive,
        'recursions_per_chunks': recursions_per_chunks,
        'n_steps': optimizer_pars['n_steps'],
        'n_steps_tot': optimizer_pars['n_steps'] * n_chunks,
        'optimizer_pars': optimizer_pars,
        'optimiziation_schemes': optimiziation_schemes,
        'obs_operator': obs_operator_str,
        'obs_operator_args': system_pars['obs_operator_args']
    }

    # ### get data for 'typical' L96 state sequences

    out, datagen_dict = get_data(K=K,
                                 J=J,
                                 T=T,
                                 dt=dt,
                                 N_trials=N_trials,
                                 F=res['F'],
                                 h=res['h'],
                                 b=res['b'],
                                 c=res['c'],
                                 resimulate=True,
                                 solver=rk4_default,
                                 save_sim=False,
                                 data_dir=data_dir)

    grndtrths = [out[n_starts] for j in range(n_chunks_recursive)]
    res['initial_states'] = np.stack(
        [sortL96intoChannels(z, J=J) for z in grndtrths])

    T_obs = [[(j + 1) * (T_rollout // n_chunks) - 1 for j in range(n_ + 1)]
             for n_ in range(n_chunks)]
    print('T_obs[-1]', T_obs[-1])
    res['targets'] = np.stack(
        [sortL96intoChannels(out[n_starts + t + 1], J=J) for t in T_obs[-1]],
        axis=0)

    # ## Generate observed data: (sub-)sample noisy observations
    res['test_state'] = sortL96intoChannels(out[n_starts + T_pred],
                                            J=J).reshape(
                                                1, len(n_starts), J + 1, K)
    res['test_state_obs'] = gen._sample_obs(as_tensor(
        res['test_state']))  # sets the loss masks!
    res['test_state_obs'] = sortL96fromChannels(
        res['test_state_obs'].detach().cpu().numpy())
    res['test_state_mask'] = torch.stack(gen.masks,
                                         dim=0).detach().cpu().numpy()

    res['targets_obs'] = gen._sample_obs(as_tensor(
        res['targets']))  # sets the loss masks!
    res['targets_obs'] = sortL96fromChannels(
        res['targets_obs'].detach().cpu().numpy())
    res['loss_mask'] = torch.stack(gen.masks, dim=0).detach().cpu().numpy()

    if fn is None:
        fn = 'results/data_assimilation/fullyobs_initstate_tests_'
        fn = fn + f'exp{exp_id}_{model_forwarder_str}_{optimizer_str}_{obs_operator_str}'

    print('\n')
    print('storing intermediate results')
    print('\n')
    np.save(res_dir + fn, arr=res)

    # ### define setup for optimization

    T_rollouts = np.arange(1, n_chunks + 1) * (T_rollout // n_chunks)
    grndtrths = [out[n_starts] for j in range(n_chunks)]
    targets = [res['targets_obs'][:len(T_obs[j])] for j in range(n_chunks)]
    loss_masks = [
        torch.stack(gen.masks[:len(T_obs[j])], dim=0) for j in range(n_chunks)
    ]

    grndtrths_chunks = [out[n_starts] for j in range(n_chunks_recursive)]

    # ## L-BFGS, solve across full rollout time recursively, initialize from forward solver in reverse

    if optimiziation_schemes['LBFGS_recurse_chunks']:

        print('\n')
        print(
            'L-BFGS, solve across full rollout time recursively, initialize from forward solver in reverse'
        )
        print('\n')

        assert len(
            T_obs) == 1  # only allow single observation at end of interval
        assert len(res['targets_obs']) == 1

        # functions for explicitly solving backwards
        class Model_eb(torch.nn.Module):
            def __init__(self, model):
                super(Model_eb, self).__init__()
                self.model = model

            def forward(self, x):
                return -self.model.forward(x)

        if res['model_forwarder'] == 'rk4_default':
            Model_forwarder = Model_forwarder_rk4default
        elif res['model_forwarder'] == 'predictor_corrector':
            Model_forwarder = Model_forwarder_predictorCorrector

        model_forwarder_eb = Model_forwarder(model=Model_eb(model),
                                             dt=dt / res['back_solve_dt_fac'])

        def exp_fsr(
            x_init
        ):  # torch decorator for forward-solve in reverse learned from model_exp_id
            x = as_tensor(x_init)
            for t in range(res['back_solve_dt_fac'] * T_rollout //
                           n_chunks_recursive):
                x = model_forwarder_eb.forward(x)
            return x.detach().cpu().numpy()

        x_init = get_init(sortL96intoChannels(res['targets_obs'], J=J)[0],
                          res['loss_mask'][0],
                          method='interpolate')
        x_inits = [exp_fsr(x_init)] + [None for j in range(n_chunks_recursive)]

        opt_res = optim_initial_state(
            gen,
            T_rollouts=np.arange(1, n_chunks_recursive + 1) *
            (T_rollout // n_chunks_recursive),
            T_obs=[[j] for j in range(recursions_per_chunks)],
            N=N,
            n_chunks=n_chunks_recursive,
            optimizer_pars=optimizer_pars,
            x_inits=x_inits,
            targets=list(np.repeat(targets, recursions_per_chunks, axis=0)),
            grndtrths=[
                out[n_starts + T_rollout - (j + 1) *
                    (T_rollout // n_chunks_recursive)]
                for j in range(n_chunks_recursive)
            ],
            loss_masks=list(
                np.repeat(loss_masks, recursions_per_chunks, axis=0)),
            f_init=exp_fsr)

        res['x_sols_LBFGS_recurse_chunks'] = opt_res[0]
        res['loss_vals_LBFGS_recurse_chunks'] = opt_res[1]
        res['time_vals_LBFGS_recurse_chunks'] = opt_res[2]
        res['state_mses_LBFGS_recurse_chunks'] = opt_res[3]

        print('\n')
        print('storing results')
        print('\n')
        np.save(res_dir + fn, arr=res)

    # ## L-BFGS, solve across single chunks recursively, initialize from last chunk

    if optimiziation_schemes['LBFGS_chunks']:

        print('\n')
        print(
            'L-BFGS, solve across single chunks recursively, initialize from last chunk'
        )
        print('\n')

        assert len(
            T_obs) == 1  # only allow single observation at end of interval
        assert len(res['targets_obs']) == 1

        x_inits = [None for i in range(n_chunks_recursive)]
        x_inits[0] = get_init(sortL96intoChannels(res['targets_obs'], J=J)[0],
                              res['loss_mask'][0],
                              method='interpolate')

        opt_res = optim_initial_state(
            gen,
            T_rollouts=np.ones(n_chunks_recursive, dtype=np.int) *
            (T_rollout // n_chunks_recursive),
            T_obs=[[0] for j in range(n_chunks_recursive)],
            N=N,
            n_chunks=n_chunks_recursive,
            optimizer_pars=optimizer_pars,
            x_inits=x_inits,
            targets=[res['targets_obs']] +
            [None for i in range(n_chunks_recursive - 1)],
            grndtrths=[
                out[n_starts + T_rollout - (j + 1) *
                    (T_rollout // n_chunks_recursive)]
                for j in range(n_chunks_recursive)
            ],
            loss_masks=[torch.stack(gen.masks, dim=0)] + [
                torch.ones((1, N, J + 1, K))
                for i in range(n_chunks_recursive - 1)
            ])

        res['x_sols_LBFGS_chunks'] = opt_res[0]
        res['loss_vals_LBFGS_chunks'] = opt_res[1]
        res['time_vals_LBFGS_chunks'] = opt_res[2]
        res['state_mses_LBFGS_chunks'] = opt_res[3]

        print('\n')
        print('storing results')
        print('\n')
        np.save(res_dir + fn, arr=res)

    # ## L-BFGS, solve across full rollout time in one go, initialize from chunked approach

    if optimiziation_schemes['LBFGS_full_chunks']:

        print('\n')
        print(
            'L-BFGS, solve across full rollout time in one go, initialize from chunked approach'
        )
        print('\n')

        x_inits = res['x_sols_LBFGS_chunks'][recursions_per_chunks -
                                             1:][::recursions_per_chunks]
        x_inits = [sortL96intoChannels(z, J=J).copy() for z in x_inits]

        opt_res = optim_initial_state(gen,
                                      T_rollouts=T_rollouts,
                                      T_obs=T_obs,
                                      N=N,
                                      n_chunks=n_chunks,
                                      optimizer_pars=optimizer_pars,
                                      x_inits=x_inits,
                                      targets=targets,
                                      grndtrths=grndtrths,
                                      loss_masks=loss_masks)

        res['x_sols_LBFGS_full_chunks'] = opt_res[0]
        res['loss_vals_LBFGS_full_chunks'] = opt_res[1]
        res['time_vals_LBFGS_full_chunks'] = opt_res[2]
        res['state_mses_LBFGS_full_chunks'] = opt_res[3]

        print('\n')
        print('storing results')
        print('\n')
        np.save(res_dir + fn, arr=res)

    # ## numerical forward solve in reverse

    if optimiziation_schemes['backsolve']:

        print('\n')
        print('numerical forward solve in reverse')
        print('\n')

        # functions for explicitly solving backwards
        from L96sim.L96_base import f1, f2, pf2

        dX_dt = np.empty(K * (J + 1), dtype=dtype_np)
        if J > 0:

            def fun_eb(t, x):
                return -f2(x, res['F'], res['h'], res['b'], res['c'], dX_dt, K,
                           J)
        else:

            def fun_eb(t, x):
                return -f1(x, res['F'], dX_dt, K)

        def explicit_backsolve(x_init, times_eb, fun_eb):
            x_sols = np.zeros_like(x_init)
            for i__ in range(x_init.shape[0]):
                out2 = rk4_default(fun=fun_eb, y0=x_init[i__], times=times_eb)
                x_sols[i__] = out2[-1].copy()  #.detach().cpu().numpy().copy()
            return x_sols

        res['state_mses_backsolve'] = np.zeros(
            (n_chunks_recursive, len(n_starts)))
        res['x_sols_backsolve'] = np.zeros(
            (n_chunks_recursive, len(n_starts), K * (J + 1)))
        res['loss_vals_backsolve'] = np.zeros(
            (n_chunks_recursive, len(n_starts)))

        print('backward solving')
        res['time_vals_backsolve'] = time.time() * np.ones(
            (n_chunks_recursive, len(n_starts)))

        x_init = get_init(sortL96intoChannels(res['targets_obs'][0], J=J),
                          res['loss_mask'][0],
                          method='interpolate')
        for j in range(recursions_per_chunks):
            times_eb = dt * np.linspace(
                0, T_rollouts[0] / recursions_per_chunks,
                res['back_solve_dt_fac'] * T_rollouts[0] /
                recursions_per_chunks + 1)
            print('x_init.shape', sortL96fromChannels(x_init).shape)
            res['x_sols_backsolve'][j] = explicit_backsolve(
                sortL96fromChannels(x_init), times_eb, fun_eb)
            x_init = sortL96intoChannels(res['x_sols_backsolve'][j].copy(),
                                         J=J)
            print('x_init.shape - out', x_init.shape)
            res['time_vals_backsolve'][j] = res['time_vals_backsolve'][0]
            x_target = out[n_starts + T_rollouts[0] - (j + 1) *
                           (T_rollouts[0] // recursions_per_chunks)]
            res['state_mses_backsolve'][j] = ((res['x_sols_backsolve'][j] -
                                               x_target)**2).mean(axis=1)
            res['time_vals_backsolve'][j] = time.time(
            ) - res['time_vals_backsolve'][0]
        for j in range(recursions_per_chunks, n_chunks_recursive):
            res['x_sols_backsolve'][j] = res['x_sols_backsolve'][
                recursions_per_chunks - 1]
            res['time_vals_backsolve'][j] = res['time_vals_backsolve'][
                recursions_per_chunks - 1]
            res['state_mses_backsolve'][j] = ((res['x_sols_backsolve'][j] -
                                               out[n_starts])**2).mean(axis=1)

        print('\n')
        print('storing results')
        print('\n')
        np.save(res_dir + fn, arr=res)

    # ## L-BFGS, solve across full rollout time in one go, initiate from backward solution

    if optimiziation_schemes['LBFGS_full_backsolve']:

        print('\n')
        print(
            'L-BFGS, solve across full rollout time in one go, initiate from backward solution'
        )
        print('\n')

        x_inits = res['x_sols_backsolve'][recursions_per_chunks -
                                          1:][::recursions_per_chunks]
        x_inits = [sortL96intoChannels(z, J=J).copy() for z in x_inits]

        opt_res = optim_initial_state(gen,
                                      T_rollouts=T_rollouts,
                                      T_obs=T_obs,
                                      N=N,
                                      n_chunks=n_chunks,
                                      optimizer_pars=optimizer_pars,
                                      x_inits=x_inits,
                                      targets=targets,
                                      grndtrths=grndtrths,
                                      loss_masks=loss_masks)

        res['x_sols_LBFGS_full_backsolve'] = opt_res[0]
        res['loss_vals_LBFGS_full_backsolve'] = opt_res[1]
        res['time_vals_LBFGS_full_backsolve'] = opt_res[2]
        res['state_mses_LBFGS_full_backsolve'] = opt_res[3]

        print('\n')
        print('storing results')
        print('\n')
        np.save(res_dir + fn, arr=res)

    # ## L-BFGS, solve across full rollout time in one go
    # - warning, this can be excruciatingly slow and hard to converge !

    if optimiziation_schemes['LBFGS_full_persistence']:

        print('\n')
        print('L-BFGS, solve across full rollout time in one go')
        print('\n')

        x_init = get_init(sortL96intoChannels(res['targets_obs'], J=J)[0],
                          res['loss_mask'][0],
                          method='interpolate')
        x_inits = [x_init for j in range(n_chunks)]

        opt_res = optim_initial_state(gen,
                                      T_rollouts=T_rollouts,
                                      T_obs=T_obs,
                                      N=N,
                                      n_chunks=n_chunks,
                                      optimizer_pars=optimizer_pars,
                                      x_inits=x_inits,
                                      targets=targets,
                                      grndtrths=grndtrths,
                                      loss_masks=loss_masks)

        res['x_sols_LBFGS_full_persistence'] = opt_res[0]
        res['loss_vals_LBFGS_full_persistence'] = opt_res[1]
        res['time_vals_LBFGS_full_persistence'] = opt_res[2]
        res['state_mses_LBFGS_full_persistence'] = opt_res[3]

        print('\n')
        print('storing results')
        print('\n')
        np.save(res_dir + fn, arr=res)

    print('\n')
    print('done')
    print('\n')
Exemplo n.º 13
0
def run_exp_4DVar(exp_id, datadir, res_dir, T_win, T_shift, B, K, J, T,
                  N_trials, dt, spin_up_time, l96_F, l96_h, l96_b, l96_c,
                  obs_operator, obs_operator_r, obs_operator_sig2,
                  obs_operator_frq, model_exp_id, model_forwarder, optimizer,
                  n_steps, lr, max_iter, max_eval, tolerance_grad,
                  tolerance_change, history_size):

    fetch_commit = subprocess.Popen(['git', 'rev-parse', 'HEAD'],
                                    shell=False,
                                    stdout=subprocess.PIPE)
    commit_id = fetch_commit.communicate()[0].strip().decode("utf-8")
    fetch_commit.kill()

    T_shift = T_win if T_shift < 0 else T_shift  # backwards compatibility to older exps with T_shift=T_win

    model_pars = {
        'exp_id': model_exp_id,
        'model_forwarder': model_forwarder,
        'K_net': K,
        'J_net': J,
        'dt_net': dt
    }

    optimizer_pars = {
        'optimizer': optimizer,
        'n_steps': n_steps,
        'lr': lr,
        'max_iter': max_iter,
        'max_eval': None if max_eval < 0 else max_eval,
        'tolerance_grad': tolerance_grad,
        'tolerance_change': tolerance_change,
        'history_size': history_size
    }

    if model_forwarder == 'rk4_default':
        model_forwarder = rk4_default
    if model_forwarder == 'predictor_corrector':
        model_forwarder = predictor_corrector

    out, datagen_dict = get_data(K=K,
                                 J=J,
                                 T=T + spin_up_time,
                                 dt=dt,
                                 N_trials=N_trials,
                                 F=l96_F,
                                 h=l96_h,
                                 b=l96_b,
                                 c=l96_c,
                                 resimulate=True,
                                 solver=model_forwarder,
                                 save_sim=False,
                                 data_dir='')
    out = sortL96intoChannels(out.transpose(1, 0, 2)[int(spin_up_time / dt):],
                              J=J)
    print('out.shape', out.shape)

    model, model_forwarder, args = get_model(model_pars,
                                             res_dir=res_dir,
                                             exp_dir='')

    obs_pars = {}
    if obs_operator == 'ObsOp_subsampleGaussian':
        obs_pars['obs_operator'] = ObsOp_subsampleGaussian
        obs_pars['obs_operator_args'] = {
            'r': obs_operator_r,
            'sigma2': obs_operator_sig2
        }
    elif obs_operator == 'ObsOp_identity':
        obs_pars['obs_operator'] = ObsOp_identity
        obs_pars['obs_operator_args'] = {}
    elif obs_operator == 'ObsOp_rotsampleGaussian':
        obs_pars['obs_operator'] = ObsOp_rotsampleGaussian
        obs_pars['obs_operator_args'] = {
            'frq': obs_operator_frq,
            'sigma2': obs_operator_sig2
        }
    else:
        raise NotImplementedError()

    model_observer = obs_pars['obs_operator'](**obs_pars['obs_operator_args'])

    prior = torch.distributions.normal.Normal(loc=torch.zeros((1, J + 1, K)),
                                              scale=B * torch.ones(
                                                  (1, J + 1, K)))
    gen = GenModel(model_forwarder, model_observer, prior)

    y = sortL96fromChannels(gen._sample_obs(
        as_tensor(out)))  # sets the loss masks!
    m = torch.stack(gen.masks, dim=0)

    save_dir = 'results/data_assimilation/' + exp_id + '/'
    mkdir_from_path(res_dir + save_dir)

    open(res_dir + save_dir + commit_id + '.txt', 'w')
    fn = save_dir + 'res'

    print('4D-VAR')
    x_sols, losses, times = solve_4dvar(y,
                                        m,
                                        T_obs=np.arange(y.shape[0]),
                                        T_win=T_win,
                                        T_shift=T_shift,
                                        x_init=None,
                                        model_pars=model_pars,
                                        obs_pars=obs_pars,
                                        optimizer_pars=optimizer_pars,
                                        res_dir=res_dir)

    np.save(
        res_dir + save_dir + 'out', {
            'out': out,
            'y': y.detach().cpu().numpy(),
            'm': m.detach().cpu().numpy(),
            'x_sols': x_sols,
            'losses': losses,
            'times': times,
            'T_win': T_win,
            'T_shift': T_shift
        })

    print('x_sols.shape', x_sols.shape)
    print('done')
Exemplo n.º 14
0
def run_exp(exp_id, datadir, res_dir, K, K_local, J, T, N_trials, dt, n_local,
            prediction_task, lead_time, seq_length, train_frac,
            validation_frac, spin_up_time, model_name, loss_fun, weight_decay,
            batch_size, batch_size_eval, max_epochs, eval_every, max_patience,
            lr, lr_min, lr_decay, max_lr_patience, only_eval, normalize_data,
            **net_kwargs):

    fetch_commit = subprocess.Popen(['git', 'rev-parse', 'HEAD'],
                                    shell=False,
                                    stdout=subprocess.PIPE)
    commit_id = fetch_commit.communicate()[0].strip().decode("utf-8")
    fetch_commit.kill()

    # load data
    fn_data = f'out_K{K}_J{J}_T{T}_N{N_trials}_dt0_{str(dt)[2:]}'
    out = np.load(datadir + fn_data + '.npy')
    print('data.shape', out.shape)
    assert (out.shape[1] - 1) * dt == T

    K_local = K if K_local < 0 else K_local
    DatasetClass = sel_dataset_class(prediction_task,
                                     N_trials,
                                     local=(K_local < K))
    test_frac = 1. - (train_frac + validation_frac)
    assert test_frac >= 0.
    spin_up = int(spin_up_time / dt)
    dg_args = {
        'data': out,
        'J': J,
        'offset': lead_time,
        'normalize': bool(normalize_data)
    }
    if DatasetClass == DatasetMultiTrial_shattered:
        dg_args['K_local'] = K_local
        dg_args['n_local'] = n_local

    dg_train = DatasetClass(start=spin_up,
                            end=int(np.floor(T / dt * train_frac)),
                            **dg_args)
    dg_val = DatasetClass(
        start=int(np.ceil(T / dt * train_frac)),
        end=int(np.ceil(T / dt * (train_frac + validation_frac))),
        **dg_args)

    print('N training data:', len(dg_train))
    print('N validation data:', len(dg_val))

    batch_size_eval = batch_size if batch_size_eval < 1 else batch_size_eval
    validation_loader = torch.utils.data.DataLoader(dg_val,
                                                    batch_size=batch_size_eval,
                                                    drop_last=False,
                                                    num_workers=0)
    train_loader = torch.utils.data.DataLoader(dg_train,
                                               batch_size=batch_size,
                                               drop_last=True,
                                               num_workers=0)

    ## define model
    print('net_kwargs', net_kwargs)
    model_fn = f'{exp_id}_dt{lead_time}.pt'

    model, model_forward = named_network(model_name=model_name,
                                         n_input_channels=J + 1,
                                         n_output_channels=J + 1,
                                         seq_length=seq_length,
                                         **net_kwargs)
    try:
        print('model.layer1.weights', model.layer1.weight.shape)
    except:
        pass
    if K_local < K:
        test_input = np.random.normal(size=(10, seq_length * (J + 1),
                                            K_local + 3 * n_local))
    else:
        test_input = np.random.normal(size=(10, seq_length * (J + 1), K))
    print(f'model output shape to test input of shape {test_input.shape}',
          model_forward(as_tensor(test_input)).shape)
    print(
        'total #parameters: ',
        np.sum([np.prod(item.shape) for item in model.state_dict().values()]))

    ## train model
    save_dir = res_dir + 'models/' + exp_id + '/'
    if only_eval:
        print('loading model from disk')
        model.load_state_dict(
            torch.load(save_dir + model_fn, map_location=torch.device(device)))

    else:  # actually train

        mkdir_from_path(save_dir)
        print('saving model state_dict to ' + save_dir + model_fn)
        open(save_dir + commit_id + '.txt', 'w')

        output_fn = '_training_outputs'
        extra_args = {}
        if loss_fun == 'local_mse':
            extra_args = {
                'n_local': n_local,
                'pad_local': (2, 2) if J == 1 else
                (2, 1)  # relevant local area for L96 diff.eq. 
            }
        print('loss_fun', loss_fun)
        loss_fun = loss_function(loss_fun, extra_args=extra_args)
        print('loss_fun', loss_fun)
        training_outputs = train_model(model,
                                       train_loader,
                                       validation_loader,
                                       device,
                                       model_forward,
                                       loss_fun=loss_fun,
                                       weight_decay=weight_decay,
                                       max_epochs=max_epochs,
                                       max_patience=max_patience,
                                       lr=lr,
                                       lr_min=lr_min,
                                       lr_decay=lr_decay,
                                       max_lr_patience=max_lr_patience,
                                       eval_every=eval_every,
                                       verbose=True,
                                       save_dir=save_dir,
                                       model_fn=model_fn,
                                       output_fn=output_fn)
        print('saving full model to ' + save_dir + model_fn[:-3] +
              '_full_model.pt')
        torch.save(model, save_dir + model_fn[:-3] + '_full_model.pt')
Exemplo n.º 15
0
def run_exp_parametrization(exp_id, datadir, res_dir, parametrization,
                            n_hiddens, kernel_size, K, J, T, dt, spin_up_time,
                            l96_F, l96_h, l96_b, l96_c, train_frac,
                            validation_frac, offset, model_exp_id,
                            model_forwarder, loss_fun, batch_size, eval_every,
                            lr, lr_min, lr_decay, weight_decay, max_epochs,
                            max_patience, max_lr_patience):

    fetch_commit = subprocess.Popen(['git', 'rev-parse', 'HEAD'],
                                    shell=False,
                                    stdout=subprocess.PIPE)
    commit_id = fetch_commit.communicate()[0].strip().decode("utf-8")
    fetch_commit.kill()

    # load model
    model_pars = {
        'exp_id': model_exp_id,
        'model_forwarder': model_forwarder,
        'K_net': K,
        'J_net': 0,
        'dt_net': dt
    }

    if model_forwarder == 'rk4_default':
        model_forwarder = rk4_default
    elif model_forwarder == 'predictor_corrector':
        model_forwarder = predictor_corrector

    model, model_forwarder, args = get_model(model_pars,
                                             res_dir=res_dir,
                                             exp_dir='')

    # instantiate parametrizations
    if parametrization == 'linear':
        param_train = Parametrization_lin(a=as_tensor(np.array([-0.75])),
                                          b=as_tensor(np.array([-0.4])))
        param_offline = Parametrization_lin(a=as_tensor(np.array([-0.75])),
                                            b=as_tensor(np.array([-0.4])))
    elif parametrization == 'nn':
        param_train = Parametrization_nn(n_hiddens=n_hiddens,
                                         kernel_size=kernel_size)
        param_offline = Parametrization_nn(n_hiddens=n_hiddens,
                                           kernel_size=kernel_size)
        # make sure they share initialization:
        param_offline.load_state_dict(copy.deepcopy(param_train.state_dict()))
    else:
        raise NotImplementedError()
    for p in model.parameters():
        p.requires_grad = False
    model_parametrized = Parametrized_twoLevel_L96(emulator=model,
                                                   parametrization=param_train)
    model_forwarder_parametrized = Model_forwarder_rk4default(
        model=model_parametrized, dt=dt)

    print('torch.nn.Parameters of parametrization require grad: ')
    for p in model_forwarder_parametrized.model.param.parameters():
        print(p.requires_grad)
    print('torch.nn.Parameters of emulator require grad: ')
    for p in model_forwarder_parametrized.model.emulator.parameters():
        print(p.requires_grad)

    if len(offset) > 1:  # multi-step predictions
        print('multi-step predictions')
        gm = GenModel(model_forwarder=model_forwarder_parametrized,
                      model_observer=ObsOp_identity(),
                      prior=SimplePrior(J=0, K=K))

        class MultiStepForwarder(torch.nn.Module):
            def __init__(self, model, offset):
                super(MultiStepForwarder, self).__init__()
                self.model = model
                self.offset = offset

            def forward(self, x):
                return torch.stack(gm._forward(x=x, T_obs=self.offset), dim=1)

        model_forwarder_parametrized = MultiStepForwarder(
            model=gm, offset=np.asarray(offset))
        print('model forwarder', model_forwarder_parametrized)

    if parametrization == 'linear':
        print('initialized a', model_parametrized.param.a)
        print('initialized b', model_parametrized.param.b)
    elif parametrization == 'nn':
        print('initialized first-layer weights',
              model_parametrized.param.layers[0].weight)

    # ground-truth two-level L96 model (based on Numba implementation):
    dX_dt = np.empty(K * (J + 1), dtype=dtype_np)
    if J > 0:

        def fun(t, x):
            return f2(x, l96_F, l96_h, l96_b, l96_c, dX_dt, K, J)
    else:

        def fun(t, x):
            return f1(x, l96_F, dX_dt, K)

    class Torch_solver(torch.nn.Module):
        # numerical solver (from numpy/numba/Julia)
        def __init__(self, fun):
            self.fun = fun

        def forward(self, x):
            x = sortL96fromChannels(x.detach().cpu().numpy()).flatten()
            return as_tensor(
                sortL96intoChannels(np.atleast_2d(self.fun(0., x)), J=J))

    model_forwarder_np = Model_forwarder_rk4default(Torch_solver(fun), dt=dt)

    # create some training data from the true two-level L96
    X_init = 0.5 + np.random.randn(1, K * (J + 1)) * 1.0
    X_init = l96_F * X_init.astype(dtype=dtype_np) / np.maximum(J, 50)

    def model_simulate(y0, dy0, n_steps):
        x = np.empty((n_steps + 1, *y0.shape[1:]), dtype=dtype_np)
        x[0] = y0.copy()
        xx = as_tensor(x[0]).reshape(1, 1, -1)
        for i in range(1, n_steps + 1):
            xx = model_forwarder_np(xx.reshape(1, J + 1, -1))
            x[i] = xx.detach().cpu().numpy().copy()
        return x

    T_dur = int(T / dt)
    spin_up = int(spin_up_time / dt)
    print('simulating high-res (two-level L96) data')
    data_full = model_simulate(y0=sortL96intoChannels(X_init, J=J),
                               dy0=None,
                               n_steps=T_dur + spin_up)
    print('full data shape: ', data_full.shape)
    assert np.all(np.isfinite(data_full))

    # offline training of parametrization

    print('offline training')
    X = sortL96intoChannels(data_full[:, 0, :], J=model_pars['J_net'])
    y = -l96_h * l96_c * sortL96intoChannels(data_full[:, 1:, :].mean(axis=1),
                                             J=model_pars['J_net'])
    dg_train = Dataset_offline(
        data=(X, y),
        start=spin_up,
        end=spin_up + int(np.floor(T_dur * train_frac)) - np.max(offset))
    print('len dg_train', len(dg_train))
    train_loader = torch.utils.data.DataLoader(dg_train,
                                               batch_size=batch_size,
                                               drop_last=True,
                                               num_workers=0)
    dg_val = Dataset_offline(
        data=(X, y),
        start=spin_up + int(np.floor(T_dur * train_frac)),
        end=spin_up + int(np.ceil(T_dur * (train_frac + validation_frac))) -
        np.max(offset))
    print('len dg_val', len(dg_val))
    validation_loader = torch.utils.data.DataLoader(dg_val,
                                                    batch_size=batch_size,
                                                    drop_last=False,
                                                    num_workers=0)

    print('starting optimization of parametrization')
    training_outputs_offline = train_model(model=param_offline,
                                           train_loader=train_loader,
                                           validation_loader=validation_loader,
                                           device=device,
                                           model_forward=param_offline,
                                           loss_fun=loss_function(
                                               loss_fun=loss_fun,
                                               extra_args={}),
                                           lr=lr,
                                           lr_min=lr_min,
                                           lr_decay=lr_decay,
                                           weight_decay=weight_decay,
                                           max_epochs=max_epochs,
                                           max_patience=max_patience,
                                           max_lr_patience=max_lr_patience,
                                           eval_every=eval_every)

    if parametrization == 'linear':
        print('learned a', param_offline.a)
        print('learned b', param_offline.b)
    elif parametrization == 'nn':
        print('initialized first-layer weights',
              param_offline.layers[0].weight)

    # online training of parametrization

    print('online training')
    # two-level simulates for fast and slow variables, we only take the slow ones for training !
    data = data_full[:, 0, :]
    data = data.reshape(1, *data.shape)  # N x T x K*(J+1)
    print('training data shape: ', data_full.shape)

    DatasetClass = sel_dataset_class(prediction_task='state',
                                     N_trials=1,
                                     local=False,
                                     offset=offset)
    print('dataset class', DatasetClass)
    print('len(offset)', len(offset))
    assert train_frac + validation_frac <= 1.

    dg_dict = {
        'data': data,
        'J': 0,
        'offset': offset[0] if len(offset) == 1 else offset,
        'normalize': False
    }

    dg_train = DatasetClass(start=spin_up,
                            end=spin_up + int(np.floor(T_dur * train_frac)) -
                            np.max(offset),
                            **dg_dict)
    print('len dg_train', len(dg_train))
    train_loader = torch.utils.data.DataLoader(dg_train,
                                               batch_size=batch_size,
                                               drop_last=True,
                                               num_workers=0)
    dg_val = DatasetClass(
        start=spin_up + int(np.floor(T_dur * train_frac)),
        end=spin_up + int(np.ceil(T_dur * (train_frac + validation_frac))) -
        np.max(offset),
        **dg_dict)
    print('len dg_val', len(dg_val))
    validation_loader = torch.utils.data.DataLoader(dg_val,
                                                    batch_size=batch_size,
                                                    drop_last=False,
                                                    num_workers=0)

    print('starting optimization of parametrization')
    training_outputs = train_model(model=model_forwarder_parametrized,
                                   train_loader=train_loader,
                                   validation_loader=validation_loader,
                                   device=device,
                                   model_forward=model_forwarder_parametrized,
                                   loss_fun=loss_function(loss_fun=loss_fun,
                                                          extra_args={}),
                                   lr=lr,
                                   lr_min=lr_min,
                                   lr_decay=lr_decay,
                                   weight_decay=weight_decay,
                                   max_epochs=max_epochs,
                                   max_patience=max_patience,
                                   max_lr_patience=max_lr_patience,
                                   eval_every=eval_every)

    if parametrization == 'linear':
        print('learned a', model_parametrized.param.a)
        print('learned b', model_parametrized.param.b)
    elif parametrization == 'nn':
        print('initialized first-layer weights',
              model_parametrized.param.layers[0].weight)

    save_dir = 'results/parametrization/' + exp_id + '/'
    mkdir_from_path(res_dir + save_dir)

    open(res_dir + save_dir + commit_id + '.txt', 'w')

    state_dict = param_train.state_dict()
    for key, value in state_dict.items():
        state_dict[key] = value.detach().cpu().numpy()

    state_dict_offline = param_offline.state_dict()
    for key, value in state_dict_offline.items():
        state_dict_offline[key] = value.detach().cpu().numpy()

    np.save(
        res_dir + save_dir + 'out', {
            'data_full': data_full,
            'X_init': data_full[-1].reshape(1, -1),
            'param_train_state_dict': state_dict,
            'param_offline_state_dict': state_dict_offline,
            'X': X,
            'y': y
        })
    print('done')
Exemplo n.º 16
0
 def forward(self, x):
     x = sortL96fromChannels(x.detach().cpu().numpy()).flatten()
     return as_tensor(
         sortL96intoChannels(np.atleast_2d(self.fun(0., x)), J=J))