示例#1
0
    def forward(self, coeffs):
        if self.interpolation == "cubic":
            X = torchcde.CubicSpline(coeffs)
        elif self.interpolation == "linear":
            X = torchcde.LinearInterpolation(coeffs)
        else:
            raise ValueError(
                "Only 'linear' and 'cubic' interpolation methods are implemented."
            )

        X0 = X.evaluate(X.interval[0])
        z0 = self.initial(X0)

        if self.return_sequences:
            pred_y = []
            times = torch.arange(X.interval[0], X.interval[1] + 1)
            z_t = torchcde.cdeint(X=X, z0=z0, func=self.func, t=times)
            for ti in times:
                z_ti = z_t[:, int(ti) - 1]
                pred_y.append(self.readout(z_ti))
            pred_y = torch.stack(pred_y)
        else:
            z_T = torchcde.cdeint(X=X, z0=z0, func=self.func, t=X.interval)
            z_T = z_T[:, 1]
            pred_y = self.readout(z_T)

        return pred_y
示例#2
0
def test_stacked_paths():
    class Record(torch.autograd.Function):
        @staticmethod
        def forward(ctx, name, x):
            ctx.name = name
            return x

        @staticmethod
        def backward(ctx, x):
            if hasattr(ctx, 'been_here_before'):
                pytest.fail(ctx.name)
            ctx.been_here_before = True
            return None, x

    ReparameterisedLinearInterpolation = ft.partial(
        torchcde.LinearInterpolation, reparameterise='bump')
    coeff_paths = [
        (torchcde.linear_interpolation_coeffs, torchcde.LinearInterpolation),
        (torchcde.linear_interpolation_coeffs,
         ReparameterisedLinearInterpolation),
        (torchcde.natural_cubic_coeffs, torchcde.NaturalCubicSpline)
    ]
    for adjoint in (False, True):
        for first_coeffs, First in coeff_paths:
            for second_coeffs, Second in coeff_paths:
                first_path = torch.rand(1, 1000, 2, requires_grad=True)
                first_coeff = first_coeffs(first_path)
                first_X = First(first_coeff)
                first_func = _Func(input_size=2, hidden_size=2)

                second_t = torch.linspace(0, 999, 100)
                second_path = torchcde.cdeint(X=first_X,
                                              func=first_func,
                                              z0=torch.rand(1, 2),
                                              t=second_t,
                                              adjoint=adjoint,
                                              method='rk4',
                                              options=dict(step_size=10))
                second_path = Record.apply('second', second_path)
                second_coeff = second_coeffs(second_path, second_t)
                second_X = Second(second_coeff, second_t)
                second_func = _Func(input_size=2, hidden_size=2)

                third_t = torch.linspace(0, 999, 10)
                third_path = torchcde.cdeint(X=second_X,
                                             func=second_func,
                                             z0=torch.rand(1, 2),
                                             t=third_t,
                                             adjoint=adjoint,
                                             method='rk4',
                                             options=dict(step_size=10))
                third_path = Record.apply('third', third_path)
                assert first_func.variable.grad is None
                assert second_func.variable.grad is None
                assert first_path.grad is None
                third_path[:, -1].sum().backward()
                assert isinstance(second_func.variable.grad, torch.Tensor)
                assert isinstance(first_func.variable.grad, torch.Tensor)
                assert isinstance(first_path.grad, torch.Tensor)
示例#3
0
def test_backend():
    x = torch.randn(1, 10, 2)
    coeffs = torchcde.natural_cubic_coeffs(x)
    X = torchcde.CubicSpline(coeffs)

    def func(t, z):
        return -z.unsqueeze(-1).expand(1, 3, 2)

    z0 = torch.randn(1, 3)

    torchdiffeq_out = torchcde.cdeint(X=X, func=func, z0=z0, t=X.interval, backend="torchdiffeq", method="midpoint",
                                      options=dict(step_size=1.0))
    torchsde_out = torchcde.cdeint(X=X, func=func, z0=z0, t=X.interval, backend="torchsde", method="midpoint", dt=1.0)

    torch.testing.assert_allclose(torchdiffeq_out, torchsde_out)
def _solve_cde(x):
    # x should be of shape (batch, length, channels)

    batch_size = x.size(0)
    input_channels = x.size(2)
    hidden_channels = 4  # hyperparameter, we can pick whatever we want for this

    coeffs = torchcde.natural_cubic_spline_coeffs(x)
    X = torchcde.NaturalCubicSpline(coeffs)
    z0 = torch.rand(batch_size, hidden_channels)

    class F(torch.nn.Module):
        def __init__(self):
            super(F, self).__init__()
            self.linear = torch.nn.Linear(hidden_channels,
                                          hidden_channels * input_channels)

        def forward(self, t, z):
            return self.linear(z).view(batch_size, hidden_channels, input_channels)

    func = F()
    zt = torchcde.cdeint(X=X, func=func, z0=z0, t=X.interval)
    zT = zt[:, -1]  # get the terminal value of the CDE

    return zT
示例#5
0
 def forward(self, coeffs):
     X = torchcde.CubicSpline(coeffs)
     X0 = X.evaluate(X.interval[0])
     z0 = self.initial(X0)
     zt = torchcde.cdeint(X=X, func=self.func, z0=z0, t=X.interval)
     zT = zt[..., -1, :]  # get the terminal value of the CDE
     return self.readout(zT)
示例#6
0
    def forward(self, coeffs):
        if self.interpolation == 'cubic':
            X = torchcde.CubicSpline(coeffs)
        elif self.interpolation == 'linear':
            X = torchcde.LinearInterpolation(coeffs)
        else:
            raise ValueError(
                "Only 'linear' and 'cubic' interpolation methods are implemented."
            )

        ######################
        # Easy to forget gotcha: Initial hidden state should be a function of the first observation.
        ######################
        X0 = X.evaluate(X.interval[0])
        z0 = self.initial(X0)

        ######################
        # Actually solve the CDE.
        ######################
        z_T = torchcde.cdeint(X=X, z0=z0, func=self.func, t=X.interval)

        ######################
        # Both the initial value and the terminal value are returned from cdeint; extract just the terminal value,
        # and then apply a linear map.
        ######################
        z_T = z_T[:, 1]
        pred_y = self.readout(z_T)
        return pred_y
示例#7
0
    def forward(self, coeffs, y):
        # coeffs is of shape (batch, length, input_channels) if using any linear interpolation
        # y is of shape (batch,)

        X = torchcde.LinearInterpolation(coeffs, self.times)
        z0 = self.initial(X.evaluate(self.times[0]))
        options = dict(grid_points=X.grid_points, eps=1e-5)
        adjoint_options = options.copy()
        if self.norm:
            adjoint_options['norm'] = common.make_norm(z0)
        z_t = torchcde.cdeint(X=X,
                              z0=z0,
                              func=self.func,
                              t=self.times[[0, -1]],
                              rtol=self.rtol,
                              atol=self.atol,
                              options=options,
                              adjoint_options=adjoint_options)
        z_T = z_t[:, -1]
        pred_y = self.readout(z_T)

        loss = torch.nn.functional.cross_entropy(pred_y, y)
        thresholded_y = torch.argmax(pred_y, dim=1)
        accuracy = (thresholded_y == y).sum().to(pred_y.dtype)
        return loss, accuracy
示例#8
0
def test_grad_paths():
    for method in ('rk4', 'dopri5'):
        for adjoint in (True, False):
            t = torch.linspace(0, 9, 10, requires_grad=True)
            path = torch.rand(1, 10, 3, requires_grad=True)
            coeffs = torchcde.natural_cubic_coeffs(path, t)
            cubic_spline = torchcde.NaturalCubicSpline(coeffs, t)
            z0 = torch.rand(1, 3, requires_grad=True)
            func = _Func(input_size=3, hidden_size=3)
            t_ = torch.tensor([0., 9.], requires_grad=True)

            z = torchcde.cdeint(X=cubic_spline,
                                func=func,
                                z0=z0,
                                t=t_,
                                adjoint=adjoint,
                                method=method,
                                rtol=1e-4,
                                atol=1e-6)
            assert z.shape == (1, 2, 3)
            assert t.grad is None
            assert path.grad is None
            assert z0.grad is None
            assert func.variable.grad is None
            assert t_.grad is None
            z[:, 1].sum().backward()
            assert isinstance(t.grad, torch.Tensor)
            assert isinstance(path.grad, torch.Tensor)
            assert isinstance(z0.grad, torch.Tensor)
            assert isinstance(func.variable.grad, torch.Tensor)
            assert isinstance(t_.grad, torch.Tensor)
示例#9
0
def test_detach_trick():
    path = torch.rand(1, 10, 3)
    func = _Func(input_size=3, hidden_size=3)

    def interp_():
        coeffs = torchcde.natural_cubic_coeffs(path)
        yield torchcde.NaturalCubicSpline(coeffs)
        coeffs = torchcde.linear_interpolation_coeffs(path)
        yield torchcde.LinearInterpolation(coeffs, reparameterise='bump')

    for interp in interp_():
        for adjoint in (True, False):
            variable_grads = []
            z0 = torch.rand(1, 3)
            for t_grad in (True, False):
                t_ = torch.tensor([0., 9.], requires_grad=t_grad)
                # Don't test dopri5. We will get different results then, because the t variable will force smaller step
                # sizes and thus slightly different results.
                z = torchcde.cdeint(X=interp,
                                    z0=z0,
                                    func=func,
                                    t=t_,
                                    adjoint=adjoint,
                                    method='rk4',
                                    options=dict(step_size=0.5))
                z[:, -1].sum().backward()
                variable_grads.append(func.variable.grad.clone())
                func.variable.grad.zero_()

            for elem in variable_grads[1:]:
                assert (elem == variable_grads[0]).all()
示例#10
0
def test_prod():
    x = torch.rand(2, 5, 1)
    X = torchcde.NaturalCubicSpline(torchcde.natural_cubic_coeffs(x))

    class F:
        def prod(self, t, z, dXdt):
            assert t.shape == ()
            assert z.shape == (2, 3)
            assert dXdt.shape == (2, 1)
            return -z * dXdt

    z0 = torch.rand(2, 3, requires_grad=True)
    out = torchcde.cdeint(X=X, func=F(), z0=z0, t=X.interval, adjoint_params=())
    out.sum().backward()
示例#11
0
    def forward(self, coeffs, times, interpolation_method):
        X, cdeint_options = build_data_path(coeffs, times, interpolation_method)
        z0 = self.initial(X.evaluate(X.interval[0])) # initial hidden state must be a function of the first observation
        if self.batch_norm:
            z0 = self.bn_initial(z0)
        z_t = torchcde.cdeint(X=X, z0=z0, func=self.vector_field, t=X.interval, options=cdeint_options) # t=times[[0, -1]] is the same (but only when times is not None...)

        # Both z0 and z_T are returned from cdeint, extract just last value
        z_T = z_t[:, -1]
        pred_y = self.readout(z_T)
        if self.batch_norm:
            pred_y = self.bn_output(pred_y)        
        pred_y = torch.nn.functional.softmax(pred_y, dim=-1) # New. Added a soft-max to get soft assignments that work with the multi-class cross entropy loss function
        return pred_y
示例#12
0
def test_shape():
    for method in ('rk4', 'dopri5'):
        for _ in range(10):
            num_points = torch.randint(low=5, high=100, size=(1, )).item()
            num_channels = torch.randint(low=1, high=3, size=(1, )).item()
            num_hidden_channels = torch.randint(low=1, high=5,
                                                size=(1, )).item()
            num_batch_dims = torch.randint(low=0, high=3, size=(1, )).item()
            batch_dims = []
            for _ in range(num_batch_dims):
                batch_dims.append(
                    torch.randint(low=1, high=3, size=(1, )).item())

            values = torch.rand(*batch_dims, num_points, num_channels)

            coeffs = torchcde.natural_cubic_coeffs(values)
            spline = torchcde.NaturalCubicSpline(coeffs)

            class _Func(torch.nn.Module):
                def __init__(self):
                    super(_Func, self).__init__()
                    self.variable = torch.nn.Parameter(
                        torch.rand(*[1 for _ in range(num_batch_dims)], 1,
                                   num_channels))

                def forward(self, t, z):
                    return z.sigmoid().unsqueeze(-1) + self.variable

            f = _Func()
            z0 = torch.rand(*batch_dims, num_hidden_channels)

            num_out_times = torch.randint(low=2, high=10, size=(1, )).item()
            start, end = spline.interval
            out_times = torch.rand(num_out_times, dtype=torch.float64).sort(
            ).values * (end - start) + start

            options = {}
            if method == 'rk4':
                options['step_size'] = 1. / num_points
            out = torchcde.cdeint(spline,
                                  f,
                                  z0,
                                  out_times,
                                  method=method,
                                  options=options,
                                  rtol=1e-4,
                                  atol=1e-6)
            assert out.shape == (*batch_dims, num_out_times,
                                 num_hidden_channels)
示例#13
0
    def forward(self, x):
        # NOTE: x should be the natural cubic spline coefficients. Look into datasets.py for how to generate these.
        x = torchcde.natural_cubic_coeffs(x)
        if self.interpolation == "cubic":
            x = torchcde.NaturalCubicSpline(x)
        elif self.interpolation == "linear":
            x = torchcde.LinearInterpolation(x)
        else:
            raise ValueError("invalid interpolation given")

        x0 = x.evaluate(x.interval[0])
        z0 = self.initial(x0)
        zt = torchcde.cdeint(X=x, func=self.model, z0=z0, t=x.interval)

        return self.output(zt[..., -1, :])
示例#14
0
def test_tuple_input():
    xa = torch.rand(2, 10, 2)
    xb = torch.rand(10, 1)

    coeffs_a = torchcde.natural_cubic_coeffs(xa)
    coeffs_b = torchcde.natural_cubic_coeffs(xb)
    spline_a = torchcde.NaturalCubicSpline(coeffs_a)
    spline_b = torchcde.NaturalCubicSpline(coeffs_b)
    X = torchcde.TupleControl(spline_a, spline_b)

    def func(t, z):
        za, zb = z
        return za.sigmoid().unsqueeze(-1).repeat_interleave(2, dim=-1), zb.tanh().unsqueeze(-1)

    z0 = torch.rand(2, 3), torch.rand(5, requires_grad=True)
    out = torchcde.cdeint(X=X, func=func, z0=z0, t=X.interval, adjoint_params=())
    out[0].sum().backward()
    assert (z0[1].grad == 0).all()
示例#15
0
    def forward(self, ys_coeffs):
        # ys_coeffs has shape (batch_size, t_size, 1 + data_size)
        # The +1 corresponds to time. When solving CDEs, It turns out to be most natural to treat time as just another
        # channel: in particular this makes handling irregular data quite easy, when the times may be different between
        # different samples in the batch.

        Y = torchcde.LinearInterpolation(ys_coeffs)
        Y0 = Y.evaluate(Y.interval[0])

        h0 = self._initial(Y0)
        hs = torchcde.cdeint(
            Y,
            self._func,
            h0,
            Y.interval,
            adjoint=False,
            method='midpoint',
            options=dict(step_size=1.0))  # shape (batch_size, 2, hidden_size)
        score = self._readout(hs[:, -1])
        return score.mean()
示例#16
0
    def forward(self, ys_coeffs):
        # ys_coeffs has shape (batch_size, t_size, 1 + data_size)
        # The +1 corresponds to time. When solving CDEs, It turns out to be most natural to treat time as just another
        # channel: in particular this makes handling irregular data quite easy, when the times may be different between
        # different samples in the batch.

        Y = torchcde.LinearInterpolation(ys_coeffs)
        Y0 = Y.evaluate(Y.interval[0])
        h0 = self._initial(Y0)
        hs = torchcde.cdeint(Y,
                             self._func,
                             h0,
                             Y.interval,
                             method='reversible_heun',
                             backend='torchsde',
                             dt=1.0,
                             adjoint_method='adjoint_reversible_heun',
                             adjoint_params=(ys_coeffs, ) +
                             tuple(self._func.parameters()))
        score = self._readout(hs[:, -1])
        return score.mean()
示例#17
0
    def forward(self, coeffs):
        X = torchcde.NaturalCubicSpline(coeffs)

        ######################
        # Easy to forget gotcha: Initial hidden state should be a function of the first observation.
        ######################
        X0 = X.evaluate(X.interval[0])
        z0 = self.initial(X0)

        ######################
        # Actually solve the CDE.
        ######################
        z_T = torchcde.cdeint(X=X, z0=z0, func=self.func, t=X.interval)

        ######################
        # Both the initial value and the terminal value are returned from cdeint; extract just the terminal value,
        # and then apply a linear map.
        ######################
        z_T = z_T[:, 1]
        pred_y = self.readout(z_T)
        return pred_y
示例#18
0
文件: ncde.py 项目: jambo6/batteries
    def forward(self, inputs):
        # Handle h0 and inputs
        coeffs, h0 = self._setup_h0(inputs)

        # Make lin int
        data = self.spline(coeffs)

        # Perform the adjoint operation
        hidden = torchcde.cdeint(
            data,
            self.func,
            h0,
            data.grid_points,
            adjoint=self.adjoint,
            method=self.solver,
        )

        # Convert to outputs
        outputs = self._make_outputs(hidden)

        return outputs
示例#19
0
def test_shape(backend, method, kwargs):
    for _ in range(5):
        num_points = torch.randint(low=5, high=100, size=(1,)).item()
        num_channels = torch.randint(low=1, high=3, size=(1,)).item()
        num_hidden_channels = torch.randint(low=1, high=5, size=(1,)).item()
        if backend == "torchdiffeq":
            num_batch_dims = torch.randint(low=0, high=3, size=(1,)).item()
            batch_dims = []
            for _ in range(num_batch_dims):
                batch_dims.append(torch.randint(low=1, high=3, size=(1,)).item())
        elif backend == "torchsde":
            num_batch_dims = 1
            batch_dims = [torch.randint(low=1, high=3, size=(1,)).item()]
        else:
            raise ValueError

        values = torch.rand(*batch_dims, num_points, num_channels)

        coeffs = torchcde.natural_cubic_coeffs(values)
        spline = torchcde.CubicSpline(coeffs)

        class _Func(torch.nn.Module):
            def __init__(self):
                super(_Func, self).__init__()
                self.variable = torch.nn.Parameter(torch.rand(*[1 for _ in range(num_batch_dims)], 1, num_channels))

            def forward(self, t, z):
                return z.sigmoid().unsqueeze(-1) + self.variable

        f = _Func()
        z0 = torch.rand(*batch_dims, num_hidden_channels)

        num_out_times = torch.randint(low=2, high=10, size=(1,)).item()
        start, end = spline.interval
        out_times = torch.rand(num_out_times, dtype=torch.float64).sort().values * (end - start) + start

        out = torchcde.cdeint(spline, f, z0, out_times, backend=backend, method=method, rtol=1e-1, atol=1e-1, **kwargs)
        assert out.shape == (*batch_dims, num_out_times, num_hidden_channels)