Example #1
0
    def train(self,
              xs,
              sys,
              nepochs,
              smooth_solver_kwargs=None,
              learner_criterion_kwargs=None,
              learner_opt_kwargs=None,
              subset=None,
              hp_schedulers=None):
        """ Runs one step of CEEM algorithm

        Args:
          xs (torch.tensor):
          sys (DiscreteDynamicalSystem):
          nepochs (int): Number of epochs to run CEEM algorithm for
          smooth_solver_kwargs (dict):
          learner_criterion_kwargs (dict):
          learner_opt_kwargs (dict):
          subset (int):
        """
        if hp_schedulers is not None:
            if not isinstance(hp_schedulers, dict):
                assert len(hp_schedulers) == len(self._learning_criteria)

        for epoch in range(nepochs):
            xs, smooth_metrics, learn_metrics = self.step(
                xs,
                sys,
                smooth_solver_kwargs,
                learner_criterion_kwargs,
                learner_opt_kwargs,
                subset=subset,
                hp_schedulers=hp_schedulers)

            # Log all metrics. Run tensorboard to visualize
            log_kv_or_listkv(smooth_metrics, "smooth")
            log_kv_or_listkv(learn_metrics, "learn")

            for ecall in self._epoch_callbacks:
                ecall(epoch)

            if self._termination_callback(epoch):
                break

            logger.dumpkvs()
Example #2
0
    def train(self, params, callbacks=[]):

        vparams0 = parameters_to_vector(params).clone()

        xsms = []

        t_start = timeit.default_timer()

        for k in range(self._max_k):

            with utils.Timer() as time:
                ## E-step
                xfilt, xfiltr, wfilt, meanll = self._fapf.filter(self._y)
                xsm = self._fapf.FFBSi(xfilt, wfilt)
                xsms.append(xsm)
                if self._xlen_cutoff:
                    if len(xsms) > self._xlen_cutoff:
                        xsms = xsms[-self._xlen_cutoff:]

            logger.logkv('train/Etime', time.dt)

            ## M-step

            with utils.Timer() as time:
                obj = lambda: -self.recursive_Q(xsms, self._y, 0, 0.)

                self._optimizer(obj, params)

            logger.logkv('train/Mtime', time.dt)

            logger.logkv('train/elapsedtime', timeit.default_timer() - t_start)

            ## log the current value of Q

            Q = float(self._fapf.Q_MCEM(self._y, xsms[-1]))
            logger.logkv('train/Q', Q)

            for callback in callbacks:
                callback(k)

            logger.dumpkvs()

        return params
Example #3
0
def train(seed, logdir, sys_seed, k, b):
    from mpl_toolkits.mplot3d import Axes3D
    import matplotlib.pyplot as plt

    ystd = 0.01
    # ystd = 0.

    torch.set_default_dtype(torch.float64)

    logger.setup(logdir, action='d')

    N = 128

    n = 3 * k

    B = b

    true_system = default_lorenz_system(k, obsdif=2)

    utils.set_rng_seed(sys_seed)

    xdim = true_system.xdim
    ydim = true_system.ydim

    dt = true_system._dt

    x0mean = torch.tensor([[-6] * k + [-6] * k + [24.] * k]).unsqueeze(0)

    # simulate true_dynamics over IC distribution
    x_test = x0mean.repeat(1024, 1, 1)
    x_test += 5.0 * torch.randn_like(x_test)
    x_test = x_test.detach()
    t_test = torch.zeros(1024, 1)
    tgt_test = true_system.step_derivs(t_test, x_test).detach()

    ## simulate the true system

    xs = [x0mean.repeat(B, 1, 1)]
    xs[0] += 2.5 * torch.randn_like(xs[0])
    with torch.no_grad():
        for t in range(N - 1):
            xs.append(true_system.step(torch.tensor([0.] * B), xs[-1]))

    xs = torch.cat(xs, dim=1)

    fig = plt.figure()
    for b in range(B):
        ax = fig.add_subplot(int(np.ceil(B / 2.)), 2, b + 1, projection='3d')

        for k_ in range(k):
            plot3d(plt.gca(),
                   xs[b, :, k_],
                   xs[b, :, k + k_],
                   xs[b, :, 2 * k + k_],
                   linestyle='--',
                   alpha=0.5)

    plt.savefig(os.path.join(logger.get_dir(), 'figs/traj_%d.png' % b),
                dpi=300)
    # plt.show()
    plt.close()

    t = torch.tensor(range(N)).unsqueeze(0).expand(B, -1).to(
        torch.get_default_dtype())

    y = true_system.observe(t, xs).detach()

    # seed for real now
    utils.set_rng_seed(seed)

    y += ystd * torch.randn_like(y)

    # prep system
    system = deepcopy(true_system)

    true_params = parameters_to_vector(true_system.parameters())

    utils.set_rng_seed(seed)

    params = true_params * (
        (torch.rand_like(true_params) - 0.5) / 5. + 1.)  # within 10%

    vector_to_parameters(params, system.parameters())

    params = list(system.parameters())

    # specify smoothing criteria

    smoothing_criteria = []

    for b in range(B):

        obscrit = GaussianObservationCriterion(1.0 * torch.ones(ydim),
                                               t[b:b + 1], y[b:b + 1])

        dyncrit = GaussianDynamicsCriterion(1e0 * torch.ones(xdim), t[b:b + 1])

        smoothing_criteria.append(GroupSOSCriterion([obscrit, dyncrit]))

    smooth_solver_kwargs = {'verbose': 0, 'tr_rho': 0.01}

    # specify learning criteria
    learning_criteria = [GaussianDynamicsCriterion(1e0 * torch.ones(xdim), t)]
    learning_params = [params]
    # learning_opts = ['scipy_minimize']
    # learner_opt_kwargs = {'method': 'Nelder-Mead', 'tr_rho': 0.1,
    #                         'options':{'adaptive':True}}
    # learner_opt_kwargs = {'method': 'BFGS', 'tr_rho': 0.1,
    #                         'options':{'disp':True}}
    learning_opts = ['torch_minimize']
    # learner_opt_kwargs = {
    #     'method': 'Adam',
    #     'lr': 5e-4,
    #     'tr_rho': 0.1,
    #     'nepochs': 200,
    #     'max_grad_norm': 10.0
    # }
    learner_opt_kwargs = {'method': 'LBFGS'}

    # instantiate CEEM

    def ecb(epoch):

        params = list(system.parameters())
        vparams = parameters_to_vector(params)

        error = (vparams - true_params).norm().item()

        logger.logkv('test/log10_paramerror', np.log10(error))

        return

    epoch_callbacks = [ecb]

    class Last10Errors:
        def __init__(self):
            return

    last_10_errors = Last10Errors
    last_10_errors._arr = []

    def tcb(epoch):

        with torch.no_grad():
            tgt_test_pr = system.step_derivs(t_test, x_test)
            error = float(torch.nn.functional.mse_loss(tgt_test_pr, tgt_test))

        logger.logkv('test/log10_error', np.log10(error))

        last_10_errors._arr.append(np.log10(error))

        if len(last_10_errors._arr) > 100:
            last_10_errors._arr = last_10_errors._arr[-100:]

            l10err = torch.tensor(last_10_errors._arr)

            convcrit = float((l10err.min() - l10err.max()).abs())
            logger.logkv('test/log10_convcrit', np.log10(convcrit))
            if convcrit < 1e-3:
                return True

        return False

    termination_callback = tcb

    ecb(-1)
    tcb(-1)
    logger.dumpkvs()

    ceem = CEEM(smoothing_criteria,
                learning_criteria,
                learning_params,
                learning_opts,
                epoch_callbacks,
                termination_callback,
                parallel=min(4, B))

    # run CEEM

    # x0 = torch.zeros_like(xs)
    x0 = xs + torch.randn_like(xs)

    ceem.train(xs=x0,
               sys=system,
               nepochs=5000,
               smooth_solver_kwargs=smooth_solver_kwargs,
               learner_opt_kwargs=learner_opt_kwargs)

    return
Example #4
0
def train_net(net,
              train_data,
              train_trgt,
              test_data,
              test_trgt,
              valid_data,
              valid_trgt,
              y_std,
              lr,
              H,
              logdir,
              n_epochs=1000):

    T = train_data.shape[1]

    opt = torch.optim.Adam(net.parameters(), lr=lr)
    scheduler_off = 1000.
    scheduler = torch.optim.lr_scheduler.LambdaLR(
        opt, lambda epoch: scheduler_off / (scheduler_off + epoch))

    best_val_loss = np.inf

    t0 = time.time()
    for e in range(n_epochs):
        logger.logkv('epoch', e)
        # train
        ll = []
        coord_error = 0
        for t in range(T - H + 1):
            opt.zero_grad()

            u = train_data[:, t:t + H]
            y = train_trgt[:, t + H - 1]
            y_pred = net(u)
            assert y.size() == y_pred.size()
            loss = compute_rms(y.unsqueeze(0), y_pred.unsqueeze(0), y_std)
            loss.backward()
            opt.step()
            ll.append(float(loss))
        mean_train_loss = np.mean(ll)
        logger.logkv('log10_train_loss', np.log10(mean_train_loss))
        coord_error /= (T - H)

        scheduler.step()

        for param_group in opt.param_groups:
            logger.logkv('log10_lr', np.log10(param_group['lr']))

        if e % 100 == 0:
            # validation
            ll = []
            coord_error = 0
            for t in range(T - H):
                with torch.no_grad():
                    u = valid_data[:, t:t + H]
                    y = valid_trgt[:, t + H - 1]
                    y_pred = net(u)
                    loss = compute_rms(y.unsqueeze(0), y_pred.unsqueeze(0),
                                       y_std)
                    ll.append(float(loss))
            mean_val_loss = np.mean(ll)
            logger.logkv('log10_val_loss', np.log10(mean_val_loss))
            coord_error /= (T - H)

            # Test
            ll = []
            coord_error = 0
            for t in range(T - H):
                with torch.no_grad():
                    u = test_data[:, t:t + H]
                    y = test_trgt[:, t + H - 1]
                    y_pred = net(u)
                    loss = compute_rms(y.unsqueeze(0), y_pred.unsqueeze(0),
                                       y_std)
                ll.append(float(loss))
            mean_test_loss = np.mean(ll)
            logger.logkv('log10_test_loss', np.log10(mean_test_loss))

            # Save
            if mean_val_loss < best_val_loss:
                torch.save(net.state_dict(),
                           os.path.join(logdir, 'best_net.th'))
                best_val_loss = mean_val_loss

        if time.time() - t0 > 2:
            t0 = time.time()
            logger.dumpkvs()

    return net
Example #5
0
def train(seed, logdir, ystd=0.1, wstd=0.01, sys_seed=4):

    print('\n\n\n##### SEED %d #####\n\n'%seed)

    torch.set_default_dtype(torch.float64)

    logger.setup(logdir, action='d')
    
    # Number of timesteps in the trajectory
    T = 128

    n = 3

    # Batch size
    B = 4

    k = 1

    utils.set_rng_seed(sys_seed)

    sys = default_lorenz_attractor()

    dt = sys._dt

    utils.set_rng_seed(seed)

    # simulate the system

    x0mean = torch.tensor([[-6] * k + [-6] * k + [24.] * k])
    x0mean = x0mean.unsqueeze(0).repeat(B,1,1)

    # Rollout with noise

    Q = (wstd ** 2) * torch.eye(sys.xdim)
    R = (ystd ** 2) * torch.eye(sys.ydim)
    Px0 = 5.0 * torch.eye(sys.xdim)

    Qpdf = MultivariateNormal(torch.zeros((B,1,sys.xdim)), Q.unsqueeze(0).unsqueeze(0))
    Rpdf = MultivariateNormal(torch.zeros((B,1,sys.ydim)), R.unsqueeze(0).unsqueeze(0))
    Px0pdf = MultivariateNormal(x0mean, Px0.unsqueeze(0).unsqueeze(0))


    xs = [Px0pdf.sample()]
    ys = [sys.observe(0, xs[0]) + Rpdf.sample()]

    for t in range(T-1):

        tinp = torch.tensor([t] * B).unsqueeze(1).to(dtype=torch.get_default_dtype())
        xs.append(sys.step(tinp, xs[-1]) + Qpdf.sample())
        ys.append(sys.observe(tinp, xs[-1]) + Rpdf.sample())

    x = torch.cat(xs, dim=1)
    y = torch.cat(ys, dim=1)


    t = torch.tensor(range(T)).unsqueeze(0).to(torch.get_default_dtype()).repeat(B,1)

    m = y.shape[-1]


    fig = plt.figure()
    for b in range(B):
        ax = fig.add_subplot(int(np.ceil(B / 2.)), 2, b + 1, projection='3d')

        for k_ in range(k):
            plot3d(plt.gca(), x[b, :, k_], x[b, :, k + k_], x[b, :, 2 * k + k_], linestyle='--',
                   alpha=0.5)

    plt.savefig(opj(logdir, 'traintrajs.png'), dpi=300)

    
    # prep system
    true_system = sys

    system = deepcopy(true_system)

    true_params = parameters_to_vector(true_system.parameters())

    params = true_params * ((torch.rand_like(true_params) - 0.5) / 5. + 1.)  # within 10%

    vector_to_parameters(params, system.parameters())

    params = list(system.parameters())

    # specify smoothing criteria

    smoothing_criteria = []

    for b in range(B):

        obscrit = GaussianObservationCriterion(torch.ones(2), t[b:b + 1], y[b:b + 1])

        dyncrit = GaussianDynamicsCriterion(wstd / ystd * torch.ones(3), t[b:b + 1])

        smoothing_criteria.append(GroupSOSCriterion([obscrit, dyncrit]))

    smooth_solver_kwargs = {'verbose': 0, 'tr_rho': 0.001}

    # specify learning criteria
    learning_criteria = [GaussianDynamicsCriterion(torch.ones(3), t)]
    learning_params = [params]
    learning_opts = ['scipy_minimize']
    learner_opt_kwargs = {'method': 'Nelder-Mead', 'tr_rho': 0.01}

    # instantiate CEEM



    timer = {'start_time':timeit.default_timer()}

    def ecb(epoch):
        logger.logkv('test/rho', float(system._rho))
        logger.logkv('test/sigma', float(system._sigma))
        logger.logkv('test/beta', float(system._beta))

        logger.logkv('test/rho_pcterr_log10',
                     float(torch.log10((true_system._rho - system._rho).abs() / true_system._rho)))
        logger.logkv(
            'test/sigma_pcterr_log10',
            float(torch.log10((true_system._sigma - system._sigma).abs() / true_system._sigma)))
        logger.logkv(
            'test/beta_pcterr_log10',
            float(torch.log10((true_system._beta - system._beta).abs() / true_system._beta)))


        logger.logkv('time/epochtime', timeit.default_timer() - timer['start_time'])

        timer['start_time'] = timeit.default_timer()

        return

    epoch_callbacks = [ecb]

    class Last10Errors:

        def __init__(self):
            return

    last_10_errors = Last10Errors
    last_10_errors._arr = []

    def tcb(epoch):

        params = list(system.parameters())
        vparams = parameters_to_vector(params)

        error = (vparams - true_params).norm().item()

        last_10_errors._arr.append(float(error))

        logger.logkv('test/log10_error', np.log10(error))

        if len(last_10_errors._arr) > 10:
            last_10_errors._arr = last_10_errors._arr[-10:]

            l10err = torch.tensor(last_10_errors._arr)

            convcrit = float((l10err.min() - l10err.max()).abs())
            logger.logkv('test/log10_convcrit', np.log10(convcrit))
            if convcrit < 1e-4:
                return True

        return False

    termination_callback = tcb

    ceem = CEEM(smoothing_criteria, learning_criteria, learning_params, learning_opts,
                epoch_callbacks, termination_callback)

    # run CEEM

    x0 = torch.zeros_like(x)

    ecb(-1)
    logger.dumpkvs()

    ceem.train(xs=x0, sys=system, nepochs=100, smooth_solver_kwargs=smooth_solver_kwargs,
               learner_opt_kwargs=learner_opt_kwargs)

    return float(system._sigma), float(system._rho), float(system._beta)
Example #6
0
def run(seed, lr, method, noise, damped, smoketest=False):
    """
    Args:
        seed (int): seed
        lr (float): init learning rate
        method (str): training method in ['qdd', 'nqqd', 'del+mnorm', 'del+logdet']
        noise (float): amount of noise in [0.01,0.05, 0.1]
        damped (bool): use damped pendulum data
        smoketest (bool): if smoketest, runs 2 epochs
    """
    torch.set_default_dtype(torch.float64)
    dtype = torch.get_default_dtype()

    torch.manual_seed(seed)

    # load the data
    dataset = 'damped_' if damped else ''
    noisedict = {
        0.01: '0p01',
        0.05: '0p05',
        0.10: '0p10',
        0.20: '0p20',
        0.30: '0p30',
        0.40: '0p40',
        0.50: '0p50',
        1.0: '1p0'
    }
    dataset += 'dubpen_%s_smoothed.td' % noisedict[noise]
    dataset = './datasets/' + dataset

    data = torch.load('./datasets/%sdubpen_qddot.td' %
                      ('damped_' if damped else ''))
    data_ = torch.load(dataset)
    dt = 0.05
    logdir = 'data/%s_%s_%.1e_%.3f_%d' % ('damped' if damped else 'undamped',
                                          method, lr, noise, seed)
    logger.setup(logdir, action='d')

    inds = torch.randperm(16)

    Btr = 8
    Bte = 4
    Bva = 4

    train_data_ = data_[inds[:Btr]]
    test_data_ = data_[inds[Btr:Btr + Bte]]
    val_data_ = data_[inds[Btr + Bte:Btr + Bte + Bva]]
    train_data = data[inds[:Btr]]
    test_data = data[inds[Btr:Btr + Bte]]
    val_data = data[inds[Btr + Bte:Btr + Bte + Bva]]

    t_, smq, smdq, smddq = train_data_[:]
    ttest_, smqtest, smdqtest, smddqtest = test_data_[:]
    tval_, smqval, smdqval, smddqval = val_data_[:]
    t, q, dq, ddq = train_data[:]
    ttest, qtest, dqtest, ddqtest = test_data[:]
    tval, qval, dqval, ddqval = val_data[:]

    B, T, qdim = q.shape

    # create the appropriate dataloader
    if 'del' in method:
        smq_1 = smq[:, :-2]
        smq_2 = smq[:, 1:-1]
        smq_3 = smq[:, 2:]
        smq_B = torch.stack([smq_1, smq_2, smq_3], dim=2).reshape(-1, 3,
                                                                  2).detach()
        print(smq_B.shape, smq.shape)
        dataset = TensorDataset(smq_B)
        dataloader = DataLoader(dataset, batch_size=256, shuffle=True)
    elif method == 'qdd':
        dataset = TensorDataset(smq.reshape(-1, 1, 2), smdq.reshape(-1, 1, 2),
                                smddq.reshape(-1, 1, 2))
        dataloader = DataLoader(dataset, batch_size=256, shuffle=True)
    elif method == 'nqqd':
        x = torch.cat([smq, smdq], dim=-1)
        inp = x[:, :-1]
        out = x[:, 1:]
        inp = inp.reshape(-1, 1, 4)
        out = out.reshape(-1, 1, 4)
        dataset = TensorDataset(inp, out)
        dataloader = DataLoader(dataset, batch_size=256, shuffle=True)
    else:
        raise NotImplementedError

    # set up logdir and model
    if damped:
        system = ForcedSMM(qdim=qdim, dt=dt)
    else:
        system = StructuredMechanicalModel(qdim=qdim, dt=dt)

    # create the appropriate closure
    def qddcrit(system, smq_, smdq_, smddq_):

        ddq_ = system.compute_qddot(smq_, smdq_, create_graph=True)
        ddq_loss = torch.nn.functional.mse_loss(ddq_, smddq_)

        return ddq_loss

    def nqqdcrit(system, inp, out):

        out_ = system.step(torch.ones_like(inp)[..., 0], inp)
        nqqd_loss = torch.nn.functional.mse_loss(out_, out)

        return nqqd_loss

    if 'del' in method:
        dyncrit = DELCriterion(t_)

        if 'logdet' in method:
            bc = LogDetBarrierCriterion
            bcf = LogDetBarrierCriterion.mineig
        else:
            bc = MxNormBarrierCriterion
            bcf = MxNormBarrierCriterion.mmxnorm

        # initialize the barrier criterion, and find an appropriate coefficient between it and DELcrit
        lb = bcf(system, smq).detach() * 0.99  # interior point init

        barriercrit = bc(lb)
        delcrit = DELCriterion(t)
        with torch.no_grad():
            dyncritloss = dyncrit(system, smq)
            barriercritloss = barriercrit(system, smq)

            mu = float(dyncritloss /
                       barriercritloss)  # mu makes them ~equal at init

        barriercrit = bc(lb, mu=mu, x_override=smq)

        crit = GroupCriterion([dyncrit, barriercrit])
    elif method == 'qdd':
        crit = qddcrit
    elif method == 'nqqd':
        crit = nqqdcrit
    else:
        raise NotImplementedError

    # setup optimizer, scheduler
    opt = torch.optim.Adam(system.parameters(), lr=lr)
    sched = torch.optim.lr_scheduler.LambdaLR(opt, lambda k: 500. / (500 + k))

    # train
    best_val_loss = np.inf
    best_val_loss_test_qddot = np.inf

    next_params = ptv(system.parameters()).detach()

    for epoch in range(2 if smoketest else 500):

        # train with SGD
        for batch in dataloader:

            prev_params = next_params

            opt.zero_grad()

            loss = crit(system, *batch)

            loss.backward()

            opt.step()

            if 'del' in method:
                # check line search
                n_ls = 0

                while True:

                    next_params = ptv(system.parameters()).detach()

                    del_params = next_params - prev_params

                    with torch.no_grad():
                        c = crit(system, smq)

                    if torch.isnan(c):
                        next_params = prev_params + 0.5 * del_params
                        vtp(next_params, system.parameters())

                        n_ls += 1
                    else:
                        break

        sched.step()

        with torch.no_grad():
            val_sqmddqloss = qddcrit(system, smqval, smdqval, smddqval)
            train_qdd_loss = qddcrit(system, q, dq, ddq)
            test_qdd_loss = qddcrit(system, qtest, dqtest, ddqtest)
            val_qdd_loss = qddcrit(system, qval, dqval, ddqval)

        # select best model using validation error
        if val_sqmddqloss < best_val_loss:
            best_val_loss = float(val_sqmddqloss)
            best_val_loss_test_qddot_loss = float(test_qdd_loss)

            torch.save(
                system.state_dict(),
                os.path.join(logger.get_dir(), 'ckpts', 'best_model.th'))

        logger.logkv("train/epoch", epoch)
        logger.logkv("train/loss", float(loss))
        logger.logkv("train/log10lr",
                     np.log10(float(opt.param_groups[0]['lr'])))

        logger.logkv("eval/val_sqmddqloss", float(val_sqmddqloss))
        logger.logkv("eval/train_qdd_loss", float(train_qdd_loss))
        logger.logkv("eval/test_qdd_loss", float(test_qdd_loss))
        logger.logkv("eval/val_qdd_loss", float(val_qdd_loss))

        logger.logkv("eval/best_val_loss", float(best_val_loss))
        logger.logkv("eval/best_val_loss_test_qddot_loss",
                     float(best_val_loss_test_qddot_loss))

        logger.dumpkvs()
Example #7
0
def train(seed, logdir, sys_seed, k, b):
    from mpl_toolkits.mplot3d import Axes3D
    import matplotlib.pyplot as plt

    ystd = 0.01
    # ystd = 0.

    torch.set_default_dtype(torch.float64)

    logger.setup(logdir, action='d')

    N = 128

    n = 3 * k

    B = b

    true_system = default_lorenz_system(k, obsdif=2)

    utils.set_rng_seed(sys_seed)

    xdim = true_system.xdim
    ydim = true_system.ydim

    dt = true_system._dt

    x0mean = torch.tensor([[-6] * k + [-6] * k + [24.] * k]).unsqueeze(0)

    # simulate true_dynamics over IC distribution
    x_test = x0mean.repeat(1024, 1, 1)
    x_test += 5.0 * torch.randn_like(x_test)
    x_test = x_test.detach()
    t_test = torch.zeros(1024, 1)
    tgt_test = true_system.step_derivs(t_test, x_test).detach()

    Q = (0.01**2) * torch.eye(true_system.xdim)
    R = (ystd**2) * torch.eye(true_system.ydim)
    Px0 = 2.5**2 * torch.eye(true_system.xdim)

    ## simulate the true system

    xs = [x0mean.repeat(B, 1, 1)]
    xs[0] += 2.5 * torch.randn_like(xs[0])
    with torch.no_grad():
        for t in range(N - 1):
            xs.append(true_system.step(torch.tensor([0.] * B), xs[-1]))

    xs = torch.cat(xs, dim=1)

    fig = plt.figure()
    for b in range(B):
        ax = fig.add_subplot(int(np.ceil(B / 2.)), 2, b + 1, projection='3d')

        for k_ in range(k):
            plot3d(plt.gca(),
                   xs[b, :, k_],
                   xs[b, :, k + k_],
                   xs[b, :, 2 * k + k_],
                   linestyle='--',
                   alpha=0.5)

    plt.savefig(os.path.join(logger.get_dir(), 'figs/traj_%d.png' % b),
                dpi=300)
    # plt.show()
    plt.close()

    t = torch.tensor(range(N)).unsqueeze(0).expand(B, -1).to(
        torch.get_default_dtype())

    y = true_system.observe(t, xs).detach()

    # seed for real now
    utils.set_rng_seed(seed)

    y += ystd * torch.randn_like(y)

    # prep system
    system = deepcopy(true_system)

    true_params = parameters_to_vector(true_system.parameters())

    utils.set_rng_seed(seed)

    params = true_params * (
        (torch.rand_like(true_params) - 0.5) / 5. + 1.)  # within 10%

    vector_to_parameters(params, system.parameters())

    params = list(system.parameters())

    Np = 100

    fapf = faPF(Np, system, Q, R, Px0)

    timer = {'start_time': timeit.default_timer()}

    def ecb(epoch):

        logger.logkv('time/epoch', epoch)

        params = list(system.parameters())
        vparams = parameters_to_vector(params)

        error = (vparams - true_params).norm().item()

        logger.logkv('test/log10_paramerror', np.log10(error))

        logger.logkv('time/epochtime',
                     timeit.default_timer() - timer['start_time'])

        timer['start_time'] = timeit.default_timer()

        with torch.no_grad():
            tgt_test_pr = system.step_derivs(t_test, x_test)
            error = float(torch.nn.functional.mse_loss(tgt_test_pr, tgt_test))

        logger.logkv('test/log10_error', np.log10(error))

        return

    epoch_callbacks = [ecb]

    ecb(-1)
    logger.dumpkvs()

    trainer = SAEMTrainer(
        fapf,
        y,
        # gamma_sched=lambda x: HarmonicDecayScheduler(x, a=50.),
        gamma_sched=lambda x: 0.2,
        max_k=5000,
        xlen_cutoff=15,
    )
    trainer.train(params, callbacks=epoch_callbacks)

    return