Beispiel #1
0
    def forward(ctx, *args):
        assert len(args) >= 14, 'Internal error: all arguments required.'
        y0 = args[:-13]
        (sde, ts, flat_params, dt, bm, method, adjoint_method, adaptive, rtol,
         atol, dt_min, options, adjoint_options) = args[-13:]
        (
            ctx.sde, ctx.dt, ctx.bm, ctx.adjoint_method, ctx.adaptive,
            ctx.rtol, ctx.atol, ctx.dt_min, ctx.adjoint_options
        ) = sde, dt, bm, adjoint_method, adaptive, rtol, atol, dt_min, adjoint_options

        sde = base_sde.ForwardSDEIto(sde)
        with torch.no_grad():
            ans = sdeint.integrate(sde=sde,
                                   y0=y0,
                                   ts=ts,
                                   bm=bm,
                                   method=method,
                                   dt=dt,
                                   adaptive=adaptive,
                                   rtol=rtol,
                                   atol=atol,
                                   dt_min=dt_min,
                                   options=options)
        ctx.save_for_backward(ts, flat_params, *ans)
        return ans
Beispiel #2
0
    def forward(ctx, *args):
        assert len(args) >= 14, 'Internal error: all arguments required.'
        y0 = args[:-13]
        (sde, ts, flat_params, dt, bm, method, adjoint_method, adaptive, rtol,
         atol, dt_min, options, adjoint_options) = args[-13:]
        (
            ctx.sde, ctx.dt, ctx.bm, ctx.adjoint_method, ctx.adaptive,
            ctx.rtol, ctx.atol, ctx.dt_min, ctx.adjoint_options
        ) = sde, dt, bm, adjoint_method, adaptive, rtol, atol, dt_min, adjoint_options

        sde = base_sde.ForwardSDEIto(sde)
        with torch.no_grad():
            ans_and_logqp = sdeint.integrate(sde=sde,
                                             y0=y0,
                                             ts=ts,
                                             bm=bm,
                                             method=method,
                                             dt=dt,
                                             adaptive=adaptive,
                                             rtol=rtol,
                                             atol=atol,
                                             dt_min=dt_min,
                                             options=options,
                                             logqp=True)
            ans, logqp = ans_and_logqp[:len(y0)], ans_and_logqp[len(y0):]

        # Don't need to save `logqp`, since it is never used in the backward pass to compute gradients.
        ctx.save_for_backward(ts, flat_params, *ans)
        return ans + logqp
Beispiel #3
0
def sdeint(sde,
           y0,
           ts,
           bm=None,
           logqp=False,
           method='srk',
           dt=1e-3,
           adaptive=False,
           rtol=1e-6,
           atol=1e-5,
           dt_min=1e-4,
           options=None,
           names=None):
    """Numerically integrate an Itô SDE.

    Args:
        sde: An object with the methods `f` and `g` representing the drift and diffusion functions. The methods
            should take in time `t` and state `y` and return a tensor or tuple of tensors. The output signature of
            `f` should match `y`. The output of `g` should either be a single (or a tuple) of tensors of size
            (batch_size, d) for diagonal noise problems or (batch_size, d, m) for other problem types,
            where d is the dimensionality of state and m is the dimensionality of the Brownian motion.
        y0: A single (or a tuple) of tensors of size (batch_size, d).
        ts: A list or 1-D tensor in non-descending order.
        bm: A `BrownianPath` or `BrownianTree` object. Defaults to `BrownianPath` for diagonal noise residing on CPU.
        logqp: If True, also return the Radon-Nikodym derivative, which is a log-ratio penalty across the whole path.
        method: Numerical integration method, one of (`euler`, `milstein`, `srk`). Defaults to `srk`.
        dt: A float for the constant step size or initial step size for adaptive time-stepping.
        adaptive: If True, use adaptive time-stepping.
        rtol: Relative tolerance.
        atol: Absolute tolerance.
        dt_min: Minimum step size for adaptive time-stepping.
        options: Optional dict of configuring options for the indicated integration method.
        names: Optional dict of method names to use as drift, diffusion, and prior drift. Expected keys are `drift`,
            `diffusion`, `prior_drift`.

    Returns:
        A single state tensor of size (T, batch_size, d) or a tuple of such tensors. Also returns a single log-ratio
        tensor of size (T - 1, batch_size) or a tuple of such tensors, if logqp=True.

    Raises:
        ValueError: An error occurred due to unrecognized noise type/method, or sde module missing required methods.
    """
    names_to_change = get_names_to_change(names)
    if len(names_to_change) > 0:
        sde = base_sde.RenameMethodsSDE(sde, **names_to_change)
    check_contract(sde=sde, method=method, adaptive=adaptive, logqp=logqp)

    if bm is None:
        bm = brownian_path.BrownianPath(t0=ts[0],
                                        w0=torch.zeros_like(y0).cpu())

    tensor_input = isinstance(y0, torch.Tensor)
    if tensor_input:
        sde = base_sde.TupleSDE(sde)
        y0 = (y0, )
        bm_ = bm
        bm = lambda t: (bm_(t), )

    sde = base_sde.ForwardSDEIto(sde)
    results = integrate(sde=sde,
                        y0=y0,
                        ts=ts,
                        bm=bm,
                        method=method,
                        dt=dt,
                        adaptive=adaptive,
                        rtol=rtol,
                        atol=atol,
                        dt_min=dt_min,
                        options=options,
                        logqp=logqp)
    if not logqp and tensor_input:
        return results[0]
    return results