def forward(self, coeffs): if self.interpolation == "cubic": X = torchcde.CubicSpline(coeffs) elif self.interpolation == "linear": X = torchcde.LinearInterpolation(coeffs) else: raise ValueError( "Only 'linear' and 'cubic' interpolation methods are implemented." ) X0 = X.evaluate(X.interval[0]) z0 = self.initial(X0) if self.return_sequences: pred_y = [] times = torch.arange(X.interval[0], X.interval[1] + 1) z_t = torchcde.cdeint(X=X, z0=z0, func=self.func, t=times) for ti in times: z_ti = z_t[:, int(ti) - 1] pred_y.append(self.readout(z_ti)) pred_y = torch.stack(pred_y) else: z_T = torchcde.cdeint(X=X, z0=z0, func=self.func, t=X.interval) z_T = z_T[:, 1] pred_y = self.readout(z_T) return pred_y
def test_stacked_paths(): class Record(torch.autograd.Function): @staticmethod def forward(ctx, name, x): ctx.name = name return x @staticmethod def backward(ctx, x): if hasattr(ctx, 'been_here_before'): pytest.fail(ctx.name) ctx.been_here_before = True return None, x ReparameterisedLinearInterpolation = ft.partial( torchcde.LinearInterpolation, reparameterise='bump') coeff_paths = [ (torchcde.linear_interpolation_coeffs, torchcde.LinearInterpolation), (torchcde.linear_interpolation_coeffs, ReparameterisedLinearInterpolation), (torchcde.natural_cubic_coeffs, torchcde.NaturalCubicSpline) ] for adjoint in (False, True): for first_coeffs, First in coeff_paths: for second_coeffs, Second in coeff_paths: first_path = torch.rand(1, 1000, 2, requires_grad=True) first_coeff = first_coeffs(first_path) first_X = First(first_coeff) first_func = _Func(input_size=2, hidden_size=2) second_t = torch.linspace(0, 999, 100) second_path = torchcde.cdeint(X=first_X, func=first_func, z0=torch.rand(1, 2), t=second_t, adjoint=adjoint, method='rk4', options=dict(step_size=10)) second_path = Record.apply('second', second_path) second_coeff = second_coeffs(second_path, second_t) second_X = Second(second_coeff, second_t) second_func = _Func(input_size=2, hidden_size=2) third_t = torch.linspace(0, 999, 10) third_path = torchcde.cdeint(X=second_X, func=second_func, z0=torch.rand(1, 2), t=third_t, adjoint=adjoint, method='rk4', options=dict(step_size=10)) third_path = Record.apply('third', third_path) assert first_func.variable.grad is None assert second_func.variable.grad is None assert first_path.grad is None third_path[:, -1].sum().backward() assert isinstance(second_func.variable.grad, torch.Tensor) assert isinstance(first_func.variable.grad, torch.Tensor) assert isinstance(first_path.grad, torch.Tensor)
def test_backend(): x = torch.randn(1, 10, 2) coeffs = torchcde.natural_cubic_coeffs(x) X = torchcde.CubicSpline(coeffs) def func(t, z): return -z.unsqueeze(-1).expand(1, 3, 2) z0 = torch.randn(1, 3) torchdiffeq_out = torchcde.cdeint(X=X, func=func, z0=z0, t=X.interval, backend="torchdiffeq", method="midpoint", options=dict(step_size=1.0)) torchsde_out = torchcde.cdeint(X=X, func=func, z0=z0, t=X.interval, backend="torchsde", method="midpoint", dt=1.0) torch.testing.assert_allclose(torchdiffeq_out, torchsde_out)
def _solve_cde(x): # x should be of shape (batch, length, channels) batch_size = x.size(0) input_channels = x.size(2) hidden_channels = 4 # hyperparameter, we can pick whatever we want for this coeffs = torchcde.natural_cubic_spline_coeffs(x) X = torchcde.NaturalCubicSpline(coeffs) z0 = torch.rand(batch_size, hidden_channels) class F(torch.nn.Module): def __init__(self): super(F, self).__init__() self.linear = torch.nn.Linear(hidden_channels, hidden_channels * input_channels) def forward(self, t, z): return self.linear(z).view(batch_size, hidden_channels, input_channels) func = F() zt = torchcde.cdeint(X=X, func=func, z0=z0, t=X.interval) zT = zt[:, -1] # get the terminal value of the CDE return zT
def forward(self, coeffs): X = torchcde.CubicSpline(coeffs) X0 = X.evaluate(X.interval[0]) z0 = self.initial(X0) zt = torchcde.cdeint(X=X, func=self.func, z0=z0, t=X.interval) zT = zt[..., -1, :] # get the terminal value of the CDE return self.readout(zT)
def forward(self, coeffs): if self.interpolation == 'cubic': X = torchcde.CubicSpline(coeffs) elif self.interpolation == 'linear': X = torchcde.LinearInterpolation(coeffs) else: raise ValueError( "Only 'linear' and 'cubic' interpolation methods are implemented." ) ###################### # Easy to forget gotcha: Initial hidden state should be a function of the first observation. ###################### X0 = X.evaluate(X.interval[0]) z0 = self.initial(X0) ###################### # Actually solve the CDE. ###################### z_T = torchcde.cdeint(X=X, z0=z0, func=self.func, t=X.interval) ###################### # Both the initial value and the terminal value are returned from cdeint; extract just the terminal value, # and then apply a linear map. ###################### z_T = z_T[:, 1] pred_y = self.readout(z_T) return pred_y
def forward(self, coeffs, y): # coeffs is of shape (batch, length, input_channels) if using any linear interpolation # y is of shape (batch,) X = torchcde.LinearInterpolation(coeffs, self.times) z0 = self.initial(X.evaluate(self.times[0])) options = dict(grid_points=X.grid_points, eps=1e-5) adjoint_options = options.copy() if self.norm: adjoint_options['norm'] = common.make_norm(z0) z_t = torchcde.cdeint(X=X, z0=z0, func=self.func, t=self.times[[0, -1]], rtol=self.rtol, atol=self.atol, options=options, adjoint_options=adjoint_options) z_T = z_t[:, -1] pred_y = self.readout(z_T) loss = torch.nn.functional.cross_entropy(pred_y, y) thresholded_y = torch.argmax(pred_y, dim=1) accuracy = (thresholded_y == y).sum().to(pred_y.dtype) return loss, accuracy
def test_grad_paths(): for method in ('rk4', 'dopri5'): for adjoint in (True, False): t = torch.linspace(0, 9, 10, requires_grad=True) path = torch.rand(1, 10, 3, requires_grad=True) coeffs = torchcde.natural_cubic_coeffs(path, t) cubic_spline = torchcde.NaturalCubicSpline(coeffs, t) z0 = torch.rand(1, 3, requires_grad=True) func = _Func(input_size=3, hidden_size=3) t_ = torch.tensor([0., 9.], requires_grad=True) z = torchcde.cdeint(X=cubic_spline, func=func, z0=z0, t=t_, adjoint=adjoint, method=method, rtol=1e-4, atol=1e-6) assert z.shape == (1, 2, 3) assert t.grad is None assert path.grad is None assert z0.grad is None assert func.variable.grad is None assert t_.grad is None z[:, 1].sum().backward() assert isinstance(t.grad, torch.Tensor) assert isinstance(path.grad, torch.Tensor) assert isinstance(z0.grad, torch.Tensor) assert isinstance(func.variable.grad, torch.Tensor) assert isinstance(t_.grad, torch.Tensor)
def test_detach_trick(): path = torch.rand(1, 10, 3) func = _Func(input_size=3, hidden_size=3) def interp_(): coeffs = torchcde.natural_cubic_coeffs(path) yield torchcde.NaturalCubicSpline(coeffs) coeffs = torchcde.linear_interpolation_coeffs(path) yield torchcde.LinearInterpolation(coeffs, reparameterise='bump') for interp in interp_(): for adjoint in (True, False): variable_grads = [] z0 = torch.rand(1, 3) for t_grad in (True, False): t_ = torch.tensor([0., 9.], requires_grad=t_grad) # Don't test dopri5. We will get different results then, because the t variable will force smaller step # sizes and thus slightly different results. z = torchcde.cdeint(X=interp, z0=z0, func=func, t=t_, adjoint=adjoint, method='rk4', options=dict(step_size=0.5)) z[:, -1].sum().backward() variable_grads.append(func.variable.grad.clone()) func.variable.grad.zero_() for elem in variable_grads[1:]: assert (elem == variable_grads[0]).all()
def test_prod(): x = torch.rand(2, 5, 1) X = torchcde.NaturalCubicSpline(torchcde.natural_cubic_coeffs(x)) class F: def prod(self, t, z, dXdt): assert t.shape == () assert z.shape == (2, 3) assert dXdt.shape == (2, 1) return -z * dXdt z0 = torch.rand(2, 3, requires_grad=True) out = torchcde.cdeint(X=X, func=F(), z0=z0, t=X.interval, adjoint_params=()) out.sum().backward()
def forward(self, coeffs, times, interpolation_method): X, cdeint_options = build_data_path(coeffs, times, interpolation_method) z0 = self.initial(X.evaluate(X.interval[0])) # initial hidden state must be a function of the first observation if self.batch_norm: z0 = self.bn_initial(z0) z_t = torchcde.cdeint(X=X, z0=z0, func=self.vector_field, t=X.interval, options=cdeint_options) # t=times[[0, -1]] is the same (but only when times is not None...) # Both z0 and z_T are returned from cdeint, extract just last value z_T = z_t[:, -1] pred_y = self.readout(z_T) if self.batch_norm: pred_y = self.bn_output(pred_y) pred_y = torch.nn.functional.softmax(pred_y, dim=-1) # New. Added a soft-max to get soft assignments that work with the multi-class cross entropy loss function return pred_y
def test_shape(): for method in ('rk4', 'dopri5'): for _ in range(10): num_points = torch.randint(low=5, high=100, size=(1, )).item() num_channels = torch.randint(low=1, high=3, size=(1, )).item() num_hidden_channels = torch.randint(low=1, high=5, size=(1, )).item() num_batch_dims = torch.randint(low=0, high=3, size=(1, )).item() batch_dims = [] for _ in range(num_batch_dims): batch_dims.append( torch.randint(low=1, high=3, size=(1, )).item()) values = torch.rand(*batch_dims, num_points, num_channels) coeffs = torchcde.natural_cubic_coeffs(values) spline = torchcde.NaturalCubicSpline(coeffs) class _Func(torch.nn.Module): def __init__(self): super(_Func, self).__init__() self.variable = torch.nn.Parameter( torch.rand(*[1 for _ in range(num_batch_dims)], 1, num_channels)) def forward(self, t, z): return z.sigmoid().unsqueeze(-1) + self.variable f = _Func() z0 = torch.rand(*batch_dims, num_hidden_channels) num_out_times = torch.randint(low=2, high=10, size=(1, )).item() start, end = spline.interval out_times = torch.rand(num_out_times, dtype=torch.float64).sort( ).values * (end - start) + start options = {} if method == 'rk4': options['step_size'] = 1. / num_points out = torchcde.cdeint(spline, f, z0, out_times, method=method, options=options, rtol=1e-4, atol=1e-6) assert out.shape == (*batch_dims, num_out_times, num_hidden_channels)
def forward(self, x): # NOTE: x should be the natural cubic spline coefficients. Look into datasets.py for how to generate these. x = torchcde.natural_cubic_coeffs(x) if self.interpolation == "cubic": x = torchcde.NaturalCubicSpline(x) elif self.interpolation == "linear": x = torchcde.LinearInterpolation(x) else: raise ValueError("invalid interpolation given") x0 = x.evaluate(x.interval[0]) z0 = self.initial(x0) zt = torchcde.cdeint(X=x, func=self.model, z0=z0, t=x.interval) return self.output(zt[..., -1, :])
def test_tuple_input(): xa = torch.rand(2, 10, 2) xb = torch.rand(10, 1) coeffs_a = torchcde.natural_cubic_coeffs(xa) coeffs_b = torchcde.natural_cubic_coeffs(xb) spline_a = torchcde.NaturalCubicSpline(coeffs_a) spline_b = torchcde.NaturalCubicSpline(coeffs_b) X = torchcde.TupleControl(spline_a, spline_b) def func(t, z): za, zb = z return za.sigmoid().unsqueeze(-1).repeat_interleave(2, dim=-1), zb.tanh().unsqueeze(-1) z0 = torch.rand(2, 3), torch.rand(5, requires_grad=True) out = torchcde.cdeint(X=X, func=func, z0=z0, t=X.interval, adjoint_params=()) out[0].sum().backward() assert (z0[1].grad == 0).all()
def forward(self, ys_coeffs): # ys_coeffs has shape (batch_size, t_size, 1 + data_size) # The +1 corresponds to time. When solving CDEs, It turns out to be most natural to treat time as just another # channel: in particular this makes handling irregular data quite easy, when the times may be different between # different samples in the batch. Y = torchcde.LinearInterpolation(ys_coeffs) Y0 = Y.evaluate(Y.interval[0]) h0 = self._initial(Y0) hs = torchcde.cdeint( Y, self._func, h0, Y.interval, adjoint=False, method='midpoint', options=dict(step_size=1.0)) # shape (batch_size, 2, hidden_size) score = self._readout(hs[:, -1]) return score.mean()
def forward(self, ys_coeffs): # ys_coeffs has shape (batch_size, t_size, 1 + data_size) # The +1 corresponds to time. When solving CDEs, It turns out to be most natural to treat time as just another # channel: in particular this makes handling irregular data quite easy, when the times may be different between # different samples in the batch. Y = torchcde.LinearInterpolation(ys_coeffs) Y0 = Y.evaluate(Y.interval[0]) h0 = self._initial(Y0) hs = torchcde.cdeint(Y, self._func, h0, Y.interval, method='reversible_heun', backend='torchsde', dt=1.0, adjoint_method='adjoint_reversible_heun', adjoint_params=(ys_coeffs, ) + tuple(self._func.parameters())) score = self._readout(hs[:, -1]) return score.mean()
def forward(self, coeffs): X = torchcde.NaturalCubicSpline(coeffs) ###################### # Easy to forget gotcha: Initial hidden state should be a function of the first observation. ###################### X0 = X.evaluate(X.interval[0]) z0 = self.initial(X0) ###################### # Actually solve the CDE. ###################### z_T = torchcde.cdeint(X=X, z0=z0, func=self.func, t=X.interval) ###################### # Both the initial value and the terminal value are returned from cdeint; extract just the terminal value, # and then apply a linear map. ###################### z_T = z_T[:, 1] pred_y = self.readout(z_T) return pred_y
def forward(self, inputs): # Handle h0 and inputs coeffs, h0 = self._setup_h0(inputs) # Make lin int data = self.spline(coeffs) # Perform the adjoint operation hidden = torchcde.cdeint( data, self.func, h0, data.grid_points, adjoint=self.adjoint, method=self.solver, ) # Convert to outputs outputs = self._make_outputs(hidden) return outputs
def test_shape(backend, method, kwargs): for _ in range(5): num_points = torch.randint(low=5, high=100, size=(1,)).item() num_channels = torch.randint(low=1, high=3, size=(1,)).item() num_hidden_channels = torch.randint(low=1, high=5, size=(1,)).item() if backend == "torchdiffeq": num_batch_dims = torch.randint(low=0, high=3, size=(1,)).item() batch_dims = [] for _ in range(num_batch_dims): batch_dims.append(torch.randint(low=1, high=3, size=(1,)).item()) elif backend == "torchsde": num_batch_dims = 1 batch_dims = [torch.randint(low=1, high=3, size=(1,)).item()] else: raise ValueError values = torch.rand(*batch_dims, num_points, num_channels) coeffs = torchcde.natural_cubic_coeffs(values) spline = torchcde.CubicSpline(coeffs) class _Func(torch.nn.Module): def __init__(self): super(_Func, self).__init__() self.variable = torch.nn.Parameter(torch.rand(*[1 for _ in range(num_batch_dims)], 1, num_channels)) def forward(self, t, z): return z.sigmoid().unsqueeze(-1) + self.variable f = _Func() z0 = torch.rand(*batch_dims, num_hidden_channels) num_out_times = torch.randint(low=2, high=10, size=(1,)).item() start, end = spline.interval out_times = torch.rand(num_out_times, dtype=torch.float64).sort().values * (end - start) + start out = torchcde.cdeint(spline, f, z0, out_times, backend=backend, method=method, rtol=1e-1, atol=1e-1, **kwargs) assert out.shape == (*batch_dims, num_out_times, num_hidden_channels)