Example #1
0
def build_data_path(coeffs, times, interpolation_method):
    if interpolation_method == 'cubic':
        X = torchcde.NaturalCubicSpline(coeffs, t=times)
        cdeint_options = {}

    elif interpolation_method == 'linear':
        X = torchcde.LinearInterpolation(coeffs, t=times)
        cdeint_options = dict(grid_points=X.grid_points, eps=1e-5)

    elif interpolation_method == 'rectilinear': # rectifilinear doesn't work when passing time argument
        X = torchcde.LinearInterpolation(coeffs)
        cdeint_options = dict(grid_points=X.grid_points, eps=1e-5)

    return X, cdeint_options
Example #2
0
    def forward(self, coeffs, y):
        # coeffs is of shape (batch, length, input_channels) if using any linear interpolation
        # y is of shape (batch,)

        X = torchcde.LinearInterpolation(coeffs, self.times)
        z0 = self.initial(X.evaluate(self.times[0]))
        options = dict(grid_points=X.grid_points, eps=1e-5)
        adjoint_options = options.copy()
        if self.norm:
            adjoint_options['norm'] = common.make_norm(z0)
        z_t = torchcde.cdeint(X=X,
                              z0=z0,
                              func=self.func,
                              t=self.times[[0, -1]],
                              rtol=self.rtol,
                              atol=self.atol,
                              options=options,
                              adjoint_options=adjoint_options)
        z_T = z_t[:, -1]
        pred_y = self.readout(z_T)

        loss = torch.nn.functional.cross_entropy(pred_y, y)
        thresholded_y = torch.argmax(pred_y, dim=1)
        accuracy = (thresholded_y == y).sum().to(pred_y.dtype)
        return loss, accuracy
Example #3
0
def test_small():
    for use_t in (False, True):
        if use_t:
            start = torch.rand(1).item() * 10 - 5
            end = torch.rand(1).item() * 10 - 5
            start, end = min(start, end), max(start, end)
            t = torch.tensor([start, end], dtype=torch.float64)
            t_ = t
        else:
            start = 0
            end = 1
            t = torch.tensor([0., 1.], dtype=torch.float64)
            t_ = None
        x = torch.rand(2, 1, dtype=torch.float64)
        true_deriv = (x[1] - x[0]) / (end - start)
        coeffs = torchcde.linear_interpolation_coeffs(x, t=t_)
        linear = torchcde.LinearInterpolation(coeffs, t=t_)
        for time in torch.linspace(-1, 2, 100):
            true = x[0] + true_deriv * (time - t[0])
            pred = linear.evaluate(time)
            deriv = linear.derivative(time)
            assert true_deriv.shape == deriv.shape
            assert true_deriv.allclose(deriv)
            assert true.shape == pred.shape
            assert true.allclose(pred)
Example #4
0
    def forward(self, coeffs):
        if self.interpolation == 'cubic':
            X = torchcde.CubicSpline(coeffs)
        elif self.interpolation == 'linear':
            X = torchcde.LinearInterpolation(coeffs)
        else:
            raise ValueError(
                "Only 'linear' and 'cubic' interpolation methods are implemented."
            )

        ######################
        # Easy to forget gotcha: Initial hidden state should be a function of the first observation.
        ######################
        X0 = X.evaluate(X.interval[0])
        z0 = self.initial(X0)

        ######################
        # Actually solve the CDE.
        ######################
        z_T = torchcde.cdeint(X=X, z0=z0, func=self.func, t=X.interval)

        ######################
        # Both the initial value and the terminal value are returned from cdeint; extract just the terminal value,
        # and then apply a linear map.
        ######################
        z_T = z_T[:, 1]
        pred_y = self.readout(z_T)
        return pred_y
Example #5
0
def test_with_linear_interpolation():
    import signatory
    window_length = 4
    for depth in (1, 2, 3, 4):
        compute_logsignature = signatory.Logsignature(depth)
        for pieces in (1, 2, 3, 5, 10):
            num_channels = torch.randint(low=1, high=4, size=(1, )).item()
            x_ = [torch.randn(1, num_channels, dtype=torch.float64)]
            logsignatures = []
            for _ in range(pieces):
                x = torch.randn(window_length,
                                num_channels,
                                dtype=torch.float64)
                logsignature = compute_logsignature(
                    torch.cat([x_[-1][-1:], x]).unsqueeze(0))
                x_.append(x)
                logsignatures.append(logsignature)

            x = torch.cat(x_)

            logsig_x = torchcde.logsig_windows(x, depth, window_length)
            coeffs = torchcde.linear_interpolation_coeffs(logsig_x)
            X = torchcde.LinearInterpolation(coeffs)

            point = 0.5
            for logsignature in logsignatures:
                interp_logsignature = X.derivative(point)
                assert interp_logsignature.allclose(logsignature)
                point += 1
Example #6
0
    def forward(self, coeffs):
        if self.interpolation == "cubic":
            X = torchcde.CubicSpline(coeffs)
        elif self.interpolation == "linear":
            X = torchcde.LinearInterpolation(coeffs)
        else:
            raise ValueError(
                "Only 'linear' and 'cubic' interpolation methods are implemented."
            )

        X0 = X.evaluate(X.interval[0])
        z0 = self.initial(X0)

        if self.return_sequences:
            pred_y = []
            times = torch.arange(X.interval[0], X.interval[1] + 1)
            z_t = torchcde.cdeint(X=X, z0=z0, func=self.func, t=times)
            for ti in times:
                z_ti = z_t[:, int(ti) - 1]
                pred_y.append(self.readout(z_ti))
            pred_y = torch.stack(pred_y)
        else:
            z_T = torchcde.cdeint(X=X, z0=z0, func=self.func, t=X.interval)
            z_T = z_T[:, 1]
            pred_y = self.readout(z_T)

        return pred_y
Example #7
0
def test_linear():
    for interp_fn in (torchcde.natural_cubic_coeffs,
                      torchcde.natural_cubic_spline_coeffs):
        for use_t in (False, True):
            start = torch.rand(1).item() * 5 - 2.5
            end = torch.rand(1).item() * 5 - 2.5
            start, end = min(start, end), max(start, end)
            num_points = torch.randint(low=2, high=10, size=(1, )).item()
            num_channels = torch.randint(low=1, high=4, size=(1, )).item()
            m = torch.rand(num_channels) * 5 - 2.5
            c = torch.rand(num_channels) * 5 - 2.5
            if use_t:
                t = torch.linspace(start, end, num_points)
                t_ = t
            else:
                t = torch.linspace(0, num_points - 1, num_points)
                t_ = None
            values = m * t.unsqueeze(-1) + c
            coeffs = interp_fn(values, t_)
            spline = torchcde.NaturalCubicSpline(coeffs, t_)
            coeffs2 = torchcde.linear_interpolation_coeffs(values, t_)
            linear = torchcde.LinearInterpolation(coeffs2, t_)
            batch_dims = []
            _test_equal(batch_dims, num_channels, linear, spline,
                        torch.float32, -1.5, 1.5, 1e-4)
Example #8
0
def test_random():
    def _points():
        yield 2
        yield 3
        yield 100
        for _ in range(10):
            yield torch.randint(low=2, high=100, size=(1, )).item()

    for drop in (False, True):
        for use_t in (False, True):
            for num_points in _points():
                if use_t:
                    start = torch.rand(1).item() * 10 - 5
                    end = torch.rand(1).item() * 10 - 5
                    start, end = min(start, end), max(start, end)
                    t = torch.linspace(start,
                                       end,
                                       num_points,
                                       dtype=torch.float64)
                    t_ = t
                else:
                    t = torch.linspace(0,
                                       num_points - 1,
                                       num_points,
                                       dtype=torch.float64)
                    t_ = None
                num_channels = torch.randint(low=1, high=5, size=(1, )).item()
                m = torch.rand(num_channels, dtype=torch.float64) * 10 - 5
                c = torch.rand(num_channels, dtype=torch.float64) * 10 - 5
                values = m * t.unsqueeze(-1) + c

                values_clone = values.clone()
                if drop:
                    for values_slice in values_clone.unbind(dim=-1):
                        num_drop = int(
                            num_points *
                            torch.randint(low=1, high=4, size=(1, )).item() /
                            10)
                        num_drop = min(num_drop, num_points - 4)
                        to_drop = torch.randperm(
                            num_points -
                            2)[:num_drop] + 1  # don't drop first or last
                        values_slice[to_drop] = float('nan')

                coeffs = torchcde.linear_interpolation_coeffs(values_clone,
                                                              t=t_)
                linear = torchcde.LinearInterpolation(coeffs, t=t_)

                for time, value in zip(t, values):
                    linear_evaluate = linear.evaluate(time)
                    assert value.shape == linear_evaluate.shape
                    assert value.allclose(linear_evaluate,
                                          rtol=1e-4,
                                          atol=1e-6)
                    linear_derivative = linear.derivative(time)
                    assert m.shape == linear_derivative.shape
                    assert m.allclose(linear_derivative, rtol=1e-4, atol=1e-6)
Example #9
0
def test_specification_and_derivative():
    for use_t in (False, True):
        for reparameterise in ('none', 'bump'):
            for _ in range(10):
                for num_batch_dims in (0, 1, 2, 3):
                    batch_dims = []
                    for _ in range(num_batch_dims):
                        batch_dims.append(
                            torch.randint(low=1, high=3, size=(1, )).item())
                    length = torch.randint(low=5, high=10, size=(1, )).item()
                    channels = torch.randint(low=1, high=5, size=(1, )).item()
                    if use_t:
                        t = torch.linspace(0, 1, length, dtype=torch.float64)
                        t_ = t
                    else:
                        t = torch.linspace(0,
                                           length - 1,
                                           length,
                                           dtype=torch.float64)
                        t_ = None
                    x = torch.rand(*batch_dims,
                                   length,
                                   channels,
                                   dtype=torch.float64)
                    coeffs = torchcde.linear_interpolation_coeffs(x, t=t_)
                    spline = torchcde.LinearInterpolation(
                        coeffs, t=t_, reparameterise=reparameterise)
                    # Test specification
                    for i, point in enumerate(t):
                        evaluate = spline.evaluate(point)
                        xi = x[..., i, :]
                        assert evaluate.allclose(xi, atol=1e-5, rtol=1e-5)
                    # Test derivative
                    for point in torch.rand(100, dtype=torch.float64):
                        point_with_grad = point.detach().requires_grad_(True)
                        evaluate = spline.evaluate(point_with_grad)
                        derivative = spline.derivative(point)
                        autoderivative = []
                        for elem in evaluate.view(-1):
                            elem.backward(retain_graph=True)
                            with torch.no_grad():
                                autoderivative.append(
                                    point_with_grad.grad.clone())
                            point_with_grad.grad.zero_()
                        autoderivative = torch.stack(autoderivative).view(
                            *evaluate.shape)
                        assert derivative.shape == autoderivative.shape
                        assert derivative.allclose(autoderivative,
                                                   atol=1e-5,
                                                   rtol=1e-5)
Example #10
0
def test_short():
    for use_t in (False, True):
        if use_t:
            t = torch.tensor([0., 1.])
        else:
            t = None
        values = torch.rand(2, 1)
        coeffs = torchcde.natural_cubic_spline_coeffs(values, t)
        spline = torchcde.NaturalCubicSpline(coeffs, t)
        coeffs2 = torchcde.linear_interpolation_coeffs(values, t)
        linear = torchcde.LinearInterpolation(coeffs2, t)
        batch_dims = []
        num_channels = 1
        _test_equal(batch_dims, num_channels, linear, spline, torch.float32, -1.5, 1.5, 1e-4)
Example #11
0
    def forward(self, x):
        # NOTE: x should be the natural cubic spline coefficients. Look into datasets.py for how to generate these.
        x = torchcde.natural_cubic_coeffs(x)
        if self.interpolation == "cubic":
            x = torchcde.NaturalCubicSpline(x)
        elif self.interpolation == "linear":
            x = torchcde.LinearInterpolation(x)
        else:
            raise ValueError("invalid interpolation given")

        x0 = x.evaluate(x.interval[0])
        z0 = self.initial(x0)
        zt = torchcde.cdeint(X=x, func=self.model, z0=z0, t=x.interval)

        return self.output(zt[..., -1, :])
Example #12
0
    def forward(self, ys_coeffs):
        # ys_coeffs has shape (batch_size, t_size, 1 + data_size)
        # The +1 corresponds to time. When solving CDEs, It turns out to be most natural to treat time as just another
        # channel: in particular this makes handling irregular data quite easy, when the times may be different between
        # different samples in the batch.

        Y = torchcde.LinearInterpolation(ys_coeffs)
        Y0 = Y.evaluate(Y.interval[0])

        h0 = self._initial(Y0)
        hs = torchcde.cdeint(
            Y,
            self._func,
            h0,
            Y.interval,
            adjoint=False,
            method='midpoint',
            options=dict(step_size=1.0))  # shape (batch_size, 2, hidden_size)
        score = self._readout(hs[:, -1])
        return score.mean()
Example #13
0
    def forward(self, ys_coeffs):
        # ys_coeffs has shape (batch_size, t_size, 1 + data_size)
        # The +1 corresponds to time. When solving CDEs, It turns out to be most natural to treat time as just another
        # channel: in particular this makes handling irregular data quite easy, when the times may be different between
        # different samples in the batch.

        Y = torchcde.LinearInterpolation(ys_coeffs)
        Y0 = Y.evaluate(Y.interval[0])
        h0 = self._initial(Y0)
        hs = torchcde.cdeint(Y,
                             self._func,
                             h0,
                             Y.interval,
                             method='reversible_heun',
                             backend='torchsde',
                             dt=1.0,
                             adjoint_method='adjoint_reversible_heun',
                             adjoint_params=(ys_coeffs, ) +
                             tuple(self._func.parameters()))
        score = self._readout(hs[:, -1])
        return score.mean()
Example #14
0
def plot(ts, generator, dataloader, num_plot_samples, plot_locs):
    # Get samples
    real_samples, = next(iter(dataloader))
    assert num_plot_samples <= real_samples.size(0)
    real_samples = torchcde.LinearInterpolation(real_samples).evaluate(ts)
    real_samples = real_samples[..., 1]

    with torch.no_grad():
        generated_samples = generator(ts, real_samples.size(0)).cpu()
    generated_samples = torchcde.LinearInterpolation(
        generated_samples).evaluate(ts)
    generated_samples = generated_samples[..., 1]

    # Plot histograms
    for prop in plot_locs:
        time = int(prop * (real_samples.size(1) - 1))
        real_samples_time = real_samples[:, time]
        generated_samples_time = generated_samples[:, time]
        _, bins, _ = plt.hist(real_samples_time.cpu().numpy(),
                              bins=32,
                              alpha=0.7,
                              label='Real',
                              color='dodgerblue',
                              density=True)
        bin_width = bins[1] - bins[0]
        num_bins = int((generated_samples_time.max() -
                        generated_samples_time.min()).item() // bin_width)
        plt.hist(generated_samples_time.cpu().numpy(),
                 bins=num_bins,
                 alpha=0.7,
                 label='Generated',
                 color='crimson',
                 density=True)
        plt.legend()
        plt.xlabel('Value')
        plt.ylabel('Density')
        plt.title(f'Marginal distribution at time {time}.')
        plt.tight_layout()
        plt.show()

    real_samples = real_samples[:num_plot_samples]
    generated_samples = generated_samples[:num_plot_samples]

    # Plot samples
    real_first = True
    generated_first = True
    for real_sample_ in real_samples:
        kwargs = {'label': 'Real'} if real_first else {}
        plt.plot(ts.cpu(),
                 real_sample_.cpu(),
                 color='dodgerblue',
                 linewidth=0.5,
                 alpha=0.7,
                 **kwargs)
        real_first = False
    for generated_sample_ in generated_samples:
        kwargs = {'label': 'Generated'} if generated_first else {}
        plt.plot(ts.cpu(),
                 generated_sample_.cpu(),
                 color='crimson',
                 linewidth=0.5,
                 alpha=0.7,
                 **kwargs)
        generated_first = False
    plt.legend()
    plt.title(
        f"{num_plot_samples} samples from both real and generated distributions."
    )
    plt.tight_layout()
    plt.show()
Example #15
0
 def interp_():
     coeffs = torchcde.natural_cubic_coeffs(path)
     yield torchcde.NaturalCubicSpline(coeffs)
     coeffs = torchcde.linear_interpolation_coeffs(path)
     yield torchcde.LinearInterpolation(coeffs, reparameterise='bump')