예제 #1
0
    def forward(self, coeffs):
        if self.interpolation == 'cubic':
            X = torchcde.CubicSpline(coeffs)
        elif self.interpolation == 'linear':
            X = torchcde.LinearInterpolation(coeffs)
        else:
            raise ValueError(
                "Only 'linear' and 'cubic' interpolation methods are implemented."
            )

        ######################
        # Easy to forget gotcha: Initial hidden state should be a function of the first observation.
        ######################
        X0 = X.evaluate(X.interval[0])
        z0 = self.initial(X0)

        ######################
        # Actually solve the CDE.
        ######################
        z_T = torchcde.cdeint(X=X, z0=z0, func=self.func, t=X.interval)

        ######################
        # Both the initial value and the terminal value are returned from cdeint; extract just the terminal value,
        # and then apply a linear map.
        ######################
        z_T = z_T[:, 1]
        pred_y = self.readout(z_T)
        return pred_y
예제 #2
0
 def forward(self, coeffs):
     X = torchcde.CubicSpline(coeffs)
     X0 = X.evaluate(X.interval[0])
     z0 = self.initial(X0)
     zt = torchcde.cdeint(X=X, func=self.func, z0=z0, t=X.interval)
     zT = zt[..., -1, :]  # get the terminal value of the CDE
     return self.readout(zT)
예제 #3
0
def test_detach_trick():
    path = torch.rand(1, 10, 3)
    interp = torchcde.CubicSpline(torchcde.natural_cubic_coeffs(path))

    func = _Func(input_size=3, hidden_size=3)

    for adjoint in (True, False):
        variable_grads = []
        z0 = torch.rand(1, 3)
        for t_grad in (True, False):
            t_ = torch.tensor([0., 9.], requires_grad=t_grad)
            # Don't test dopri5. We will get different results then, because the t variable will force smaller step
            # sizes and thus slightly different results.
            z = torchcde.cdeint(X=interp,
                                z0=z0,
                                func=func,
                                t=t_,
                                adjoint=adjoint,
                                method='rk4',
                                options=dict(step_size=0.5))
            z[:, -1].sum().backward()
            variable_grads.append(func.variable.grad.clone())
            func.variable.grad.zero_()

        for elem in variable_grads[1:]:
            assert (elem == variable_grads[0]).all()
예제 #4
0
    def forward(self, coeffs):
        if self.interpolation == "cubic":
            X = torchcde.CubicSpline(coeffs)
        elif self.interpolation == "linear":
            X = torchcde.LinearInterpolation(coeffs)
        else:
            raise ValueError(
                "Only 'linear' and 'cubic' interpolation methods are implemented."
            )

        X0 = X.evaluate(X.interval[0])
        z0 = self.initial(X0)

        if self.return_sequences:
            pred_y = []
            times = torch.arange(X.interval[0], X.interval[1] + 1)
            z_t = torchcde.cdeint(X=X, z0=z0, func=self.func, t=times)
            for ti in times:
                z_ti = z_t[:, int(ti) - 1]
                pred_y.append(self.readout(z_ti))
            pred_y = torch.stack(pred_y)
        else:
            z_T = torchcde.cdeint(X=X, z0=z0, func=self.func, t=X.interval)
            z_T = z_T[:, 1]
            pred_y = self.readout(z_T)

        return pred_y
예제 #5
0
def test_tuple_input():
    xa = torch.rand(2, 10, 2)
    xb = torch.rand(10, 1)

    coeffs_a = torchcde.natural_cubic_coeffs(xa)
    coeffs_b = torchcde.natural_cubic_coeffs(xb)
    spline_a = torchcde.CubicSpline(coeffs_a)
    spline_b = torchcde.CubicSpline(coeffs_b)
    X = torchcde.TupleControl(spline_a, spline_b)

    def func(t, z):
        za, zb = z
        return za.sigmoid().unsqueeze(-1).repeat_interleave(2, dim=-1), zb.tanh().unsqueeze(-1)

    z0 = torch.rand(2, 3), torch.rand(5, requires_grad=True)
    out = torchcde.cdeint(X=X, func=func, z0=z0, t=X.interval, adjoint_params=())
    out[0].sum().backward()
    assert (z0[1].grad == 0).all()
예제 #6
0
def test_prod():
    x = torch.rand(2, 5, 1)
    X = torchcde.CubicSpline(torchcde.natural_cubic_coeffs(x))

    class F:
        def prod(self, t, z, dXdt):
            assert t.shape == ()
            assert z.shape == (2, 3)
            assert dXdt.shape == (2, 1)
            return -z * dXdt

    z0 = torch.rand(2, 3, requires_grad=True)
    out = torchcde.cdeint(X=X, func=F(), z0=z0, t=X.interval, adjoint_params=())
    out.sum().backward()
def test_short():
    for interp_fn in (torchcde.natural_cubic_coeffs, torchcde.natural_cubic_spline_coeffs):
        for use_t in (False, True):
            if use_t:
                t = torch.tensor([0., 1.])
            else:
                t = None
            values = torch.rand(2, 1)
            coeffs = interp_fn(values, t)
            spline = torchcde.CubicSpline(coeffs, t)
            coeffs2 = torchcde.linear_interpolation_coeffs(values, t)
            linear = torchcde.LinearInterpolation(coeffs2, t)
            batch_dims = []
            num_channels = 1
            _test_equal(batch_dims, num_channels, linear, spline, torch.float32, -1.5, 1.5, 1e-4)
예제 #8
0
def test_backend():
    x = torch.randn(1, 10, 2)
    coeffs = torchcde.natural_cubic_coeffs(x)
    X = torchcde.CubicSpline(coeffs)

    def func(t, z):
        return -z.unsqueeze(-1).expand(1, 3, 2)

    z0 = torch.randn(1, 3)

    torchdiffeq_out = torchcde.cdeint(X=X, func=func, z0=z0, t=X.interval, backend="torchdiffeq", method="midpoint",
                                      options=dict(step_size=1.0))
    torchsde_out = torchcde.cdeint(X=X, func=func, z0=z0, t=X.interval, backend="torchsde", method="midpoint", dt=1.0)

    torch.testing.assert_allclose(torchdiffeq_out, torchsde_out)
def test_interp():
    for interp_fn in (torchcde.natural_cubic_coeffs, torchcde.natural_cubic_spline_coeffs):
        for _ in range(3):
            for use_t in (True, False):
                for drop in (False, True):
                    num_points = torch.randint(low=5, high=100, size=(1,)).item()
                    half_num_points = num_points // 2
                    num_points = 2 * half_num_points + 1
                    if use_t:
                        times1 = torch.rand(half_num_points, dtype=torch.float64) - 1
                        times2 = torch.rand(half_num_points, dtype=torch.float64)
                        t = torch.cat([times1, times2, torch.tensor([0.], dtype=torch.float64)]).sort().values
                        t_ = t
                        start, end = -1.5, 1.5
                        del times1, times2
                    else:
                        t = torch.linspace(0, num_points - 1, num_points, dtype=torch.float64)
                        t_ = None
                        start = 0
                        end = num_points - 0.5
                    num_channels = torch.randint(low=1, high=3, size=(1,)).item()
                    num_batch_dims = torch.randint(low=0, high=3, size=(1,)).item()
                    batch_dims = []
                    for _ in range(num_batch_dims):
                        batch_dims.append(torch.randint(low=1, high=3, size=(1,)).item())
                    if use_t:
                        cubic = _Cubic(batch_dims, num_channels, start=t[0], end=t[-1])
                        knot = 0
                    else:
                        cubic = _Offset(batch_dims, num_channels, start=t[0], end=t[-1], offset=t[1] - t[0])
                        knot = t[1] - t[0]
                    values = cubic.evaluate(t)
                    if drop:
                        for values_slice in values.unbind(dim=-1):
                            num_drop = int(num_points * torch.randint(low=1, high=4, size=(1,)).item() / 10)
                            num_drop = min(num_drop, num_points - 4)
                            # don't drop first or last
                            to_drop = torch.randperm(num_points - 2)[:num_drop] + 1
                            to_drop = [x for x in to_drop if x != knot]
                            values_slice[..., to_drop] = float('nan')
                            del num_drop, to_drop, values_slice
                    coeffs = interp_fn(values, t_)
                    spline = torchcde.CubicSpline(coeffs, t_)
                    _test_equal(batch_dims, num_channels, cubic, spline, torch.float64, start, end, 1e-3)
예제 #10
0
def test_shape(backend, method, kwargs):
    for _ in range(5):
        num_points = torch.randint(low=5, high=100, size=(1,)).item()
        num_channels = torch.randint(low=1, high=3, size=(1,)).item()
        num_hidden_channels = torch.randint(low=1, high=5, size=(1,)).item()
        if backend == "torchdiffeq":
            num_batch_dims = torch.randint(low=0, high=3, size=(1,)).item()
            batch_dims = []
            for _ in range(num_batch_dims):
                batch_dims.append(torch.randint(low=1, high=3, size=(1,)).item())
        elif backend == "torchsde":
            num_batch_dims = 1
            batch_dims = [torch.randint(low=1, high=3, size=(1,)).item()]
        else:
            raise ValueError

        values = torch.rand(*batch_dims, num_points, num_channels)

        coeffs = torchcde.natural_cubic_coeffs(values)
        spline = torchcde.CubicSpline(coeffs)

        class _Func(torch.nn.Module):
            def __init__(self):
                super(_Func, self).__init__()
                self.variable = torch.nn.Parameter(torch.rand(*[1 for _ in range(num_batch_dims)], 1, num_channels))

            def forward(self, t, z):
                return z.sigmoid().unsqueeze(-1) + self.variable

        f = _Func()
        z0 = torch.rand(*batch_dims, num_hidden_channels)

        num_out_times = torch.randint(low=2, high=10, size=(1,)).item()
        start, end = spline.interval
        out_times = torch.rand(num_out_times, dtype=torch.float64).sort().values * (end - start) + start

        out = torchcde.cdeint(spline, f, z0, out_times, backend=backend, method=method, rtol=1e-1, atol=1e-1, **kwargs)
        assert out.shape == (*batch_dims, num_out_times, num_hidden_channels)
def test_linear():
    for interp_fn in (torchcde.natural_cubic_coeffs, torchcde.natural_cubic_spline_coeffs):
        for use_t in (False, True):
            start = torch.rand(1).item() * 5 - 2.5
            end = torch.rand(1).item() * 5 - 2.5
            start, end = min(start, end), max(start, end)
            num_points = torch.randint(low=2, high=10, size=(1,)).item()
            num_channels = torch.randint(low=1, high=4, size=(1,)).item()
            m = torch.rand(num_channels) * 5 - 2.5
            c = torch.rand(num_channels) * 5 - 2.5
            if use_t:
                t = torch.linspace(start, end, num_points)
                t_ = t
            else:
                t = torch.linspace(0, num_points - 1, num_points)
                t_ = None
            values = m * t.unsqueeze(-1) + c
            coeffs = interp_fn(values, t_)
            spline = torchcde.CubicSpline(coeffs, t_)
            coeffs2 = torchcde.linear_interpolation_coeffs(values, t_)
            linear = torchcde.LinearInterpolation(coeffs2, t_)
            batch_dims = []
            _test_equal(batch_dims, num_channels, linear, spline, torch.float32, -1.5, 1.5, 1e-4)
예제 #12
0
def test_grad_paths():
    for method in ('rk4', 'dopri5'):
        for adjoint in (True, False):
            t = torch.linspace(0, 9, 10, requires_grad=True)
            path = torch.rand(1, 10, 3, requires_grad=True)
            coeffs = torchcde.natural_cubic_coeffs(path, t)
            cubic_spline = torchcde.CubicSpline(coeffs, t)
            z0 = torch.rand(1, 3, requires_grad=True)
            func = _Func(input_size=3, hidden_size=3)
            t_ = torch.tensor([0., 9.], requires_grad=True)

            if adjoint:
                kwargs = dict(adjoint_params=tuple(func.parameters()) +
                              (coeffs, t))
            else:
                kwargs = {}
            z = torchcde.cdeint(X=cubic_spline,
                                func=func,
                                z0=z0,
                                t=t_,
                                adjoint=adjoint,
                                method=method,
                                rtol=1e-4,
                                atol=1e-6,
                                **kwargs)
            assert z.shape == (1, 2, 3)
            assert t.grad is None
            assert path.grad is None
            assert z0.grad is None
            assert func.variable.grad is None
            assert t_.grad is None
            z[:, 1].sum().backward()
            assert isinstance(t.grad, torch.Tensor)
            assert isinstance(path.grad, torch.Tensor)
            assert isinstance(z0.grad, torch.Tensor)
            assert isinstance(func.variable.grad, torch.Tensor)
            assert isinstance(t_.grad, torch.Tensor)
def test_specification_and_derivative():
    for interp_fn in (torchcde.natural_cubic_coeffs, torchcde.natural_cubic_spline_coeffs):
        for _ in range(10):
            for use_t in (False, True):
                for num_batch_dims in (0, 1, 2, 3):
                    batch_dims = []
                    for _ in range(num_batch_dims):
                        batch_dims.append(torch.randint(low=1, high=3, size=(1,)).item())
                    length = torch.randint(low=5, high=10, size=(1,)).item()
                    channels = torch.randint(low=1, high=5, size=(1,)).item()
                    if use_t:
                        t = torch.linspace(0, 1, length, dtype=torch.float64)
                    else:
                        t = torch.linspace(0, length - 1, length, dtype=torch.float64)
                    x = torch.rand(*batch_dims, length, channels, dtype=torch.float64)
                    coeffs = interp_fn(x, t)
                    spline = torchcde.CubicSpline(coeffs, t)
                    # Test specification
                    for i, point in enumerate(t):
                        evaluate = spline.evaluate(point)
                        xi = x[..., i, :]
                        assert evaluate.allclose(xi, atol=1e-5, rtol=1e-5)
                    # Test derivative
                    for point in torch.rand(100, dtype=torch.float64):
                        point_with_grad = point.detach().requires_grad_(True)
                        evaluate = spline.evaluate(point_with_grad)
                        derivative = spline.derivative(point)
                        autoderivative = []
                        for elem in evaluate.view(-1):
                            elem.backward(retain_graph=True)
                            with torch.no_grad():
                                autoderivative.append(point_with_grad.grad.clone())
                            point_with_grad.grad.zero_()
                        autoderivative = torch.stack(autoderivative).view(*evaluate.shape)
                        assert derivative.shape == autoderivative.shape
                        assert derivative.allclose(autoderivative, atol=1e-5, rtol=1e-5)