Exemplo n.º 1
0
def test_entropy_determinism(random_order, device, levy_area_approximation,
                             return_U, return_A):
    if device == gpu and not torch.cuda.is_available():
        pytest.skip(msg="CUDA not available.")

    t0, t1 = 0.0, 1.0
    entropy = 56789
    points1 = torch.rand(1000)
    points2 = torch.rand(1000)
    outs = []

    tol = 1e-6 if random_order else 0.

    bm = torchsde.BrownianInterval(
        t0=t0,
        t1=t1,
        size=(),
        device=device,
        levy_area_approximation=levy_area_approximation,
        entropy=entropy,
        tol=tol,
        halfway_tree=random_order)
    for point1, point2 in zip(points1, points2):
        point1, point2 = sorted([point1, point2])
        outs.append(bm(point1, point2, return_U=return_U, return_A=return_A))

    bm = torchsde.BrownianInterval(
        t0=t0,
        t1=t1,
        size=(),
        device=device,
        levy_area_approximation=levy_area_approximation,
        entropy=entropy,
        tol=tol,
        halfway_tree=random_order)
    if random_order:
        perm = torch.randperm(1000)
        points1 = points1[perm]
        points2 = points2[perm]
        outs = [outs[i.item()] for i in perm]
    for point1, point2, out in zip(points1, points2, outs):
        point1, point2 = sorted([point1, point2])
        out_ = bm(point1, point2, return_U=return_U, return_A=return_A)

        # Assert equal
        if torch.is_tensor(out):
            out = (out, )
        if torch.is_tensor(out_):
            out_ = (out_, )
        for outi, outi_ in zip(out, out_):
            if torch.is_tensor(outi):
                assert (outi == outi_).all()
            else:
                assert outi == outi_
Exemplo n.º 2
0
def test_reversibility(sde_cls):
    batch_size = 32
    state_size = 4
    t_size = 20
    dt = 0.1

    brownian_size = {
        NOISE_TYPES.scalar: 1,
        NOISE_TYPES.diagonal: state_size,
        NOISE_TYPES.general: 2,
        NOISE_TYPES.additive: 2
    }[sde_cls.noise_type]

    class MinusSDE(torch.nn.Module):
        def __init__(self, sde):
            self.noise_type = sde.noise_type
            self.sde_type = sde.sde_type
            self.f = lambda t, y: -sde.f(-t, y)
            self.g = lambda t, y: -sde.g(-t, y)

    sde = sde_cls(d=state_size, m=brownian_size, sde_type='stratonovich')
    minus_sde = MinusSDE(sde)
    y0 = torch.full((batch_size, state_size), 0.1)
    ts = torch.linspace(0, (t_size - 1) * dt, t_size)
    bm = torchsde.BrownianInterval(t0=ts[0], t1=ts[-1], size=(batch_size, brownian_size))
    ys, (f, g, z) = torchsde.sdeint(sde, y0, ts, bm=bm, method='reversible_heun', dt=dt, extra=True)
    backward_ts = -ts.flip(0)
    backward_ys = torchsde.sdeint(minus_sde, ys[-1], backward_ts, bm=torchsde.ReverseBrownian(bm),
                                  method='reversible_heun', dt=dt, extra_solver_state=(-f, -g, z))
    backward_ys = backward_ys.flip(0)

    torch.testing.assert_allclose(ys, backward_ys, rtol=1e-6, atol=1e-6)
Exemplo n.º 3
0
def test_consistency(device, levy_area_approximation):
    if device == gpu and not torch.cuda.is_available():
        pytest.skip(msg="CUDA not available.")

    t0, t1 = 0.0, 1.0
    for _ in range(REPS):
        bm = torchsde.BrownianInterval(
            t0=t0,
            t1=t1,
            size=(LARGE_BATCH_SIZE, ),
            device=device,
            levy_area_approximation=levy_area_approximation)

        for _ in range(MEDIUM_REPS):
            ta, t_, tb = sorted(npr.uniform(low=t0, high=t1, size=(3, )))

            if levy_area_approximation == 'none':
                W = bm(ta, tb)
                W1 = bm(ta, t_)
                W2 = bm(t_, tb)
            else:
                W, U = bm(ta, tb, return_U=True)
                W1, U1 = bm(ta, t_, return_U=True)
                W2, U2 = bm(t_, tb, return_U=True)

            torch.testing.assert_allclose(W1 + W2, W, rtol=1e-6, atol=1e-6)
            if levy_area_approximation != 'none':
                torch.testing.assert_allclose(U1 + U2 + (tb - t_) * W1,
                                              U,
                                              rtol=1e-6,
                                              atol=1e-6)
Exemplo n.º 4
0
def test_normality_simple(device, levy_area_approximation):
    if device == gpu and not torch.cuda.is_available():
        pytest.skip(msg="CUDA not available.")

    t0, t1 = 0.0, 1.0
    for _ in range(REPS):
        base_W = torch.tensor(npr.randn(),
                              device=device).repeat(LARGE_BATCH_SIZE)
        bm = torchsde.BrownianInterval(
            t0=t0,
            t1=t1,
            W=base_W,
            levy_area_approximation=levy_area_approximation)

        t_ = npr.uniform(low=t0, high=t1)

        W = bm(t0, t_)

        mean_W = base_W * (t_ - t0) / (t1 - t0)
        std_W = math.sqrt((t1 - t_) * (t_ - t0) / (t1 - t0))
        rescaled_W = (W - mean_W) / std_W

        _, pval = kstest(rescaled_W.cpu().detach().numpy(), 'norm')
        assert pval >= ALPHA

        if levy_area_approximation != 'none':
            W, U = bm(t0, t_, return_U=True)
            H = _U_to_H(W, U, t_ - t0)

            mean_H = 0
            std_H = math.sqrt((t_ - t0) / 12)
            rescaled_H = (H - mean_H) / std_H

            _, pval = kstest(rescaled_H.cpu().detach().numpy(), 'norm')
            assert pval >= ALPHA
Exemplo n.º 5
0
def main(
    batch_size=1024,
    context_size=64,
    hidden_size=128,
    lr_init=1e-2,
    t0=0.,
    t1=2.,
    lr_gamma=0.997,
    num_iters=5000,
    kl_anneal_iters=500,
    pause_every=50,
    noise_std=0.01,
    adjoint=False,
    train_dir='./dump/lorenz/',
    method="euler",
):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    xs, ts = make_dataset(t0=t0,
                          t1=t1,
                          batch_size=batch_size,
                          noise_std=noise_std,
                          train_dir=train_dir,
                          device=device)
    latent_sde = LatentSDE(data_size=3,
                           context_size=context_size,
                           hidden_size=hidden_size).to(device)
    optimizer = optim.Adam(params=latent_sde.parameters(), lr=lr_init)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer,
                                                       gamma=lr_gamma)
    kl_scheduler = LinearScheduler(iters=kl_anneal_iters)

    # Fix the same Brownian motion for visualization.
    bm_vis = torchsde.BrownianInterval(t0=t0,
                                       t1=t1,
                                       size=(
                                           batch_size,
                                           3,
                                       ),
                                       device=device,
                                       levy_area_approximation="space-time")

    for global_step in tqdm.tqdm(range(1, num_iters + 1)):
        latent_sde.zero_grad()
        log_pxs, log_ratio = latent_sde(xs, ts, noise_std, adjoint, method)
        loss = -log_pxs + log_ratio * kl_scheduler.val
        loss.backward()
        optimizer.step()
        scheduler.step()
        kl_scheduler.step()

        if global_step % pause_every == 0:
            lr_now = optimizer.param_groups[0]['lr']
            logging.warning(
                f'global_step: {global_step:06d}, lr: {lr_now:.5f}, '
                f'log_pxs: {log_pxs:.4f}, log_ratio: {log_ratio:.4f} loss: {loss:.4f}, kl_coeff: {kl_scheduler.val:.4f}'
            )
            img_path = os.path.join(train_dir,
                                    f'global_step_{global_step:06d}.pdf')
            vis(xs, ts, latent_sde, bm_vis, img_path)
Exemplo n.º 6
0
    def forward(self, ts, batch_size, eps=None):
        eps = torch.randn(batch_size, 1).to(
            self.qy0_std) if eps is None else eps
        y0 = self.qy0_mean + eps * self.qy0_std
        qy0 = distributions.Normal(loc=self.qy0_mean, scale=self.qy0_std)
        py0 = distributions.Normal(loc=self.py0_mean, scale=self.py0_std)
        logqp0 = distributions.kl_divergence(qy0, py0).sum(dim=1)  # KL(t=0).

        bm = torchsde.BrownianInterval(t0=ts[0],
                                       t1=ts[-1],
                                       dtype=y0.dtype,
                                       device=y0.device,
                                       size=(batch_size, 2),
                                       pool_size=POOL_SIZE,
                                       cache_size=CACHE_SIZE)
        aug_y0 = torch.cat([y0, torch.zeros(batch_size, 1).to(y0)], dim=1)
        aug_ys = sdeint_fn(sde=self,
                           bm=bm,
                           y0=aug_y0,
                           ts=ts.to(device),
                           method=args.method,
                           dt=args.dt,
                           adaptive=args.adaptive,
                           adjoint_adaptive=args.adjoint_adaptive,
                           rtol=args.rtol,
                           atol=args.atol,
                           names={
                               'drift': 'f_aug',
                               'diffusion': 'g_aug'
                           })
        ys, logqp_path = aug_ys[:, :, 0:1], aug_ys[-1, :, 1]
        logqp = (logqp0 + logqp_path).mean(dim=0)  # KL(t=0) + KL(path).
        return ys, logqp
Exemplo n.º 7
0
def test_sdeint_run_shape_method(sde_cls, use_bm, levy_area_approximation,
                                 sde_type, method, adaptive, logqp, device):
    """Tests that sdeint:
    (a) runs/raises an error as appropriate
    (b) produces tensors of the right shape
    (c) accepts every method
    """

    if method == 'milstein_grad_free':
        method = 'milstein'
        options = dict(grad_free=True)
    else:
        options = dict()

    should_fail = False
    if sde_type == 'ito':
        if method not in ('euler', 'srk', 'milstein'):
            should_fail = True
    else:
        if method not in ('euler_heun', 'heun', 'midpoint', 'log_ode',
                          'milstein'):
            should_fail = True
    if method in ('milstein', 'srk') and sde_cls.noise_type == 'general':
        should_fail = True
    if method == 'srk' and levy_area_approximation == 'none':
        should_fail = True
    if method == 'log_ode' and levy_area_approximation in ('none',
                                                           'space-time'):
        should_fail = True

    if sde_cls.noise_type in (NOISE_TYPES.scalar, NOISE_TYPES.diagonal):
        kwargs = {'d': d}
    else:
        kwargs = {'d': d, 'm': m}
    sde = sde_cls(sde_type=sde_type, **kwargs).to(device)

    if use_bm:
        if sde_cls.noise_type == 'scalar':
            size = (batch_size, 1)
        elif sde_cls.noise_type == 'diagonal':
            size = (batch_size, d + 1) if logqp else (batch_size, d)
        else:
            assert sde_cls.noise_type in ('additive', 'general')
            size = (batch_size, m)
        bm = torchsde.BrownianInterval(
            t0=t0,
            t1=t1,
            size=size,
            dtype=dtype,
            device=device,
            levy_area_approximation=levy_area_approximation)
    else:
        bm = None

    _test_sdeint(sde, bm, method, adaptive, logqp, device, should_fail,
                 options)
Exemplo n.º 8
0
 def forward(self,
             y,
             adjoint=False,
             dt=0.02,
             adaptive=False,
             adjoint_adaptive=False,
             method="midpoint",
             rtol=1e-4,
             atol=1e-3):
     # Note: This works correctly, as long as we are requesting the nfe after each gradient update.
     #  There are obviously cleaner ways to achieve this.
     self.nfe = 0
     sdeint = torchsde.sdeint_adjoint if adjoint else torchsde.sdeint
     if self.aug_zeros.numel() > 0:  # Add zero channels.
         aug_zeros = self.aug_zeros.expand(y.shape[0], *self.aug_zeros_size)
         y = torch.cat((y, aug_zeros), dim=1)  # 235200
     aug_y = torch.cat(
         (y.reshape(-1), self.flat_initial_params,
          torch.tensor([0.],
                       device=y.device)))  # 841609: (235200, 606408, 1)
     aug_y = aug_y[None]
     bm = torchsde.BrownianInterval(
         t0=self.ts[0],
         t1=self.ts[-1],
         size=aug_y.shape,
         dtype=aug_y.dtype,
         device=aug_y.device,
         cache_size=45
         if adjoint else 30  # If not adjoint, don't really need to cache.
     )
     if adjoint_adaptive:
         _, aug_y1 = sdeint(self,
                            aug_y,
                            self.ts,
                            bm=bm,
                            method=method,
                            dt=dt,
                            adaptive=adaptive,
                            adjoint_adaptive=adjoint_adaptive,
                            rtol=rtol,
                            atol=atol)
     else:
         _, aug_y1 = sdeint(self,
                            aug_y,
                            self.ts,
                            bm=bm,
                            method=method,
                            dt=dt,
                            adaptive=adaptive,
                            rtol=rtol,
                            atol=atol)
     y1 = aug_y1[:y.numel()].reshape(y.size())
     logits = self.projection(y1)
     logqp = .5 * aug_y1[-1]
     return logits, logqp
Exemplo n.º 9
0
def _setup(device, levy_area_approximation, shape):
    t0, t1 = torch.tensor([0., 1.], device=device)
    ta = torch.rand([], device=device)
    tb = torch.rand([], device=device)
    ta, tb = min(ta, tb), max(ta, tb)
    bm = torchsde.BrownianInterval(
        t0=t0,
        t1=t1,
        size=shape,
        device=device,
        levy_area_approximation=levy_area_approximation)
    return ta, tb, bm
Exemplo n.º 10
0
def test_adjoint(sde_cls, method, sde_type, adaptive):
    # Skipping below, since method not supported for corresponding noise types.
    if sde_cls.noise_type == NOISE_TYPES.general and method in (
            METHODS.milstein, METHODS.srk):
        return

    d = 3
    m = {
        NOISE_TYPES.scalar: 1,
        NOISE_TYPES.diagonal: d,
        NOISE_TYPES.general: 2,
        NOISE_TYPES.additive: 2
    }[sde_cls.noise_type]
    batch_size = 4
    t0, t1 = ts = torch.tensor([0.0, 0.5], device=device)
    dt = 1e-3
    y0 = torch.zeros(batch_size, d).to(device).fill_(0.1)
    sde = sde_cls(d=d, m=m, sde_type=sde_type).to(device)

    levy_area_approximation = {
        'euler': 'none',
        'milstein': 'none',
        'srk': 'space-time',
        'midpoint': 'none'
    }[method]
    bm = torchsde.BrownianInterval(
        t0=t0,
        t1=t1,
        size=(batch_size, m),
        dtype=dtype,
        device=device,
        levy_area_approximation=levy_area_approximation)

    def func(inputs, modules):
        y0, sde = inputs[0], modules[0]
        ys = torchsde.sdeint_adjoint(sde,
                                     y0,
                                     ts,
                                     bm,
                                     dt=dt,
                                     method=method,
                                     adaptive=adaptive)
        return (ys[-1]**2).sum(dim=1).mean(dim=0)

    # `grad_inputs=True` also works, but we really only care about grad wrt params and want fast tests.
    utils.gradcheck(func,
                    y0,
                    sde,
                    eps=1e-6,
                    rtol=1e-2,
                    atol=1e-2,
                    grad_params=True)
Exemplo n.º 11
0
def _compare(w0, ts, msg=''):
    bm = torchsde.BrownianPath(t0=t0, w0=w0)
    bp_py_time = _time_query(bm, ts)
    logging.warning(f'{msg} (torchsde.BrownianPath): {bp_py_time:.4f}')

    bm = torchsde.BrownianTree(t0=t0, t1=t1, w0=w0, tol=1e-5)
    bt_py_time = _time_query(bm, ts)
    logging.warning(f'{msg} (torchsde.BrownianTree): {bt_py_time:.4f}')

    bm = torchsde.BrownianInterval(t0=t0,
                                   t1=t1,
                                   size=w0.shape,
                                   dtype=w0.dtype,
                                   device=w0.device)
    bi_py_time = _time_query(bm, ts)
    logging.warning(f'{msg} (torchsde.BrownianInterval): {bi_py_time:.4f}')

    return bp_py_time, bt_py_time, bi_py_time
Exemplo n.º 12
0
def test_specialised_functions(sde_type, method):
    vector = torch.randn(m)
    fg = problems.FGSDE(sde_type, vector)
    f_and_g = problems.FAndGSDE(sde_type, vector)
    g_prod = problems.GProdSDE(sde_type, vector)
    f_and_g_prod = problems.FAndGProdSDE(sde_type, vector)
    f_and_g_with_g_prod1 = problems.FAndGGProdSDE1(sde_type, vector)
    f_and_g_with_g_prod2 = problems.FAndGGProdSDE2(sde_type, vector)

    y0 = torch.randn(batch_size, d)

    outs = []
    for sde in (fg, f_and_g, g_prod, f_and_g_prod, f_and_g_with_g_prod1, f_and_g_with_g_prod2):
        bm = torchsde.BrownianInterval(t0, t1, (batch_size, m), entropy=45678)
        outs.append(torchsde.sdeint(sde, y0, [t0, t1], dt=dt, bm=bm)[1])
    for o in outs[1:]:
        # Equality of floating points, because we expect them to do everything exactly the same.
        assert o.shape == outs[0].shape
        assert (o == outs[0]).all()
Exemplo n.º 13
0
def main():
    # Dataset.
    ts_, ts_ext_, ts_vis_, ts, ts_ext, ts_vis, ys, ys_ = make_data()

    # Plotting parameters.
    vis_batch_size = 1024
    ylims = (-1.75, 1.75)
    alphas = [0.05, 0.10, 0.15, 0.20, 0.25, 0.30, 0.35, 0.40, 0.45, 0.50, 0.55]
    percentiles = [0.999, 0.99, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]
    vis_idx = np.random.permutation(vis_batch_size)
    # From https://colorbrewer2.org/.
    if args.color == "blue":
        sample_colors = ('#8c96c6', '#8c6bb1', '#810f7c')
        fill_color = '#9ebcda'
        mean_color = '#4d004b'
        num_samples = len(sample_colors)
    else:
        sample_colors = ('#fc4e2a', '#e31a1c', '#bd0026')
        fill_color = '#fd8d3c'
        mean_color = '#800026'
        num_samples = len(sample_colors)

    eps = torch.randn(vis_batch_size, 1).to(
        device)  # Fix seed for the random draws used in the plots.
    bm = torchsde.BrownianInterval(
        t0=ts_vis[0],
        t1=ts_vis[-1],
        size=(vis_batch_size, 1),
        device=device,
        levy_area_approximation='space-time'
    )  # We need space-time Levy area to use the SRK solver

    # Model.
    model = LatentSDE().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-2)
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=.999)
    kl_scheduler = LinearScheduler(iters=args.kl_anneal_iters)

    logpy_metric = EMAMetric()
    kl_metric = EMAMetric()
    loss_metric = EMAMetric()

    if args.show_prior:
        with torch.no_grad():
            zs = model.sample_p(ts=ts_vis,
                                batch_size=vis_batch_size,
                                eps=eps,
                                bm=bm).squeeze()
            ts_vis_, zs_ = ts_vis.cpu().numpy(), zs.cpu().numpy()
            zs_ = np.sort(zs_, axis=1)

            img_dir = os.path.join(args.train_dir, 'prior.png')
            plt.subplot(frameon=False)
            for alpha, percentile in zip(alphas, percentiles):
                idx = int((1 - percentile) / 2. * vis_batch_size)
                zs_bot_ = zs_[:, idx]
                zs_top_ = zs_[:, -idx]
                plt.fill_between(ts_vis_,
                                 zs_bot_,
                                 zs_top_,
                                 alpha=alpha,
                                 color=fill_color)

            # `zorder` determines who's on top; the larger the more at the top.
            plt.scatter(ts_, ys_, marker='x', zorder=3, color='k',
                        s=35)  # Data.
            plt.ylim(ylims)
            plt.xlabel('$t$')
            plt.ylabel('$Y_t$')
            plt.tight_layout()
            plt.savefig(img_dir, dpi=args.dpi)
            plt.close()
            logging.info(f'Saved prior figure at: {img_dir}')

    for global_step in tqdm.tqdm(range(args.train_iters)):
        # Plot and save.
        if global_step % args.pause_iters == 0:
            img_path = os.path.join(args.train_dir,
                                    f'global_step_{global_step}.png')

            with torch.no_grad():
                zs = model.sample_q(ts=ts_vis,
                                    batch_size=vis_batch_size,
                                    eps=eps,
                                    bm=bm).squeeze()
                samples = zs[:, vis_idx]
                ts_vis_, zs_, samples_ = ts_vis.cpu().numpy(), zs.cpu().numpy(
                ), samples.cpu().numpy()
                zs_ = np.sort(zs_, axis=1)
                plt.subplot(frameon=False)

                if args.show_percentiles:
                    for alpha, percentile in zip(alphas, percentiles):
                        idx = int((1 - percentile) / 2. * vis_batch_size)
                        zs_bot_, zs_top_ = zs_[:, idx], zs_[:, -idx]
                        plt.fill_between(ts_vis_,
                                         zs_bot_,
                                         zs_top_,
                                         alpha=alpha,
                                         color=fill_color)

                if args.show_mean:
                    plt.plot(ts_vis_, zs_.mean(axis=1), color=mean_color)

                if args.show_samples:
                    for j in range(num_samples):
                        plt.plot(ts_vis_,
                                 samples_[:, j],
                                 color=sample_colors[j],
                                 linewidth=1.0)

                if args.show_arrows:
                    num, dt = 12, 0.12
                    t, y = torch.meshgrid([
                        torch.linspace(0.2, 1.8, num).to(device),
                        torch.linspace(-1.5, 1.5, num).to(device)
                    ])
                    t, y = t.reshape(-1, 1), y.reshape(-1, 1)
                    fty = model.f(t=t, y=y).reshape(num, num)
                    dt = torch.zeros(num, num).fill_(dt).to(device)
                    dy = fty * dt
                    dt_, dy_, t_, y_ = dt.cpu().numpy(), dy.cpu().numpy(
                    ), t.cpu().numpy(), y.cpu().numpy()
                    plt.quiver(t_,
                               y_,
                               dt_,
                               dy_,
                               alpha=0.3,
                               edgecolors='k',
                               width=0.0035,
                               scale=50)

                if args.hide_ticks:
                    plt.xticks([], [])
                    plt.yticks([], [])

                plt.scatter(ts_, ys_, marker='x', zorder=3, color='k',
                            s=35)  # Data.
                plt.ylim(ylims)
                plt.xlabel('$t$')
                plt.ylabel('$Y_t$')
                plt.tight_layout()
                plt.savefig(img_path, dpi=args.dpi)
                plt.close()
                logging.info(f'Saved figure at: {img_path}')

                if args.save_ckpt:
                    torch.save(
                        {
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict(),
                            'scheduler': scheduler.state_dict(),
                            'kl_scheduler': kl_scheduler
                        },
                        os.path.join(ckpt_dir,
                                     f'global_step_{global_step}.ckpt'))

        # Train.
        optimizer.zero_grad()
        zs, kl = model(ts=ts_ext, batch_size=args.batch_size)
        zs = zs.squeeze()
        zs = zs[
            1:
            -1]  # Drop first and last which are only used to penalize out-of-data region and spread uncertainty.

        likelihood_constructor = {
            "laplace": distributions.Laplace,
            "normal": distributions.Normal
        }[args.likelihood]
        likelihood = likelihood_constructor(loc=zs, scale=args.scale)
        logpy = likelihood.log_prob(ys).sum(dim=0).mean(dim=0)

        loss = -logpy + kl * kl_scheduler.val
        loss.backward()

        optimizer.step()
        scheduler.step()
        kl_scheduler.step()

        logpy_metric.step(logpy)
        kl_metric.step(kl)
        loss_metric.step(loss)

        logging.info(f'global_step: {global_step}, '
                     f'logpy: {logpy_metric.val:.3f}, '
                     f'kl: {kl_metric.val:.3f}, '
                     f'loss: {loss_metric.val:.3f}')
Exemplo n.º 14
0
def test_against_sdeint(sde_cls, sde_type, method, options, dt, rtol, atol,
                        len_ts):
    # Skipping below, since method not supported for corresponding noise types.
    if sde_cls.noise_type == NOISE_TYPES.general and method in (
            METHODS.milstein, METHODS.srk):
        return

    d = 3
    m = {
        NOISE_TYPES.scalar: 1,
        NOISE_TYPES.diagonal: d,
        NOISE_TYPES.general: 2,
        NOISE_TYPES.additive: 2
    }[sde_cls.noise_type]
    batch_size = 4
    ts = torch.linspace(0.0, 1.0, len_ts, device=device, dtype=torch.float64)
    t0 = ts[0]
    t1 = ts[-1]
    y0 = torch.full((batch_size, d),
                    0.1,
                    device=device,
                    dtype=torch.float64,
                    requires_grad=True)
    sde = sde_cls(d=d, m=m, sde_type=sde_type).to(device, torch.float64)

    if method == METHODS.srk:
        levy_area_approximation = LEVY_AREA_APPROXIMATIONS.space_time
    else:
        levy_area_approximation = LEVY_AREA_APPROXIMATIONS.none
    bm = torchsde.BrownianInterval(
        t0=t0,
        t1=t1,
        size=(batch_size, m),
        dtype=torch.float64,
        device=device,
        levy_area_approximation=levy_area_approximation)

    if method == METHODS.reversible_heun:
        adjoint_method = METHODS.adjoint_reversible_heun
        adjoint_options = options
    else:
        adjoint_method = None
        adjoint_options = None

    ys_true = torchsde.sdeint(sde,
                              y0,
                              ts,
                              dt=dt,
                              method=method,
                              bm=bm,
                              options=options)
    grad = torch.randn_like(ys_true)
    ys_true.backward(grad)

    true_grad = torch.cat([y0.grad.view(-1)] +
                          [param.grad.view(-1) for param in sde.parameters()])
    y0.grad.zero_()
    for param in sde.parameters():
        param.grad.zero_()

    ys_test = torchsde.sdeint_adjoint(sde,
                                      y0,
                                      ts,
                                      dt=dt,
                                      method=method,
                                      bm=bm,
                                      adjoint_method=adjoint_method,
                                      options=options,
                                      adjoint_options=adjoint_options)
    ys_test.backward(grad)
    test_grad = torch.cat([y0.grad.view(-1)] +
                          [param.grad.view(-1) for param in sde.parameters()])

    torch.testing.assert_allclose(ys_true, ys_test)
    torch.testing.assert_allclose(true_grad, test_grad, rtol=rtol, atol=atol)
Exemplo n.º 15
0
def test_against_numerical(sde_cls, sde_type, method, options, adaptive):
    # Skipping below, since method not supported for corresponding noise types.
    if sde_cls.noise_type == NOISE_TYPES.general and method in (
            METHODS.milstein, METHODS.srk):
        return

    d = 3
    m = {
        NOISE_TYPES.scalar: 1,
        NOISE_TYPES.diagonal: d,
        NOISE_TYPES.general: 2,
        NOISE_TYPES.additive: 2
    }[sde_cls.noise_type]
    batch_size = 4
    t0, t1 = ts = torch.tensor([0.0, 0.5], device=device)
    dt = 1e-3
    y0 = torch.full((batch_size, d), 0.1, device=device)
    sde = sde_cls(d=d, m=m, sde_type=sde_type).to(device)

    if method == METHODS.srk:
        levy_area_approximation = LEVY_AREA_APPROXIMATIONS.space_time
    else:
        levy_area_approximation = LEVY_AREA_APPROXIMATIONS.none
    bm = torchsde.BrownianInterval(
        t0=t0,
        t1=t1,
        size=(batch_size, m),
        dtype=dtype,
        device=device,
        levy_area_approximation=levy_area_approximation)

    if method == METHODS.reversible_heun:
        tol = 1e-6
        adjoint_method = METHODS.adjoint_reversible_heun
        adjoint_options = options
    else:
        tol = 1e-2
        adjoint_method = None
        adjoint_options = None

    def func(inputs, modules):
        y0, sde = inputs[0], modules[0]
        ys = torchsde.sdeint_adjoint(sde,
                                     y0,
                                     ts,
                                     dt=dt,
                                     method=method,
                                     adjoint_method=adjoint_method,
                                     adaptive=adaptive,
                                     bm=bm,
                                     options=options,
                                     adjoint_options=adjoint_options)
        return (ys[-1]**2).sum(dim=1).mean(dim=0)

    # `grad_inputs=True` also works, but we really only care about grad wrt params and want fast tests.
    utils.gradcheck(func,
                    y0,
                    sde,
                    eps=1e-6,
                    rtol=tol,
                    atol=tol,
                    grad_params=True)
Exemplo n.º 16
0
def test_normality_conditional(device, levy_area_approximation):
    if device == gpu and not torch.cuda.is_available():
        pytest.skip(msg="CUDA not available.")

    t0, t1 = 0.0, 1.0
    for _ in range(REPS):
        bm = torchsde.BrownianInterval(
            t0=t0,
            t1=t1,
            size=(LARGE_BATCH_SIZE, ),
            device=device,
            levy_area_approximation=levy_area_approximation)

        for _ in range(MEDIUM_REPS):
            ta, t_, tb = sorted(npr.uniform(low=t0, high=t1, size=(3, )))

            W = bm(ta, tb)
            W1 = bm(ta, t_)
            W2 = bm(t_, tb)

            mean_W1 = W * (t_ - ta) / (tb - ta)
            std_W1 = math.sqrt((tb - t_) * (t_ - ta) / (tb - ta))
            rescaled_W1 = (W1 - mean_W1) / std_W1
            _, pval = kstest(rescaled_W1.cpu().detach().numpy(), 'norm')
            assert pval >= ALPHA

            mean_W2 = W * (tb - t_) / (tb - ta)
            std_W2 = math.sqrt((tb - t_) * (t_ - ta) / (tb - ta))
            rescaled_W2 = (W2 - mean_W2) / std_W2
            _, pval = kstest(rescaled_W2.cpu().detach().numpy(), 'norm')
            assert pval >= ALPHA

            if levy_area_approximation != 'none':
                W, U = bm(ta, tb, return_U=True)
                W1, U1 = bm(ta, t_, return_U=True)
                W2, U2 = bm(t_, tb, return_U=True)

                h = tb - ta
                h1 = t_ - ta
                h2 = tb - t_

                denom = math.sqrt(h1**3 + h2**3)
                a = h1**3.5 * h2**0.5 / (2 * h * denom)
                b = h1**0.5 * h2**3.5 / (2 * h * denom)
                c = math.sqrt(3) * h1**1.5 * h2**1.5 / (6 * denom)

                H = _U_to_H(W, U, h)
                H1 = _U_to_H(W1, U1, h1)
                H2 = _U_to_H(W2, U2, h2)

                mean_H1 = H * (h1 / h)**2
                std_H1 = math.sqrt(a**2 + c**2) / h1
                rescaled_H1 = (H1 - mean_H1) / std_H1

                _, pval = kstest(rescaled_H1.cpu().detach().numpy(), 'norm')
                assert pval >= ALPHA

                mean_H2 = H * (h2 / h)**2
                std_H2 = math.sqrt(b**2 + c**2) / h2
                rescaled_H2 = (H2 - mean_H2) / std_H2

                _, pval = kstest(rescaled_H2.cpu().detach().numpy(), 'norm')
                assert pval >= ALPHA
Exemplo n.º 17
0
def main():
    # Dataset.
    ts_, ts_ext_, ts_vis_, ts, ts_ext, ts_vis, ys, ys_ = make_data()
    summary = SummaryWriter(os.path.join(args.train_dir, 'tb'))

    # Plotting parameters.
    vis_batch_size = 1024
    ylims = (-1.75, 1.75)
    alphas = [0.05, 0.10, 0.15, 0.20, 0.25, 0.30, 0.35, 0.40, 0.45, 0.50, 0.55]
    percentiles = [0.999, 0.99, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]
    vis_idx = np.random.permutation(vis_batch_size)
    if args.color == "blue":
        fill_color = '#9ebcda'
        mean_color = '#4d004b'
        num_samples = 60
    else:
        sample_colors = ('#fc4e2a', '#e31a1c', '#bd0026')
        fill_color = '#fd8d3c'
        mean_color = '#800026'
        num_samples = len(sample_colors)

    eps = torch.randn(vis_batch_size, 1).to(
        device)  # Fix seed for the random draws used in the plots.
    bm = torchsde.BrownianInterval(
        t0=ts_vis[0],
        t1=ts_vis[-1],
        size=(vis_batch_size, 1),
        device=device,
        levy_area_approximation='space-time',
        pool_size=POOL_SIZE,
        cache_size=CACHE_SIZE,
    )  # We need space-time Levy area to use the SRK solver

    # Model.
    # Note: This `mu` is selected based on the yvalue of the two endpoints of the left and right segments.
    model = LatentSDE(mu=-0.80901699, sigma=args.sigma).to(device)
    optimizer = make_optimizer(optimizer=args.optimizer,
                               params=model.parameters(),
                               lr=args.lr)
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=.99997)
    kl_scheduler = LinearScheduler(iters=args.kl_anneal_iters,
                                   maxval=args.kl_coeff)
    nll_scheduler = ConstantScheduler(constant=args.nll_coef)

    logpy_metric = EMAMetric()
    kl_metric = EMAMetric()
    loss_metric = EMAMetric()

    if os.path.exists(os.path.join(args.train_dir, 'ckpts', f'state.ckpt')):
        logging.info("Loading checkpoints...")
        checkpoint = torch.load(
            os.path.join(args.train_dir, 'ckpts', f'state.ckpt'))
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        try:
            logpy_metric.set(checkpoint['logpy_metric'])
            kl_metric.set(checkpoint['kl_metric'])
            loss_metric.set(checkpoint['loss_metric'])
        except:
            logging.warning(
                f"Could not successfully load logpy, kl, and loss metrics from checkpoint"
            )
        logging.info(
            f"Successfully loaded checkpoints at global_step {checkpoint['global_step']}"
        )

    if args.show_prior:
        with torch.no_grad():
            zs = model.sample_p(ts=ts_vis,
                                batch_size=vis_batch_size,
                                eps=eps,
                                bm=bm).squeeze()
            ts_vis_, zs_ = ts_vis.cpu().numpy(), zs.cpu().numpy()
            zs_ = np.sort(zs_, axis=1)

            img_dir = os.path.join(args.train_dir, 'prior.png')

            plt.subplot(frameon=False)
            for alpha, percentile in zip(alphas, percentiles):
                idx = int((1 - percentile) / 2. * vis_batch_size)
                zs_bot_ = zs_[:, idx]
                zs_top_ = zs_[:, -idx]
                plt.fill_between(ts_vis_,
                                 zs_bot_,
                                 zs_top_,
                                 alpha=alpha,
                                 color=fill_color)

            # `zorder` determines who's on top; the larger the more at the top.
            plt.scatter(ts_, ys_[:, 0], marker='x', zorder=3, color='k',
                        s=35)  # Data.
            if args.data != "irregular_sine":
                plt.scatter(ts_,
                            ys_[:, 1],
                            marker='x',
                            zorder=3,
                            color='k',
                            s=35)  # Data.
            plt.ylim(ylims)
            plt.xlabel('$t$')
            plt.ylabel('$Y_t$')
            plt.tight_layout()
            plt.savefig(img_dir, dpi=args.dpi)
            summary.add_figure('Prior', plt.gcf(), 0)
            logging.info(f'Prior saved to tensorboard')
            plt.close()
            logging.info(f'Saved prior figure at: {img_dir}')

    for global_step in tqdm.tqdm(range(args.train_iters)):
        # Plot and save.
        if global_step % args.pause_iters == 0 or global_step == (
                args.train_iters - 1):
            img_path = os.path.join(args.train_dir, "plots",
                                    f'global_step_{global_step}.png')

            with torch.no_grad():
                # TODO:
                zs = model.sample_q(ts=ts_vis,
                                    batch_size=vis_batch_size,
                                    eps=None,
                                    bm=bm).squeeze()
                samples = zs[:, vis_idx]
                ts_vis_, zs_, samples_ = ts_vis.cpu().numpy(), zs.cpu().numpy(
                ), samples.cpu().numpy()
                zs_ = np.sort(zs_, axis=1)
                plt.subplot(frameon=False)

                if args.show_percentiles:
                    for alpha, percentile in zip(alphas, percentiles):
                        idx = int((1 - percentile) / 2. * vis_batch_size)
                        zs_bot_, zs_top_ = zs_[:, idx], zs_[:, -idx]
                        plt.fill_between(ts_vis_,
                                         zs_bot_,
                                         zs_top_,
                                         alpha=alpha,
                                         color=fill_color)

                if args.show_mean:
                    plt.plot(ts_vis_, zs_.mean(axis=1), color=mean_color)

                if args.show_samples:
                    for j in range(num_samples):
                        plt.plot(ts_vis_, samples_[:, j], linewidth=1.0)

                if args.show_arrows:
                    t_start, t_end = ts_vis_[0], ts_vis_[-1]
                    num, dt = 12, 0.12
                    t, y = torch.meshgrid([
                        torch.linspace(t_start, t_end, num).to(device),
                        torch.linspace(*ylims, num).to(device)
                    ])
                    t, y = t.reshape(-1, 1), y.reshape(-1, 1)
                    fty = model.f(t=t, y=y).reshape(num, num)
                    dt = torch.zeros(num, num).fill_(dt).to(device)
                    dy = fty * dt
                    dt_, dy_, t_, y_ = dt.cpu().numpy(), dy.cpu().numpy(
                    ), t.cpu().numpy(), y.cpu().numpy()
                    plt.quiver(t_,
                               y_,
                               dt_,
                               dy_,
                               alpha=0.3,
                               edgecolors='k',
                               width=0.0035,
                               scale=50)

                if args.hide_ticks:
                    plt.xticks([], [])
                    plt.yticks([], [])

                plt.scatter(ts_,
                            ys_[:, 0],
                            marker='x',
                            zorder=3,
                            color='k',
                            s=35)  # Data.
                if args.data != "irregular_sine":
                    plt.scatter(ts_,
                                ys_[:, 1],
                                marker='x',
                                zorder=3,
                                color='k',
                                s=35)  # Data.
                plt.ylim(ylims)
                plt.xlabel('$t$')
                plt.ylabel('$Y_t$')
                plt.tight_layout()
                if global_step % args.save_fig == 0:
                    plt.savefig(img_path, dpi=args.dpi)
                current_fig = plt.gcf()
                summary.add_figure('Predictions plot', current_fig,
                                   global_step)
                logging.info(f'Predictions plot saved to tensorboard')
                plt.close()
                logging.info(f'Saved figure at: {img_path}')

                if args.save_ckpt:
                    torch.save(
                        {
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict(),
                            'scheduler': scheduler.state_dict()
                        },
                        os.path.join(args.train_dir, 'ckpts',
                                     f'global_step_{global_step}.ckpt'))
                    # for preemption
                    torch.save(
                        {
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict(),
                            'scheduler': scheduler.state_dict(),
                            'global_step': global_step,
                            'logpy_metric': logpy_metric.val,
                            'kl_metric': kl_metric.val,
                            'loss_metric': loss_metric.val
                        }, os.path.join(args.train_dir, 'ckpts',
                                        f'state.ckpt'))

        # Train.
        optimizer.zero_grad()
        zs, kl = model(ts=ts_ext, batch_size=args.batch_size)
        zs = zs.squeeze()
        zs = zs[
            1:
            -1]  # Drop first and last which are only used to penalize out-of-data region and spread uncertainty.

        likelihood_constructor = {
            "laplace": distributions.Laplace,
            "normal": distributions.Normal,
            "cauchy": distributions.Cauchy
        }[args.likelihood]
        likelihood = likelihood_constructor(loc=zs, scale=args.scale)

        # Proper summation of log-likelihoods.
        logpy = 0.
        ys_split = ys.split(split_size=1, dim=-1)
        for _ys in ys_split:
            logpy = logpy + likelihood.log_prob(_ys).sum(dim=0).mean(dim=0)
        logpy = logpy / len(ys_split)

        loss = -logpy * nll_scheduler.val + kl * kl_scheduler.val
        loss.backward()

        optimizer.step()
        scheduler.step()
        kl_scheduler.step()
        nll_scheduler.step(global_step)

        logpy_metric.step(logpy)
        kl_metric.step(kl)
        loss_metric.step(loss)

        logging.info(f'global_step: {global_step}, '
                     f'logpy: {logpy_metric.val:.3f}, '
                     f'kl: {kl_metric.val:.3f}, '
                     f'loss: {loss_metric.val:.3f}')
        summary.add_scalar('KL Schedler', kl_scheduler.val, global_step)
        summary.add_scalar('NLL Schedler', nll_scheduler.val, global_step)
        summary.add_scalar('Loss', loss_metric.val, global_step)
        summary.add_scalar('KL', kl_metric.val, global_step)
        summary.add_scalar('Log(py) Likelihood', logpy_metric.val, global_step)
        logging.info(f'Logged loss, kl, logpy to tensorboard')

    summary.close()