def build_data_path(coeffs, times, interpolation_method): if interpolation_method == 'cubic': X = torchcde.NaturalCubicSpline(coeffs, t=times) cdeint_options = {} elif interpolation_method == 'linear': X = torchcde.LinearInterpolation(coeffs, t=times) cdeint_options = dict(grid_points=X.grid_points, eps=1e-5) elif interpolation_method == 'rectilinear': # rectifilinear doesn't work when passing time argument X = torchcde.LinearInterpolation(coeffs) cdeint_options = dict(grid_points=X.grid_points, eps=1e-5) return X, cdeint_options
def forward(self, coeffs, y): # coeffs is of shape (batch, length, input_channels) if using any linear interpolation # y is of shape (batch,) X = torchcde.LinearInterpolation(coeffs, self.times) z0 = self.initial(X.evaluate(self.times[0])) options = dict(grid_points=X.grid_points, eps=1e-5) adjoint_options = options.copy() if self.norm: adjoint_options['norm'] = common.make_norm(z0) z_t = torchcde.cdeint(X=X, z0=z0, func=self.func, t=self.times[[0, -1]], rtol=self.rtol, atol=self.atol, options=options, adjoint_options=adjoint_options) z_T = z_t[:, -1] pred_y = self.readout(z_T) loss = torch.nn.functional.cross_entropy(pred_y, y) thresholded_y = torch.argmax(pred_y, dim=1) accuracy = (thresholded_y == y).sum().to(pred_y.dtype) return loss, accuracy
def test_small(): for use_t in (False, True): if use_t: start = torch.rand(1).item() * 10 - 5 end = torch.rand(1).item() * 10 - 5 start, end = min(start, end), max(start, end) t = torch.tensor([start, end], dtype=torch.float64) t_ = t else: start = 0 end = 1 t = torch.tensor([0., 1.], dtype=torch.float64) t_ = None x = torch.rand(2, 1, dtype=torch.float64) true_deriv = (x[1] - x[0]) / (end - start) coeffs = torchcde.linear_interpolation_coeffs(x, t=t_) linear = torchcde.LinearInterpolation(coeffs, t=t_) for time in torch.linspace(-1, 2, 100): true = x[0] + true_deriv * (time - t[0]) pred = linear.evaluate(time) deriv = linear.derivative(time) assert true_deriv.shape == deriv.shape assert true_deriv.allclose(deriv) assert true.shape == pred.shape assert true.allclose(pred)
def forward(self, coeffs): if self.interpolation == 'cubic': X = torchcde.CubicSpline(coeffs) elif self.interpolation == 'linear': X = torchcde.LinearInterpolation(coeffs) else: raise ValueError( "Only 'linear' and 'cubic' interpolation methods are implemented." ) ###################### # Easy to forget gotcha: Initial hidden state should be a function of the first observation. ###################### X0 = X.evaluate(X.interval[0]) z0 = self.initial(X0) ###################### # Actually solve the CDE. ###################### z_T = torchcde.cdeint(X=X, z0=z0, func=self.func, t=X.interval) ###################### # Both the initial value and the terminal value are returned from cdeint; extract just the terminal value, # and then apply a linear map. ###################### z_T = z_T[:, 1] pred_y = self.readout(z_T) return pred_y
def test_with_linear_interpolation(): import signatory window_length = 4 for depth in (1, 2, 3, 4): compute_logsignature = signatory.Logsignature(depth) for pieces in (1, 2, 3, 5, 10): num_channels = torch.randint(low=1, high=4, size=(1, )).item() x_ = [torch.randn(1, num_channels, dtype=torch.float64)] logsignatures = [] for _ in range(pieces): x = torch.randn(window_length, num_channels, dtype=torch.float64) logsignature = compute_logsignature( torch.cat([x_[-1][-1:], x]).unsqueeze(0)) x_.append(x) logsignatures.append(logsignature) x = torch.cat(x_) logsig_x = torchcde.logsig_windows(x, depth, window_length) coeffs = torchcde.linear_interpolation_coeffs(logsig_x) X = torchcde.LinearInterpolation(coeffs) point = 0.5 for logsignature in logsignatures: interp_logsignature = X.derivative(point) assert interp_logsignature.allclose(logsignature) point += 1
def forward(self, coeffs): if self.interpolation == "cubic": X = torchcde.CubicSpline(coeffs) elif self.interpolation == "linear": X = torchcde.LinearInterpolation(coeffs) else: raise ValueError( "Only 'linear' and 'cubic' interpolation methods are implemented." ) X0 = X.evaluate(X.interval[0]) z0 = self.initial(X0) if self.return_sequences: pred_y = [] times = torch.arange(X.interval[0], X.interval[1] + 1) z_t = torchcde.cdeint(X=X, z0=z0, func=self.func, t=times) for ti in times: z_ti = z_t[:, int(ti) - 1] pred_y.append(self.readout(z_ti)) pred_y = torch.stack(pred_y) else: z_T = torchcde.cdeint(X=X, z0=z0, func=self.func, t=X.interval) z_T = z_T[:, 1] pred_y = self.readout(z_T) return pred_y
def test_linear(): for interp_fn in (torchcde.natural_cubic_coeffs, torchcde.natural_cubic_spline_coeffs): for use_t in (False, True): start = torch.rand(1).item() * 5 - 2.5 end = torch.rand(1).item() * 5 - 2.5 start, end = min(start, end), max(start, end) num_points = torch.randint(low=2, high=10, size=(1, )).item() num_channels = torch.randint(low=1, high=4, size=(1, )).item() m = torch.rand(num_channels) * 5 - 2.5 c = torch.rand(num_channels) * 5 - 2.5 if use_t: t = torch.linspace(start, end, num_points) t_ = t else: t = torch.linspace(0, num_points - 1, num_points) t_ = None values = m * t.unsqueeze(-1) + c coeffs = interp_fn(values, t_) spline = torchcde.NaturalCubicSpline(coeffs, t_) coeffs2 = torchcde.linear_interpolation_coeffs(values, t_) linear = torchcde.LinearInterpolation(coeffs2, t_) batch_dims = [] _test_equal(batch_dims, num_channels, linear, spline, torch.float32, -1.5, 1.5, 1e-4)
def test_random(): def _points(): yield 2 yield 3 yield 100 for _ in range(10): yield torch.randint(low=2, high=100, size=(1, )).item() for drop in (False, True): for use_t in (False, True): for num_points in _points(): if use_t: start = torch.rand(1).item() * 10 - 5 end = torch.rand(1).item() * 10 - 5 start, end = min(start, end), max(start, end) t = torch.linspace(start, end, num_points, dtype=torch.float64) t_ = t else: t = torch.linspace(0, num_points - 1, num_points, dtype=torch.float64) t_ = None num_channels = torch.randint(low=1, high=5, size=(1, )).item() m = torch.rand(num_channels, dtype=torch.float64) * 10 - 5 c = torch.rand(num_channels, dtype=torch.float64) * 10 - 5 values = m * t.unsqueeze(-1) + c values_clone = values.clone() if drop: for values_slice in values_clone.unbind(dim=-1): num_drop = int( num_points * torch.randint(low=1, high=4, size=(1, )).item() / 10) num_drop = min(num_drop, num_points - 4) to_drop = torch.randperm( num_points - 2)[:num_drop] + 1 # don't drop first or last values_slice[to_drop] = float('nan') coeffs = torchcde.linear_interpolation_coeffs(values_clone, t=t_) linear = torchcde.LinearInterpolation(coeffs, t=t_) for time, value in zip(t, values): linear_evaluate = linear.evaluate(time) assert value.shape == linear_evaluate.shape assert value.allclose(linear_evaluate, rtol=1e-4, atol=1e-6) linear_derivative = linear.derivative(time) assert m.shape == linear_derivative.shape assert m.allclose(linear_derivative, rtol=1e-4, atol=1e-6)
def test_specification_and_derivative(): for use_t in (False, True): for reparameterise in ('none', 'bump'): for _ in range(10): for num_batch_dims in (0, 1, 2, 3): batch_dims = [] for _ in range(num_batch_dims): batch_dims.append( torch.randint(low=1, high=3, size=(1, )).item()) length = torch.randint(low=5, high=10, size=(1, )).item() channels = torch.randint(low=1, high=5, size=(1, )).item() if use_t: t = torch.linspace(0, 1, length, dtype=torch.float64) t_ = t else: t = torch.linspace(0, length - 1, length, dtype=torch.float64) t_ = None x = torch.rand(*batch_dims, length, channels, dtype=torch.float64) coeffs = torchcde.linear_interpolation_coeffs(x, t=t_) spline = torchcde.LinearInterpolation( coeffs, t=t_, reparameterise=reparameterise) # Test specification for i, point in enumerate(t): evaluate = spline.evaluate(point) xi = x[..., i, :] assert evaluate.allclose(xi, atol=1e-5, rtol=1e-5) # Test derivative for point in torch.rand(100, dtype=torch.float64): point_with_grad = point.detach().requires_grad_(True) evaluate = spline.evaluate(point_with_grad) derivative = spline.derivative(point) autoderivative = [] for elem in evaluate.view(-1): elem.backward(retain_graph=True) with torch.no_grad(): autoderivative.append( point_with_grad.grad.clone()) point_with_grad.grad.zero_() autoderivative = torch.stack(autoderivative).view( *evaluate.shape) assert derivative.shape == autoderivative.shape assert derivative.allclose(autoderivative, atol=1e-5, rtol=1e-5)
def test_short(): for use_t in (False, True): if use_t: t = torch.tensor([0., 1.]) else: t = None values = torch.rand(2, 1) coeffs = torchcde.natural_cubic_spline_coeffs(values, t) spline = torchcde.NaturalCubicSpline(coeffs, t) coeffs2 = torchcde.linear_interpolation_coeffs(values, t) linear = torchcde.LinearInterpolation(coeffs2, t) batch_dims = [] num_channels = 1 _test_equal(batch_dims, num_channels, linear, spline, torch.float32, -1.5, 1.5, 1e-4)
def forward(self, x): # NOTE: x should be the natural cubic spline coefficients. Look into datasets.py for how to generate these. x = torchcde.natural_cubic_coeffs(x) if self.interpolation == "cubic": x = torchcde.NaturalCubicSpline(x) elif self.interpolation == "linear": x = torchcde.LinearInterpolation(x) else: raise ValueError("invalid interpolation given") x0 = x.evaluate(x.interval[0]) z0 = self.initial(x0) zt = torchcde.cdeint(X=x, func=self.model, z0=z0, t=x.interval) return self.output(zt[..., -1, :])
def forward(self, ys_coeffs): # ys_coeffs has shape (batch_size, t_size, 1 + data_size) # The +1 corresponds to time. When solving CDEs, It turns out to be most natural to treat time as just another # channel: in particular this makes handling irregular data quite easy, when the times may be different between # different samples in the batch. Y = torchcde.LinearInterpolation(ys_coeffs) Y0 = Y.evaluate(Y.interval[0]) h0 = self._initial(Y0) hs = torchcde.cdeint( Y, self._func, h0, Y.interval, adjoint=False, method='midpoint', options=dict(step_size=1.0)) # shape (batch_size, 2, hidden_size) score = self._readout(hs[:, -1]) return score.mean()
def forward(self, ys_coeffs): # ys_coeffs has shape (batch_size, t_size, 1 + data_size) # The +1 corresponds to time. When solving CDEs, It turns out to be most natural to treat time as just another # channel: in particular this makes handling irregular data quite easy, when the times may be different between # different samples in the batch. Y = torchcde.LinearInterpolation(ys_coeffs) Y0 = Y.evaluate(Y.interval[0]) h0 = self._initial(Y0) hs = torchcde.cdeint(Y, self._func, h0, Y.interval, method='reversible_heun', backend='torchsde', dt=1.0, adjoint_method='adjoint_reversible_heun', adjoint_params=(ys_coeffs, ) + tuple(self._func.parameters())) score = self._readout(hs[:, -1]) return score.mean()
def plot(ts, generator, dataloader, num_plot_samples, plot_locs): # Get samples real_samples, = next(iter(dataloader)) assert num_plot_samples <= real_samples.size(0) real_samples = torchcde.LinearInterpolation(real_samples).evaluate(ts) real_samples = real_samples[..., 1] with torch.no_grad(): generated_samples = generator(ts, real_samples.size(0)).cpu() generated_samples = torchcde.LinearInterpolation( generated_samples).evaluate(ts) generated_samples = generated_samples[..., 1] # Plot histograms for prop in plot_locs: time = int(prop * (real_samples.size(1) - 1)) real_samples_time = real_samples[:, time] generated_samples_time = generated_samples[:, time] _, bins, _ = plt.hist(real_samples_time.cpu().numpy(), bins=32, alpha=0.7, label='Real', color='dodgerblue', density=True) bin_width = bins[1] - bins[0] num_bins = int((generated_samples_time.max() - generated_samples_time.min()).item() // bin_width) plt.hist(generated_samples_time.cpu().numpy(), bins=num_bins, alpha=0.7, label='Generated', color='crimson', density=True) plt.legend() plt.xlabel('Value') plt.ylabel('Density') plt.title(f'Marginal distribution at time {time}.') plt.tight_layout() plt.show() real_samples = real_samples[:num_plot_samples] generated_samples = generated_samples[:num_plot_samples] # Plot samples real_first = True generated_first = True for real_sample_ in real_samples: kwargs = {'label': 'Real'} if real_first else {} plt.plot(ts.cpu(), real_sample_.cpu(), color='dodgerblue', linewidth=0.5, alpha=0.7, **kwargs) real_first = False for generated_sample_ in generated_samples: kwargs = {'label': 'Generated'} if generated_first else {} plt.plot(ts.cpu(), generated_sample_.cpu(), color='crimson', linewidth=0.5, alpha=0.7, **kwargs) generated_first = False plt.legend() plt.title( f"{num_plot_samples} samples from both real and generated distributions." ) plt.tight_layout() plt.show()
def interp_(): coeffs = torchcde.natural_cubic_coeffs(path) yield torchcde.NaturalCubicSpline(coeffs) coeffs = torchcde.linear_interpolation_coeffs(path) yield torchcde.LinearInterpolation(coeffs, reparameterise='bump')