예제 #1
0
def _solve_cde(x):
    # x should be of shape (batch, length, channels)

    batch_size = x.size(0)
    input_channels = x.size(2)
    hidden_channels = 4  # hyperparameter, we can pick whatever we want for this

    coeffs = torchcde.natural_cubic_spline_coeffs(x)
    X = torchcde.NaturalCubicSpline(coeffs)
    z0 = torch.rand(batch_size, hidden_channels)

    class F(torch.nn.Module):
        def __init__(self):
            super(F, self).__init__()
            self.linear = torch.nn.Linear(hidden_channels,
                                          hidden_channels * input_channels)

        def forward(self, t, z):
            return self.linear(z).view(batch_size, hidden_channels, input_channels)

    func = F()
    zt = torchcde.cdeint(X=X, func=func, z0=z0, t=X.interval)
    zT = zt[:, -1]  # get the terminal value of the CDE

    return zT
예제 #2
0
def test_specification_and_derivative():
    for _ in range(10):
        for use_t in (False, True):
            for num_batch_dims in (0, 1, 2, 3):
                batch_dims = []
                for _ in range(num_batch_dims):
                    batch_dims.append(torch.randint(low=1, high=3, size=(1,)).item())
                length = torch.randint(low=5, high=10, size=(1,)).item()
                channels = torch.randint(low=1, high=5, size=(1,)).item()
                if use_t:
                    t = torch.linspace(0, 1, length, dtype=torch.float64)
                else:
                    t = torch.linspace(0, length - 1, length, dtype=torch.float64)
                x = torch.rand(*batch_dims, length, channels, dtype=torch.float64)
                coeffs = torchcde.natural_cubic_spline_coeffs(x, t)
                spline = torchcde.NaturalCubicSpline(coeffs, t)
                # Test specification
                for i, point in enumerate(t):
                    evaluate = spline.evaluate(point)
                    xi = x[..., i, :]
                    assert evaluate.allclose(xi, atol=1e-5, rtol=1e-5)
                # Test derivative
                for point in torch.rand(100, dtype=torch.float64):
                    point_with_grad = point.detach().requires_grad_(True)
                    evaluate = spline.evaluate(point_with_grad)
                    derivative = spline.derivative(point)
                    autoderivative = []
                    for elem in evaluate.view(-1):
                        elem.backward(retain_graph=True)
                        with torch.no_grad():
                            autoderivative.append(point_with_grad.grad.clone())
                        point_with_grad.grad.zero_()
                    autoderivative = torch.stack(autoderivative).view(*evaluate.shape)
                    assert derivative.shape == autoderivative.shape
                    assert derivative.allclose(autoderivative, atol=1e-5, rtol=1e-5)
예제 #3
0
def test_grad_paths():
    for method in ('rk4', 'dopri5'):
        for adjoint in (True, False):
            t = torch.linspace(0, 9, 10, requires_grad=True)
            path = torch.rand(1, 10, 3, requires_grad=True)
            coeffs = torchcde.natural_cubic_coeffs(path, t)
            cubic_spline = torchcde.NaturalCubicSpline(coeffs, t)
            z0 = torch.rand(1, 3, requires_grad=True)
            func = _Func(input_size=3, hidden_size=3)
            t_ = torch.tensor([0., 9.], requires_grad=True)

            z = torchcde.cdeint(X=cubic_spline,
                                func=func,
                                z0=z0,
                                t=t_,
                                adjoint=adjoint,
                                method=method,
                                rtol=1e-4,
                                atol=1e-6)
            assert z.shape == (1, 2, 3)
            assert t.grad is None
            assert path.grad is None
            assert z0.grad is None
            assert func.variable.grad is None
            assert t_.grad is None
            z[:, 1].sum().backward()
            assert isinstance(t.grad, torch.Tensor)
            assert isinstance(path.grad, torch.Tensor)
            assert isinstance(z0.grad, torch.Tensor)
            assert isinstance(func.variable.grad, torch.Tensor)
            assert isinstance(t_.grad, torch.Tensor)
예제 #4
0
 def forward(self, coeffs):
     X = torchcde.NaturalCubicSpline(coeffs)
     X0 = X.evaluate(X.interval[0])
     z0 = self.initial(X0)
     zt = torchcde.cdeint(X=X, func=self.func, z0=z0, t=X.interval)
     zT = zt[:, -1]  # get the terminal value of the CDE
     return self.readout(zT)
예제 #5
0
def test_linear():
    for interp_fn in (torchcde.natural_cubic_coeffs,
                      torchcde.natural_cubic_spline_coeffs):
        for use_t in (False, True):
            start = torch.rand(1).item() * 5 - 2.5
            end = torch.rand(1).item() * 5 - 2.5
            start, end = min(start, end), max(start, end)
            num_points = torch.randint(low=2, high=10, size=(1, )).item()
            num_channels = torch.randint(low=1, high=4, size=(1, )).item()
            m = torch.rand(num_channels) * 5 - 2.5
            c = torch.rand(num_channels) * 5 - 2.5
            if use_t:
                t = torch.linspace(start, end, num_points)
                t_ = t
            else:
                t = torch.linspace(0, num_points - 1, num_points)
                t_ = None
            values = m * t.unsqueeze(-1) + c
            coeffs = interp_fn(values, t_)
            spline = torchcde.NaturalCubicSpline(coeffs, t_)
            coeffs2 = torchcde.linear_interpolation_coeffs(values, t_)
            linear = torchcde.LinearInterpolation(coeffs2, t_)
            batch_dims = []
            _test_equal(batch_dims, num_channels, linear, spline,
                        torch.float32, -1.5, 1.5, 1e-4)
예제 #6
0
    def forward(self, coeffs):
        if self.interpolation == 'cubic':
            X = torchcde.NaturalCubicSpline(coeffs)
        elif self.interpolation == 'linear':
            X = torchcde.LinearInterpolation(coeffs)
        else:
            raise ValueError(
                "Only 'linear' and 'cubic' interpolation methods are implemented."
            )

        ######################
        # Easy to forget gotcha: Initial hidden state should be a function of the first observation.
        ######################
        X0 = X.evaluate(X.interval[0])
        z0 = self.initial(X0)

        ######################
        # Actually solve the CDE.
        ######################
        z_T = torchcde.cdeint(X=X, z0=z0, func=self.func, t=X.interval)

        ######################
        # Both the initial value and the terminal value are returned from cdeint; extract just the terminal value,
        # and then apply a linear map.
        ######################
        z_T = z_T[:, 1]
        pred_y = self.readout(z_T)
        return pred_y
예제 #7
0
def test_tuple_input():
    xa = torch.rand(2, 10, 2)
    xb = torch.rand(10, 1)

    coeffs_a = torchcde.natural_cubic_coeffs(xa)
    coeffs_b = torchcde.natural_cubic_coeffs(xb)
    spline_a = torchcde.NaturalCubicSpline(coeffs_a)
    spline_b = torchcde.NaturalCubicSpline(coeffs_b)
    X = torchcde.TupleControl(spline_a, spline_b)

    def func(t, z):
        za, zb = z
        return za.sigmoid().unsqueeze(-1).repeat_interleave(2, dim=-1), zb.tanh().unsqueeze(-1)

    z0 = torch.rand(2, 3), torch.rand(5, requires_grad=True)
    out = torchcde.cdeint(X=X, func=func, z0=z0, t=X.interval, adjoint_params=())
    out[0].sum().backward()
    assert (z0[1].grad == 0).all()
예제 #8
0
def test_prod():
    x = torch.rand(2, 5, 1)
    X = torchcde.NaturalCubicSpline(torchcde.natural_cubic_coeffs(x))

    class F:
        def prod(self, t, z, dXdt):
            assert t.shape == ()
            assert z.shape == (2, 3)
            assert dXdt.shape == (2, 1)
            return -z * dXdt

    z0 = torch.rand(2, 3, requires_grad=True)
    out = torchcde.cdeint(X=X, func=F(), z0=z0, t=X.interval, adjoint_params=())
    out.sum().backward()
예제 #9
0
def build_data_path(coeffs, times, interpolation_method):
    if interpolation_method == 'cubic':
        X = torchcde.NaturalCubicSpline(coeffs, t=times)
        cdeint_options = {}

    elif interpolation_method == 'linear':
        X = torchcde.LinearInterpolation(coeffs, t=times)
        cdeint_options = dict(grid_points=X.grid_points, eps=1e-5)

    elif interpolation_method == 'rectilinear': # rectifilinear doesn't work when passing time argument
        X = torchcde.LinearInterpolation(coeffs)
        cdeint_options = dict(grid_points=X.grid_points, eps=1e-5)

    return X, cdeint_options
예제 #10
0
def test_short():
    for use_t in (False, True):
        if use_t:
            t = torch.tensor([0., 1.])
        else:
            t = None
        values = torch.rand(2, 1)
        coeffs = torchcde.natural_cubic_spline_coeffs(values, t)
        spline = torchcde.NaturalCubicSpline(coeffs, t)
        coeffs2 = torchcde.linear_interpolation_coeffs(values, t)
        linear = torchcde.LinearInterpolation(coeffs2, t)
        batch_dims = []
        num_channels = 1
        _test_equal(batch_dims, num_channels, linear, spline, torch.float32, -1.5, 1.5, 1e-4)
예제 #11
0
def test_shape():
    for method in ('rk4', 'dopri5'):
        for _ in range(10):
            num_points = torch.randint(low=5, high=100, size=(1, )).item()
            num_channels = torch.randint(low=1, high=3, size=(1, )).item()
            num_hidden_channels = torch.randint(low=1, high=5,
                                                size=(1, )).item()
            num_batch_dims = torch.randint(low=0, high=3, size=(1, )).item()
            batch_dims = []
            for _ in range(num_batch_dims):
                batch_dims.append(
                    torch.randint(low=1, high=3, size=(1, )).item())

            values = torch.rand(*batch_dims, num_points, num_channels)

            coeffs = torchcde.natural_cubic_coeffs(values)
            spline = torchcde.NaturalCubicSpline(coeffs)

            class _Func(torch.nn.Module):
                def __init__(self):
                    super(_Func, self).__init__()
                    self.variable = torch.nn.Parameter(
                        torch.rand(*[1 for _ in range(num_batch_dims)], 1,
                                   num_channels))

                def forward(self, t, z):
                    return z.sigmoid().unsqueeze(-1) + self.variable

            f = _Func()
            z0 = torch.rand(*batch_dims, num_hidden_channels)

            num_out_times = torch.randint(low=2, high=10, size=(1, )).item()
            start, end = spline.interval
            out_times = torch.rand(num_out_times, dtype=torch.float64).sort(
            ).values * (end - start) + start

            options = {}
            if method == 'rk4':
                options['step_size'] = 1. / num_points
            out = torchcde.cdeint(spline,
                                  f,
                                  z0,
                                  out_times,
                                  method=method,
                                  options=options,
                                  rtol=1e-4,
                                  atol=1e-6)
            assert out.shape == (*batch_dims, num_out_times,
                                 num_hidden_channels)
예제 #12
0
    def forward(self, x):
        # NOTE: x should be the natural cubic spline coefficients. Look into datasets.py for how to generate these.
        x = torchcde.natural_cubic_coeffs(x)
        if self.interpolation == "cubic":
            x = torchcde.NaturalCubicSpline(x)
        elif self.interpolation == "linear":
            x = torchcde.LinearInterpolation(x)
        else:
            raise ValueError("invalid interpolation given")

        x0 = x.evaluate(x.interval[0])
        z0 = self.initial(x0)
        zt = torchcde.cdeint(X=x, func=self.model, z0=z0, t=x.interval)

        return self.output(zt[..., -1, :])
예제 #13
0
def test_backend():
    x = torch.randn(1, 10, 2)
    coeffs = torchcde.natural_cubic_coeffs(x)
    X = torchcde.NaturalCubicSpline(coeffs)

    def func(t, z):
        return -z.unsqueeze(-1).expand(1, 3, 2)

    z0 = torch.randn(1, 3)

    torchdiffeq_out = torchcde.cdeint(X=X, func=func, z0=z0, t=X.interval, backend="torchdiffeq", method="midpoint",
                                      options=dict(step_size=1.0))
    torchsde_out = torchcde.cdeint(X=X, func=func, z0=z0, t=X.interval, backend="torchsde", method="midpoint", dt=1.0)

    torch.testing.assert_allclose(torchdiffeq_out, torchsde_out)
예제 #14
0
def test_interp():
    for _ in range(3):
        for use_t in (True, False):
            for drop in (False, True):
                num_points = torch.randint(low=5, high=100, size=(1,)).item()
                half_num_points = num_points // 2
                num_points = 2 * half_num_points + 1
                if use_t:
                    times1 = torch.rand(half_num_points, dtype=torch.float64) - 1
                    times2 = torch.rand(half_num_points, dtype=torch.float64)
                    t = torch.cat([times1, times2, torch.tensor([0.], dtype=torch.float64)]).sort().values
                    t_ = t
                    start, end = -1.5, 1.5
                    del times1, times2
                else:
                    t = torch.linspace(0, num_points - 1, num_points, dtype=torch.float64)
                    t_ = None
                    start = 0
                    end = num_points - 0.5
                num_channels = torch.randint(low=1, high=3, size=(1,)).item()
                num_batch_dims = torch.randint(low=0, high=3, size=(1,)).item()
                batch_dims = []
                for _ in range(num_batch_dims):
                    batch_dims.append(torch.randint(low=1, high=3, size=(1,)).item())
                if use_t:
                    cubic = _Cubic(batch_dims, num_channels, start=t[0], end=t[-1])
                    knot = 0
                else:
                    cubic = _Offset(batch_dims, num_channels, start=t[0], end=t[-1], offset=t[1] - t[0])
                    knot = t[1] - t[0]
                values = cubic.evaluate(t)
                if drop:
                    for values_slice in values.unbind(dim=-1):
                        num_drop = int(num_points * torch.randint(low=1, high=4, size=(1,)).item() / 10)
                        num_drop = min(num_drop, num_points - 4)
                        # don't drop first or last
                        to_drop = torch.randperm(num_points - 2)[:num_drop] + 1
                        to_drop = [x for x in to_drop if x != knot]
                        values_slice[..., to_drop] = float('nan')
                        del num_drop, to_drop, values_slice
                coeffs = torchcde.natural_cubic_spline_coeffs(values, t_)
                spline = torchcde.NaturalCubicSpline(coeffs, t_)
                _test_equal(batch_dims, num_channels, cubic, spline, torch.float64, start, end, 1e-3)
예제 #15
0
    def forward(self, coeffs):
        X = torchcde.NaturalCubicSpline(coeffs)

        ######################
        # Easy to forget gotcha: Initial hidden state should be a function of the first observation.
        ######################
        X0 = X.evaluate(X.interval[0])
        z0 = self.initial(X0)

        ######################
        # Actually solve the CDE.
        ######################
        z_T = torchcde.cdeint(X=X, z0=z0, func=self.func, t=X.interval)

        ######################
        # Both the initial value and the terminal value are returned from cdeint; extract just the terminal value,
        # and then apply a linear map.
        ######################
        z_T = z_T[:, 1]
        pred_y = self.readout(z_T)
        return pred_y
예제 #16
0
def test_detach_trick():
    path = torch.rand(1, 10, 3)
    interp = torchcde.NaturalCubicSpline(torchcde.natural_cubic_coeffs(path))

    func = _Func(input_size=3, hidden_size=3)

    for adjoint in (True, False):
        variable_grads = []
        z0 = torch.rand(1, 3)
        for t_grad in (True, False):
            t_ = torch.tensor([0., 9.], requires_grad=t_grad)
            # Don't test dopri5. We will get different results then, because the t variable will force smaller step
            # sizes and thus slightly different results.
            z = torchcde.cdeint(X=interp, z0=z0, func=func, t=t_, adjoint=adjoint, method='rk4',
                                options=dict(step_size=0.5))
            z[:, -1].sum().backward()
            variable_grads.append(func.variable.grad.clone())
            func.variable.grad.zero_()

        for elem in variable_grads[1:]:
            assert (elem == variable_grads[0]).all()
예제 #17
0
def test_shape(backend, method, kwargs):
    for _ in range(5):
        num_points = torch.randint(low=5, high=100, size=(1,)).item()
        num_channels = torch.randint(low=1, high=3, size=(1,)).item()
        num_hidden_channels = torch.randint(low=1, high=5, size=(1,)).item()
        if backend == "torchdiffeq":
            num_batch_dims = torch.randint(low=0, high=3, size=(1,)).item()
            batch_dims = []
            for _ in range(num_batch_dims):
                batch_dims.append(torch.randint(low=1, high=3, size=(1,)).item())
        elif backend == "torchsde":
            num_batch_dims = 1
            batch_dims = [torch.randint(low=1, high=3, size=(1,)).item()]
        else:
            raise ValueError

        values = torch.rand(*batch_dims, num_points, num_channels)

        coeffs = torchcde.natural_cubic_coeffs(values)
        spline = torchcde.NaturalCubicSpline(coeffs)

        class _Func(torch.nn.Module):
            def __init__(self):
                super(_Func, self).__init__()
                self.variable = torch.nn.Parameter(torch.rand(*[1 for _ in range(num_batch_dims)], 1, num_channels))

            def forward(self, t, z):
                return z.sigmoid().unsqueeze(-1) + self.variable

        f = _Func()
        z0 = torch.rand(*batch_dims, num_hidden_channels)

        num_out_times = torch.randint(low=2, high=10, size=(1,)).item()
        start, end = spline.interval
        out_times = torch.rand(num_out_times, dtype=torch.float64).sort().values * (end - start) + start

        out = torchcde.cdeint(spline, f, z0, out_times, backend=backend, method=method, rtol=1e-1, atol=1e-1, **kwargs)
        assert out.shape == (*batch_dims, num_out_times, num_hidden_channels)
예제 #18
0
 def interp_():
     coeffs = torchcde.natural_cubic_coeffs(path)
     yield torchcde.NaturalCubicSpline(coeffs)
     coeffs = torchcde.linear_interpolation_coeffs(path)
     yield torchcde.LinearInterpolation(coeffs, reparameterise='bump')