def forward(ctx, *args): assert len(args) >= 14, 'Internal error: all arguments required.' y0 = args[:-13] (sde, ts, flat_params, dt, bm, method, adjoint_method, adaptive, rtol, atol, dt_min, options, adjoint_options) = args[-13:] ( ctx.sde, ctx.dt, ctx.bm, ctx.adjoint_method, ctx.adaptive, ctx.rtol, ctx.atol, ctx.dt_min, ctx.adjoint_options ) = sde, dt, bm, adjoint_method, adaptive, rtol, atol, dt_min, adjoint_options sde = base_sde.ForwardSDEIto(sde) with torch.no_grad(): ans = sdeint.integrate(sde=sde, y0=y0, ts=ts, bm=bm, method=method, dt=dt, adaptive=adaptive, rtol=rtol, atol=atol, dt_min=dt_min, options=options) ctx.save_for_backward(ts, flat_params, *ans) return ans
def forward(ctx, *args): assert len(args) >= 14, 'Internal error: all arguments required.' y0 = args[:-13] (sde, ts, flat_params, dt, bm, method, adjoint_method, adaptive, rtol, atol, dt_min, options, adjoint_options) = args[-13:] ( ctx.sde, ctx.dt, ctx.bm, ctx.adjoint_method, ctx.adaptive, ctx.rtol, ctx.atol, ctx.dt_min, ctx.adjoint_options ) = sde, dt, bm, adjoint_method, adaptive, rtol, atol, dt_min, adjoint_options sde = base_sde.ForwardSDEIto(sde) with torch.no_grad(): ans_and_logqp = sdeint.integrate(sde=sde, y0=y0, ts=ts, bm=bm, method=method, dt=dt, adaptive=adaptive, rtol=rtol, atol=atol, dt_min=dt_min, options=options, logqp=True) ans, logqp = ans_and_logqp[:len(y0)], ans_and_logqp[len(y0):] # Don't need to save `logqp`, since it is never used in the backward pass to compute gradients. ctx.save_for_backward(ts, flat_params, *ans) return ans + logqp
def sdeint(sde, y0, ts, bm=None, logqp=False, method='srk', dt=1e-3, adaptive=False, rtol=1e-6, atol=1e-5, dt_min=1e-4, options=None, names=None): """Numerically integrate an Itô SDE. Args: sde: An object with the methods `f` and `g` representing the drift and diffusion functions. The methods should take in time `t` and state `y` and return a tensor or tuple of tensors. The output signature of `f` should match `y`. The output of `g` should either be a single (or a tuple) of tensors of size (batch_size, d) for diagonal noise problems or (batch_size, d, m) for other problem types, where d is the dimensionality of state and m is the dimensionality of the Brownian motion. y0: A single (or a tuple) of tensors of size (batch_size, d). ts: A list or 1-D tensor in non-descending order. bm: A `BrownianPath` or `BrownianTree` object. Defaults to `BrownianPath` for diagonal noise residing on CPU. logqp: If True, also return the Radon-Nikodym derivative, which is a log-ratio penalty across the whole path. method: Numerical integration method, one of (`euler`, `milstein`, `srk`). Defaults to `srk`. dt: A float for the constant step size or initial step size for adaptive time-stepping. adaptive: If True, use adaptive time-stepping. rtol: Relative tolerance. atol: Absolute tolerance. dt_min: Minimum step size for adaptive time-stepping. options: Optional dict of configuring options for the indicated integration method. names: Optional dict of method names to use as drift, diffusion, and prior drift. Expected keys are `drift`, `diffusion`, `prior_drift`. Returns: A single state tensor of size (T, batch_size, d) or a tuple of such tensors. Also returns a single log-ratio tensor of size (T - 1, batch_size) or a tuple of such tensors, if logqp=True. Raises: ValueError: An error occurred due to unrecognized noise type/method, or sde module missing required methods. """ names_to_change = get_names_to_change(names) if len(names_to_change) > 0: sde = base_sde.RenameMethodsSDE(sde, **names_to_change) check_contract(sde=sde, method=method, adaptive=adaptive, logqp=logqp) if bm is None: bm = brownian_path.BrownianPath(t0=ts[0], w0=torch.zeros_like(y0).cpu()) tensor_input = isinstance(y0, torch.Tensor) if tensor_input: sde = base_sde.TupleSDE(sde) y0 = (y0, ) bm_ = bm bm = lambda t: (bm_(t), ) sde = base_sde.ForwardSDEIto(sde) results = integrate(sde=sde, y0=y0, ts=ts, bm=bm, method=method, dt=dt, adaptive=adaptive, rtol=rtol, atol=atol, dt_min=dt_min, options=options, logqp=logqp) if not logqp and tensor_input: return results[0] return results