Exemplo n.º 1
0
def get_data(batch_size, device):
    class OrnsteinUhlenbeckSDE(torch.nn.Module):
        sde_type = 'ito'
        noise_type = 'scalar'

        def __init__(self, mu, theta, sigma):
            super(OrnsteinUhlenbeckSDE, self).__init__()
            self.register_buffer('mu', torch.as_tensor(mu))
            self.register_buffer('theta', torch.as_tensor(theta))
            self.register_buffer('sigma', torch.as_tensor(sigma))

        def f(self, t, y):
            return self.mu * t - self.theta * y

        def g(self, t, y):
            return self.sigma.expand(y.size(0), 1, 1)

    dataset_size = 8192
    t_size = 64

    ou_sde = OrnsteinUhlenbeckSDE(mu=0.02, theta=0.1, sigma=0.4).to(device)
    y0 = torch.rand(dataset_size, device=device).unsqueeze(-1) * 2 - 1
    ts = torch.linspace(0, t_size - 1, t_size, device=device)
    ys = torchsde.sdeint(ou_sde, y0, ts, dt=1e-1)

    ###################
    # To demonstrate how to handle irregular data, then here we additionally drop some of the data (by setting it to
    # NaN.)
    ###################
    ys_num = ys.numel()
    to_drop = torch.randperm(ys_num)[:int(0.3 * ys_num)]
    ys.view(-1)[to_drop] = float('nan')

    ###################
    # Typically important to normalise data. Note that the data is normalised with respect to the statistics of the
    # initial data, _not_ the whole time series. This seems to help the learning process, presumably because if the
    # initial condition is wrong then it's pretty hard to learn the rest of the SDE correctly.
    ###################
    y0_flat = ys[0].view(-1)
    y0_not_nan = y0_flat.masked_select(~torch.isnan(y0_flat))
    ys = (ys - y0_not_nan.mean()) / y0_not_nan.std()

    ###################
    # As discussed, time must be included as a channel for the discriminator.
    ###################
    ys = torch.cat([
        ts.unsqueeze(0).unsqueeze(-1).expand(dataset_size, t_size, 1),
        ys.transpose(0, 1)
    ],
                   dim=2)
    # shape (dataset_size=1000, t_size=100, 1 + data_size=3)

    ###################
    # Package up.
    ###################
    data_size = ys.size(
        -1
    ) - 1  # How many channels the data has (not including time, hence the minus one).
    ys_coeffs = torchcde.linear_interpolation_coeffs(ys)  # as per neural CDEs.
    dataset = torch.utils.data.TensorDataset(ys_coeffs)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=batch_size,
                                             shuffle=True)

    return ts, data_size, dataloader
Exemplo n.º 2
0
def test_random():
    def _points():
        yield 2
        yield 3
        yield 100
        for _ in range(10):
            yield torch.randint(low=2, high=100, size=(1, )).item()

    for reparameterise in ('none', 'bump'):
        for drop in (False, True):
            for use_t in (False, True):
                for num_points in _points():
                    if use_t:
                        start = torch.rand(1).item() * 10 - 5
                        end = torch.rand(1).item() * 10 - 5
                        start, end = min(start, end), max(start, end)
                        t = torch.linspace(start,
                                           end,
                                           num_points,
                                           dtype=torch.float64)
                        t_ = t
                    else:
                        t = torch.linspace(0,
                                           num_points - 1,
                                           num_points,
                                           dtype=torch.float64)
                        t_ = None
                    num_channels = torch.randint(low=1, high=5,
                                                 size=(1, )).item()
                    m = torch.rand(num_channels, dtype=torch.float64) * 10 - 5
                    c = torch.rand(num_channels, dtype=torch.float64) * 10 - 5
                    values = m * t.unsqueeze(-1) + c

                    values_clone = values.clone()
                    if drop:
                        for values_slice in values_clone.unbind(dim=-1):
                            num_drop = int(num_points * torch.randint(
                                low=1, high=4, size=(1, )).item() / 10)
                            num_drop = min(num_drop, num_points - 4)
                            to_drop = torch.randperm(
                                num_points -
                                2)[:num_drop] + 1  # don't drop first or last
                            values_slice[to_drop] = float('nan')

                    coeffs = torchcde.linear_interpolation_coeffs(values_clone,
                                                                  t=t_)
                    linear = torchcde.LinearInterpolation(
                        coeffs, t=t_, reparameterise=reparameterise)

                    for time, value in zip(t, values):
                        linear_evaluate = linear.evaluate(time)
                        assert value.shape == linear_evaluate.shape
                        assert value.allclose(linear_evaluate,
                                              rtol=1e-4,
                                              atol=1e-6)
                        if reparameterise is False:
                            linear_derivative = linear.derivative(time)
                            assert m.shape == linear_derivative.shape
                            assert m.allclose(linear_derivative,
                                              rtol=1e-4,
                                              atol=1e-6)
Exemplo n.º 3
0
 def transform(self, data):
     return linear_interpolation_coeffs(data, rectilinear=self._rectilinear)
Exemplo n.º 4
0
def test_rectilinear_preparation():
    devices = ['cpu']
    if torch.cuda.is_available():
        devices.append('cuda')

    for device in devices:
        # Simple test
        nan = float('nan')
        t1 = torch.tensor([0.1, 0.2, 0.9]).view(-1, 1).to(device)
        t2 = torch.tensor([0.2, 0.3]).view(-1, 1).to(device)
        x1 = torch.tensor([0.4, nan, 1.1]).view(-1, 1).to(device)
        x2 = torch.tensor([nan, 2.]).view(-1, 1).to(device)
        x = torch.nn.utils.rnn.pad_sequence(
            [torch.cat((t1, x1), -1),
             torch.cat((t2, x2), -1)],
            batch_first=True,
            padding_value=nan)
        # We have to fill the time index forward because we currently dont allow nan times for rectilinear
        x[:, :, 0] = torchcde.misc.forward_fill(x[:, :, 0], fill_index=-1)
        # Build true solution
        x1_true = torch.tensor([[0.1, 0.2, 0.2, 0.9, 0.9],
                                [0.4, 0.4, 0.4, 0.4,
                                 1.1]]).T.view(-1, 2).to(device)
        x2_true = torch.tensor([[0.2, 0.3, 0.3, 0.3, 0.3],
                                [2., 2., 2., 2., 2.]]).T.view(-1, 2).to(device)
        rect_true = torch.stack((x1_true, x2_true))
        # Apply rectilinear and compare
        rectilinear = torchcde.linear_interpolation_coeffs(x, rectilinear=0)
        assert torch.equal(rect_true[~torch.isnan(rect_true)],
                           rectilinear[~torch.isnan(rectilinear)])
        # Test also if we swap time time dimension
        x_swap = x[:, :, [1, 0]]
        rectilinear_swap = torchcde.linear_interpolation_coeffs(x_swap,
                                                                rectilinear=1)
        rect_swp = rect_true[:, :, [1, 0]]
        assert torch.equal(rect_swp, rectilinear_swap)

        # Additionally try a 2d case
        assert torch.equal(
            rect_true[0],
            torchcde.linear_interpolation_coeffs(x[0], rectilinear=0))
        # And a 4d case
        x_4d = torch.stack([x, x])
        rect_true_4d = torch.stack([rect_true, rect_true])
        assert torch.equal(
            rect_true_4d,
            torchcde.linear_interpolation_coeffs(x_4d, rectilinear=0))

        # Ensure error is thrown if time has a nan value anywhere
        x_time_nan = x.clone()
        x_time_nan[0, 1, 0] = float('nan')
        pytest.raises(AssertionError,
                      torchcde.linear_interpolation_coeffs,
                      x_time_nan,
                      rectilinear=0)

        # Some randoms tests
        for _ in range(5):
            # Build some data with time
            t_starts = torch.randn(5).to(device)**2
            ts = [
                torch.linspace(s, s + 10,
                               torch.randint(2, 50, (1, )).item()).to(device)
                for s in t_starts
            ]
            xs = [torch.randn(len(t), 10 - 1).to(device) for t in ts]
            x = torch.nn.utils.rnn.pad_sequence([
                torch.cat([t_.view(-1, 1), x_], dim=1)
                for t_, x_ in zip(ts, xs)
            ],
                                                batch_first=True,
                                                padding_value=nan)
            # Add some random nans about the place
            mask = torch.randint(0,
                                 5, (x.size(0), x.size(1), x.size(2) - 1),
                                 dtype=torch.float).to(device)
            mask[mask == 0] = float('nan')
            x[:, :, 1:] = x[:, :, 1:] * mask
            # We have to fill the time index forward because we currently dont allow nan times for rectilinear
            x[:, :, 0] = torchcde.misc.forward_fill(x[:, :, 0], fill_index=-1)
            # Fill
            x_ffilled = torchcde.misc.forward_fill(x)
            # Compute the true solution
            N, L, C = x_ffilled.shape
            rect_true = torch.zeros(N, 2 * L - 1, C).to(device)
            lag = torch.cat([x_ffilled[:, 1:, [0]], x_ffilled[:, :-1, 1:]],
                            dim=-1)
            rect_true[:, ::2, ] = x_ffilled
            rect_true[:, 1::2] = lag
            # Need to backfill rect true
            # Rectilinear solution
            rectilinear = torchcde.linear_interpolation_coeffs(x,
                                                               rectilinear=0)
            assert torch.equal(rect_true[~torch.isnan(rect_true)],
                               rectilinear[~torch.isnan(rect_true)])
Exemplo n.º 5
0
 def interp_():
     coeffs = torchcde.natural_cubic_coeffs(path)
     yield torchcde.NaturalCubicSpline(coeffs)
     coeffs = torchcde.linear_interpolation_coeffs(path)
     yield torchcde.LinearInterpolation(coeffs, reparameterise='bump')
Exemplo n.º 6
0
def get_interpolation_coeffs(directory,
                             data,
                             times,
                             use_noskip,
                             reduced,
                             interpolation_method='cubic'):
    # Create new folder for storing coefficients as datasets (interpolation is expensive)
    if not os.path.exists(directory):
        os.mkdir(directory)
    # Dataset name
    noskip = 'noskip' if use_noskip else ''
    red = 'red' if reduced else ''
    timing = 'eqspaced' if times is None or interpolation_method == 'rectilinear' else 'irrspaced'  # quick naming fix. TODO: if another equally spaced time series timestamps is in times it wouldn't name it properly...
    coeffs_filename = f'{interpolation_method}_coeffs{noskip}{red}_{timing}.hdf5'
    absolute_coeffs_filename_path = os.path.join(directory, coeffs_filename)

    # Interpolate and save, or load it for use if it exists
    coefficients = {}
    if not os.path.exists(absolute_coeffs_filename_path):
        if interpolation_method == 'cubic':
            coefficients[
                'train_coeffs'] = torchcde.natural_cubic_spline_coeffs(
                    data['train_data'], t=times)
            coefficients['val_coeffs'] = torchcde.natural_cubic_spline_coeffs(
                data['val_data'], t=times)
            coefficients['test_coeffs'] = torchcde.natural_cubic_spline_coeffs(
                data['test_data'], t=times)

        elif interpolation_method == 'linear':
            coefficients[
                'train_coeffs'] = torchcde.linear_interpolation_coeffs(
                    data['train_data'], t=times)
            coefficients['val_coeffs'] = torchcde.linear_interpolation_coeffs(
                data['val_data'], t=times)
            coefficients['test_coeffs'] = torchcde.linear_interpolation_coeffs(
                data['test_data'], t=times)

        elif interpolation_method == 'rectilinear':  # rectifilinear doesn't work when passing time argument
            if timing == 'irrspaced':
                print(
                    'Warning: will do default equally spaced time array instead, rectifilinear interpolation currently works with it only.'
                )
            coefficients[
                'train_coeffs'] = torchcde.linear_interpolation_coeffs(
                    data['train_data'], rectilinear=0)
            coefficients['val_coeffs'] = torchcde.linear_interpolation_coeffs(
                data['val_data'], rectilinear=0)
            coefficients['test_coeffs'] = torchcde.linear_interpolation_coeffs(
                data['test_data'], rectilinear=0)

        # Save coefficients in the new directory
        print('Saving interpolation coefficients ...')
        train_nobs, train_ntimes, train_nfeatures = coefficients[
            'train_coeffs'].shape
        val_nobs, val_ntimes, val_nfeatures = coefficients['val_coeffs'].shape
        test_nobs, test_ntimes, test_nfeatures = coefficients[
            'test_coeffs'].shape

        hdf5_coeffs = h5py.File(absolute_coeffs_filename_path, mode='w')
        hdf5_coeffs.create_dataset('train',
                                   (train_nobs, train_ntimes, train_nfeatures),
                                   np.float,
                                   data=coefficients['train_coeffs'])
        hdf5_coeffs.create_dataset('val',
                                   (val_nobs, val_ntimes, val_nfeatures),
                                   np.float,
                                   data=coefficients['val_coeffs'])
        hdf5_coeffs.create_dataset('test',
                                   (test_nobs, test_ntimes, test_nfeatures),
                                   np.float,
                                   data=coefficients['test_coeffs'])

    else:
        print('Loading corresponding interpolation coefficients ...')
        coeffs_dataset = h5py.File(absolute_coeffs_filename_path, mode='r')
        coefficients['train_coeffs'] = torch.Tensor(coeffs_dataset['train'][:])
        coefficients['val_coeffs'] = torch.Tensor(coeffs_dataset['val'][:])
        coefficients['test_coeffs'] = torch.Tensor(coeffs_dataset['test'][:])

    train_coeffs = coefficients['train_coeffs']
    val_coeffs = coefficients['val_coeffs']
    test_coeffs = coefficients['test_coeffs']
    print(f'Train data interpolation coefficients shape: {train_coeffs.shape}')
    print(
        f'Validation data interpolation coefficients shape: {val_coeffs.shape}'
    )
    print(f'Test data interpolation coefficients shape: {test_coeffs.shape}')
    return coefficients