Beispiel #1
0
def inspect_strong_order():
    batch_size, d, m = 4096, 5, 5
    t0, t1 = ts = torch.tensor([0., 5.]).to(device)
    dts = tuple(2**-i for i in range(1, 9))
    y0 = torch.ones(batch_size, d).to(device)
    sde = Ex3Additive(d=d).to(device)

    euler_mses_ = []
    srk_mses_ = []

    with torch.no_grad():
        bm = BrownianPath(t0=t0, w0=torch.zeros(
            batch_size,
            m).to(device))  # It's important to have the correct size!!!

        for dt in tqdm.tqdm(dts):
            # Only take end value.
            _, ys_euler = sdeint(sde,
                                 y0=y0,
                                 ts=ts,
                                 dt=dt,
                                 bm=bm,
                                 method='euler')
            _, ys_srk = sdeint(sde,
                               y0=y0,
                               ts=ts,
                               dt=dt,
                               bm=bm,
                               method='srk',
                               options={'trapezoidal_approx': False})
            _, ys_analytical = sde.analytical_sample(y0=y0, ts=ts, bm=bm)

            euler_mse = compute_mse(ys_euler, ys_analytical)
            srk_mse = compute_mse(ys_srk, ys_analytical)

            euler_mse_, srk_mse_ = to_numpy(euler_mse, srk_mse)

            euler_mses_.append(euler_mse_)
            srk_mses_.append(srk_mse_)
    del euler_mse_, srk_mse_

    # Divide the log-error by 2, since textbook strong orders are represented so.
    log = lambda x: np.log(np.array(x))
    euler_slope, _, _, _, _ = stats.linregress(log(dts), log(euler_mses_) / 2)
    srk_slope, _, _, _, _ = stats.linregress(log(dts), log(srk_mses_) / 2)

    plt.figure()
    plt.plot(dts, euler_mses_, label=f'euler(k={euler_slope:.4f})')
    plt.plot(dts, srk_mses_, label=f'srk(k={srk_slope:.4f})')
    plt.xscale('log')
    plt.yscale('log')
    plt.legend()

    img_dir = os.path.join('.', 'diagnostics', 'plots', 'srk_additive')
    makedirs_if_not_found(img_dir)
    plt.savefig(os.path.join(img_dir, 'rate'))
    plt.close()
Beispiel #2
0
def inspect_rate():
    batch_size, d, m = 4096, 5, 5
    t0, t1 = ts = torch.tensor([0., 5.]).to(device)
    dts = tuple(2**-i for i in range(1, 10))
    y0 = torch.ones(batch_size, d).to(device)
    sde = Ex3Additive(d=d).to(device)

    euler_mses_ = []
    srk_mses_ = []

    with torch.no_grad():
        bm = BrownianPath(t0=t0, w0=torch.zeros(
            batch_size,
            m).to(device))  # It's important to have the correct size!!!

        for dt in dts:
            # Only take end value.
            _, ys_euler = sdeint(sde,
                                 y0=y0,
                                 ts=ts,
                                 dt=dt,
                                 bm=bm,
                                 method='euler')
            _, ys_srk = sdeint(sde,
                               y0=y0,
                               ts=ts,
                               dt=dt,
                               bm=bm,
                               method='srk',
                               options={'trapezoidal_approx': False})
            _, ys_analytical = sde.analytical_sample(y0=y0, ts=ts, bm=bm)

            euler_mse = compute_mse(ys_euler, ys_analytical)
            srk_mse = compute_mse(ys_srk, ys_analytical)

            euler_mse_, srk_mse_ = to_numpy(euler_mse, srk_mse)

            euler_mses_.append(euler_mse_)
            srk_mses_.append(srk_mse_)

    plt.figure()
    plt.plot(dts, euler_mses_, label='euler')
    plt.plot(dts, srk_mses_, label='srk')
    plt.xscale('log')
    plt.yscale('log')
    plt.legend()

    img_dir = os.path.join('.', 'diagnostics', 'plots', 'srk_additive')
    makedirs_if_not_found(img_dir)
    plt.savefig(os.path.join(img_dir, 'rate'))
    plt.close()
    def test_normality(self):
        """Kolmogorov-Smirnov test."""
        t0_, t1_ = 0.0, 1.0
        t0, t1 = torch.tensor([t0_, t1_])
        eps = 1e-2
        for _ in range(REPS):
            w0_, w1_ = 0.0, npr.randn() * np.sqrt(t1_)
            # Use the same endpoint for the batch, so samples from same dist.
            w0 = torch.tensor(w0_).repeat(BATCH_SIZE)
            w1 = torch.tensor(w1_).repeat(BATCH_SIZE)

            bm = BrownianPath(t0=t0, w0=w0)
            bm.insert(t=t1, w=w1)

            t_ = npr.uniform(low=t0_ + eps, high=t1_ - eps)
            samples = bm(t_)
            samples_ = samples.detach().numpy()

            mean_ = ((t1_ - t_) * w0_ + (t_ - t0_) * w1_) / (t1_ - t0_)
            std_ = np.sqrt((t1_ - t_) * (t_ - t0_) / (t1_ - t0_))
            ref_dist = norm(loc=mean_, scale=std_)

            _, pval = kstest(samples_, ref_dist.cdf)
            self.assertGreaterEqual(pval, ALPHA)
Beispiel #4
0
def inspect_rate():
    batch_size, d = 4096, 10
    t0, t1 = ts = torch.tensor([0., 5.]).to(device)
    dts = tuple(2 ** -i for i in range(1, 10))
    y0 = torch.ones(batch_size, d).to(device)
    sde = Ex2(d=d).to(device)

    euler_mses_ = []
    milstein_mses_ = []
    srk_mses_ = []

    with torch.no_grad():
        bm = BrownianPath(t0=t0, w0=torch.zeros(batch_size, 1).to(device))

        for dt in dts:
            # Only take end value.
            _, ys_euler = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='euler')
            _, ys_milstein = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='milstein')
            _, ys_srk = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='srk')
            _, ys_analytical = sde.analytical_sample(y0=y0, ts=ts, bm=bm)

            euler_mse = compute_mse(ys_euler, ys_analytical)
            milstein_mse = compute_mse(ys_milstein, ys_analytical)
            srk_mse = compute_mse(ys_srk, ys_analytical)

            euler_mse_, milstein_mse_, srk_mse_ = to_numpy(euler_mse, milstein_mse, srk_mse)

            euler_mses_.append(euler_mse_)
            milstein_mses_.append(milstein_mse_)
            srk_mses_.append(srk_mse_)

    plt.figure()
    plt.plot(dts, euler_mses_, label='euler')
    plt.plot(dts, milstein_mses_, label='milstein')
    plt.plot(dts, srk_mses_, label='srk')
    plt.xscale('log')
    plt.yscale('log')
    plt.legend()

    img_dir = os.path.join('.', 'diagnostics', 'plots', 'srk_scalar')
    makedirs_if_not_found(img_dir)
    plt.savefig(os.path.join(img_dir, 'rate'))
    plt.close()
Beispiel #5
0
def inspect_samples():
    batch_size, d, m = 32, 1, 5
    steps = 10

    ts = torch.linspace(0., 5., steps=steps).to(device)
    t0 = ts[0]
    dt = 3e-1
    y0 = torch.ones(batch_size, d).to(device)
    sde = AdditiveSDE(d=d, m=m).to(device)

    with torch.no_grad():
        bm = BrownianPath(t0=t0, w0=torch.zeros(batch_size, m).to(device))
        ys_em = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='euler')
        ys_srk = sdeint(sde,
                        y0=y0,
                        ts=ts,
                        dt=dt,
                        bm=bm,
                        method='srk',
                        options={'trapezoidal_approx': False})
        ys_true = sdeint(sde, y0=y0, ts=ts, dt=1e-3, bm=bm, method='euler')

        ys_em = ys_em.squeeze().t()
        ys_srk = ys_srk.squeeze().t()
        ys_true = ys_true.squeeze().t()

        ts_, ys_em_, ys_srk_, ys_true_ = to_numpy(ts, ys_em, ys_srk, ys_true)

    # Visualize sample path.
    img_dir = os.path.join('.', 'diagnostics', 'plots', 'srk_additive')
    makedirs_if_not_found(img_dir)

    for i, (ys_em_i, ys_srk_i,
            ys_true_i) in enumerate(zip(ys_em_, ys_srk_, ys_true_)):
        plt.figure()
        plt.plot(ts_, ys_em_i, label='em')
        plt.plot(ts_, ys_srk_i, label='srk')
        plt.plot(ts_, ys_true_i, label='true')
        plt.legend()
        plt.savefig(os.path.join(img_dir, f'{i}'))
        plt.close()
Beispiel #6
0
def inspect_sample():
    batch_size, d = 32, 1
    steps = 100

    ts = torch.linspace(0., 5., steps=steps).to(device)
    t0 = ts[0]
    dt = 1e-1
    y0 = torch.ones(batch_size, d).to(device)
    sde = Ex2(d=d).to(device)
    sde.noise_type = "scalar"

    with torch.no_grad():
        bm = BrownianPath(t0=t0, w0=torch.zeros_like(y0))
        ys_euler = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='euler')
        ys_milstein = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='milstein')
        ys_srk = sdeint(sde, y0=y0, ts=ts, dt=dt, bm=bm, method='srk', options={'trapezoidal_approx': False})
        ys_analytical = sde.analytical_sample(y0=y0, ts=ts, bm=bm)

        ys_euler = ys_euler.squeeze().t()
        ys_milstein = ys_milstein.squeeze().t()
        ys_srk = ys_srk.squeeze().t()
        ys_analytical = ys_analytical.squeeze().t()

        ts_, ys_euler_, ys_milstein_, ys_srk_, ys_analytical_ = to_numpy(
            ts, ys_euler, ys_milstein, ys_srk, ys_analytical)

    # Visualize sample path.
    img_dir = os.path.join('.', 'diagnostics', 'plots', 'srk_scalar')
    makedirs_if_not_found(img_dir)

    for i, (ys_euler_i, ys_milstein_i, ys_srk_i, ys_analytical_i) in enumerate(
            zip(ys_euler_, ys_milstein_, ys_srk_, ys_analytical_)):
        plt.figure()
        plt.plot(ts_, ys_euler_i, label='euler')
        plt.plot(ts_, ys_milstein_i, label='milstein')
        plt.plot(ts_, ys_srk_i, label='srk')
        plt.plot(ts_, ys_analytical_i, label='analytical')
        plt.legend()
        plt.savefig(os.path.join(img_dir, f'{i}'))
        plt.close()
Beispiel #7
0
    def _test_gradient(self, problem, method, adaptive, rtol=1e-6, atol=1e-5):
        if method == 'euler' and adaptive:
            return

        bm = BrownianPath(t0=t0, w0=w0)
        with torch.no_grad():
            grad_outputs = torch.ones(batch_size, d).to(device)
            alt_grad = problem.analytical_grad(y0, t1, grad_outputs, bm)

        problem.zero_grad()
        _, yt = sdeint_adjoint(problem,
                               y0,
                               ts,
                               bm=bm,
                               method=method,
                               dt=dt,
                               adaptive=adaptive,
                               rtol=rtol,
                               atol=atol)
        loss = yt.sum(dim=1).mean(dim=0)
        loss.backward()
        adj_grad = torch.cat(tuple(p.grad for p in problem.parameters()))
        self.tensorAssertAllClose(alt_grad, adj_grad)
Beispiel #8
0
def _test_forward_and_backward(sde):
    bm = BrownianPath(t0=t0, w0=w0)
    for method in methods:
        _test_forward(sde, bm, method=method)
        _test_backward(sde, bm, method=method)
Beispiel #9
0
    def on_epoch_end(self, vis_n_sim=1024):

        img_path = os.path.join(train_dir, f'global_step_{self.current_epoch}.png')
        ylims = (-1.75, 1.75)
        alphas = [0.05, 0.10, 0.15, 0.20, 0.25, 0.30, 0.35, 0.40, 0.45, 0.50, 0.55]
        percentiles = [0.999, 0.99, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]
        sample_colors = ('#8c96c6', '#8c6bb1', '#810f7c')
        fill_color = '#9ebcda'
        mean_color = '#4d004b'
        num_samples = len(sample_colors)
        vis_idx = np.random.permutation(vis_n_sim)

        eps = torch.randn(vis_n_sim, 1)
        bm = BrownianPath(t0=self.vis_span[0], w0=torch.zeros(vis_n_sim, 1))

        # -- Not used -- From show_prior option in original implementation
        # zs = self.model.sample_p(vis_span=self.vis_span, n_sim=vis_n_sim, eps=eps, bm=bm).squeeze()
        # ts_vis_, zs_ = self.vis_span.cpu().numpy(), zs.cpu().numpy()
        # zs_ = np.sort(zs_, axis=1)

        zs = self.model.lsde.sample_q(vis_span=self.vis_span, n_sim=vis_n_sim, eps=eps, bm=bm).squeeze()
        samples = zs[:, vis_idx]
        s_span_vis_ = self.vis_span.cpu().detach().numpy()
        zs_ = zs.cpu().detach().numpy()
        samples_ = samples.cpu().detach().numpy()

        zs_ = np.sort(zs_, axis=1)

        with torch.no_grad():

            plt.subplot(frameon=False)

            for alpha, percentile in zip(alphas, percentiles):
                idx = int((1 - percentile) / 2. * vis_n_sim)
                zs_bot_, zs_top_ = zs_[:, idx], zs_[:, -idx]
                plt.fill_between(s_span_vis_, zs_bot_, zs_top_, alpha=alpha, color=fill_color)

            plt.plot(s_span_vis_, zs_.mean(axis=1), color=mean_color)

            for j in range(num_samples):
                plt.plot(s_span_vis_, samples_[:, j], color=sample_colors[j], linewidth=1.0)

            num, ds = 12, 0.12
            s, x = torch.meshgrid(
                [torch.linspace(0.2, 1.8, num), torch.linspace(-1.5, 1.5, num)]
            )

            s, x = s.reshape(-1, 1).to(self.device), x.reshape(-1, 1).to(self.device)

            ftx = self.model.lsde.defunc.f(s=s, x=x)
            ftx = ftx.cpu().reshape(num, num)

            ds = torch.zeros(num, num).fill_(ds)
            dx = ftx * ds
            ds_, dx_, = ds.cpu().detach().numpy(), dx.cpu().detach().numpy()
            s_, x_ = s.cpu().detach().numpy(), x.cpu().detach().numpy()

            plt.quiver(s_, x_, ds_, dx_, alpha=0.3, edgecolors='k', width=0.0035, scale=50)

            # Data.
            plt.scatter(self.s_span.cpu().numpy(), self.x_sample.cpu().numpy(), marker='x', zorder=3, color='k', s=35)

            plt.ylim(ylims)
            plt.xlabel('$t$')
            plt.ylabel('$Y_t$')
            plt.tight_layout()
            plt.savefig(img_path, dpi=400)
            plt.close()
Beispiel #10
0
t0 = 0.0
t1 = 0.3
T = 5
batch_size = 16
dt = 1e-2
ts = torch.linspace(t0, t1, steps=T).to(device)
y0 = torch.ones(batch_size, d).to(device)

basic_sdes = (
    basic_sde.BasicSDE1(d=d).to(device),
    basic_sde.BasicSDE2(d=d).to(device),
    basic_sde.BasicSDE3(d=d).to(device),
    basic_sde.BasicSDE4(d=d).to(device),
)

bm_diagonal = BrownianPath(t0=ts[0], w0=torch.zeros(batch_size, d).to(device))
bm_general = BrownianPath(t0=ts[0], w0=torch.zeros(batch_size, m).to(device))
bm_scalar = BrownianPath(t0=ts[0], w0=torch.zeros(batch_size, 1).to(device))


class TestSdeint(TorchTestCase):

    def test_rename_methods(self):
        # Test renaming works with a subset of names when `logqp=False`.
        sde = basic_sde.CustomNamesSDE().to(device)
        ans = sdeint(sde, y0, ts, dt=dt, names={'drift': 'forward'})
        self.assertEqual(ans.shape, (T, batch_size, d))

        # Test renaming works with a subset of names when `logqp=True`.
        sde = basic_sde.CustomNamesSDELogqp().to(device)
        ans = sdeint(sde, y0, ts, dt=dt, names={'drift': 'forward', 'prior_drift': 'w'}, logqp=True)
    def _setUp(self, device=None):
        t0, t1 = torch.tensor([0., 1.]).to(device)
        w0, w1 = torch.randn([2, BATCH_SIZE, D]).to(device)

        self.t = torch.rand([]).to(device)
        self.bm = BrownianPath(t0=t0, w0=w0)
Beispiel #12
0
def main():
    # Dataset.
    ts_, ts_ext_, ts_vis_, ts, ts_ext, ts_vis, ys, ys_ = make_data()

    # Plotting parameters.
    vis_batch_size = 1024
    ylims = (-1.75, 1.75)
    alphas = [0.05, 0.10, 0.15, 0.20, 0.25, 0.30, 0.35, 0.40, 0.45, 0.50, 0.55]
    percentiles = [0.999, 0.99, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]
    vis_idx = npr.permutation(vis_batch_size)
    # From https://colorbrewer2.org/.
    if args.color == "blue":
        sample_colors = ('#8c96c6', '#8c6bb1', '#810f7c')
        fill_color = '#9ebcda'
        mean_color = '#4d004b'
        num_samples = len(sample_colors)
    else:
        sample_colors = ('#fc4e2a', '#e31a1c', '#bd0026')
        fill_color = '#fd8d3c'
        mean_color = '#800026'
        num_samples = len(sample_colors)

    # Fix seed for the random draws used in the plots.
    eps = torch.randn(vis_batch_size, 1).to(device)
    bm = BrownianPath(t0=ts_vis[0],
                      w0=torch.zeros(vis_batch_size, 1).to(device))

    # Model.
    model = LatentSDE().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-2)
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=.999)
    kl_scheduler = utils.LinearScheduler(iters=args.kl_anneal_iters)

    logp_metric = utils.EMAMetric()
    log_ratio_metric = utils.EMAMetric()
    loss_metric = utils.EMAMetric()

    if args.show_prior:
        with torch.no_grad():
            zs = model.sample_p(ts=ts_vis,
                                batch_size=vis_batch_size,
                                eps=eps,
                                bm=bm).squeeze()
            ts_vis_, zs_ = ts_vis.cpu().numpy(), zs.cpu().numpy()
            zs_ = np.sort(zs_, axis=1)

            img_dir = os.path.join(args.train_dir, 'prior.png')
            plt.subplot(frameon=False)
            for alpha, percentile in zip(alphas, percentiles):
                idx = int((1 - percentile) / 2. * vis_batch_size)
                zs_bot_ = zs_[:, idx]
                zs_top_ = zs_[:, -idx]
                plt.fill_between(ts_vis_,
                                 zs_bot_,
                                 zs_top_,
                                 alpha=alpha,
                                 color=fill_color)

            # `zorder` determines who's on top; the larger the more at the top.
            plt.scatter(ts_, ys_, marker='x', zorder=3, color='k',
                        s=35)  # Data.
            plt.ylim(ylims)
            plt.xlabel('$t$')
            plt.ylabel('$Y_t$')
            plt.tight_layout()
            plt.savefig(img_dir, dpi=args.dpi)
            plt.close()
            logging.info(f'Saved prior figure at: {img_dir}')

    for global_step in tqdm.tqdm(range(args.train_iters)):
        # Plot and save.
        if global_step % args.pause_iters == 0:
            img_path = os.path.join(args.train_dir,
                                    f'global_step_{global_step}.png')

            with torch.no_grad():
                zs = model.sample_q(ts=ts_vis,
                                    batch_size=vis_batch_size,
                                    eps=eps,
                                    bm=bm).squeeze()
                samples = zs[:, vis_idx]
                ts_vis_, zs_, samples_ = ts_vis.cpu().numpy(), zs.cpu().numpy(
                ), samples.cpu().numpy()
                zs_ = np.sort(zs_, axis=1)
                plt.subplot(frameon=False)

                if args.show_percentiles:
                    for alpha, percentile in zip(alphas, percentiles):
                        idx = int((1 - percentile) / 2. * vis_batch_size)
                        zs_bot_, zs_top_ = zs_[:, idx], zs_[:, -idx]
                        plt.fill_between(ts_vis_,
                                         zs_bot_,
                                         zs_top_,
                                         alpha=alpha,
                                         color=fill_color)

                if args.show_mean:
                    plt.plot(ts_vis_, zs_.mean(axis=1), color=mean_color)

                if args.show_samples:
                    for j in range(num_samples):
                        plt.plot(ts_vis_,
                                 samples_[:, j],
                                 color=sample_colors[j],
                                 linewidth=1.0)

                if args.show_arrows:
                    num, dt = 12, 0.12
                    t, y = torch.meshgrid([
                        torch.linspace(0.2, 1.8, num).to(device),
                        torch.linspace(-1.5, 1.5, num).to(device)
                    ])
                    t, y = t.reshape(-1, 1), y.reshape(-1, 1)
                    fty = model.f(t=t, y=y).reshape(num, num)
                    dt = torch.zeros(num, num).fill_(dt).to(device)
                    dy = fty * dt
                    dt_, dy_, t_, y_ = dt.cpu().numpy(), dy.cpu().numpy(
                    ), t.cpu().numpy(), y.cpu().numpy()
                    plt.quiver(t_,
                               y_,
                               dt_,
                               dy_,
                               alpha=0.3,
                               edgecolors='k',
                               width=0.0035,
                               scale=50)

                if args.hide_ticks:
                    plt.xticks([], [])
                    plt.yticks([], [])

                plt.scatter(ts_, ys_, marker='x', zorder=3, color='k',
                            s=35)  # Data.
                plt.ylim(ylims)
                plt.xlabel('$t$')
                plt.ylabel('$Y_t$')
                plt.tight_layout()
                plt.savefig(img_path, dpi=args.dpi)
                plt.close()
                logging.info(f'Saved figure at: {img_path}')

                if args.save_ckpt:
                    torch.save({'model': model.state_dict()},
                               os.path.join(ckpt_dir,
                                            f'global_step_{global_step}.ckpt'))

        # Train.
        optimizer.zero_grad()
        zs, log_ratio = model(ts=ts_ext, batch_size=args.batch_size)
        zs = zs.squeeze()
        zs = zs[
            1:
            -1]  # Drop first and last which are only used to penalize out-of-data region and spread uncertainty.

        likelihood = {
            "laplace": Laplace(loc=zs, scale=args.scale),
            "normal": Normal(loc=zs, scale=args.scale)
        }[args.likelihood]
        logp = likelihood.log_prob(ys).sum(dim=0).mean(dim=0)

        loss = -logp + log_ratio * kl_scheduler()
        loss.backward()
        optimizer.step()
        scheduler.step()
        kl_scheduler.step()

        logp_metric.step(logp)
        log_ratio_metric.step(log_ratio)
        loss_metric.step(loss)

        logging.info(
            f'global_step: {global_step}, '
            f'logp: {logp_metric.val():.3f}, log_ratio: {log_ratio_metric.val():.3f}, loss: {loss_metric.val():.3f}'
        )