def forward(self, times, coeffs): ###################### # Extract the sizes of the batch dimensions from the coefficients ###################### coeff, _, _, _ = coeffs batch_dims = coeff.shape[:-2] z0 = torch.zeros(*batch_dims, self.hidden_channels, dtype=times.dtype, device=times.device) ###################### # Actually solve the CDE. ###################### z_T = controldiffeq.cdeint(dX_dt=controldiffeq.NaturalCubicSpline( times, coeffs).derivative, z0=z0, func=self.func, t=times[[0, -1]], atol=1e-2, rtol=1e-2) ###################### # Both the initial value and the terminal value are returned from cdeint; extract just the terminal value, # and then apply a linear map. ###################### z_T = z_T[1] pred_y = self.linear(z_T) return pred_y
def forward(self, times, coeffs): spline = controldiffeq.NaturalCubicSpline(times, coeffs) ###################### # Easy to forget gotcha: Initial hidden state should be a function of the first observation. ###################### z0 = self.initial(spline.evaluate(times[0])) ###################### # Actually solve the CDE. ###################### z_T = controldiffeq.cdeint(dX_dt=spline.derivative, z0=z0, func=self.func, t=times[[0, -1]], atol=1e-2, rtol=1e-2) self.l1 = self.func.l1 self.l2 = self.func.l2 ###################### # Both the initial value and the terminal value are returned from cdeint; extract just the terminal value, # and then apply a linear map. ###################### z_T = z_T[1] pred_y = self.readout(z_T) return pred_y, self.l1, self.l2
def forward(self, task, output): t = torch.cat([output['t0'], task['x']]) z = controldiffeq.cdeint(dX_dt=self.derivative, z0=output['z0'], func=self.cde_func, t=t, adjoint=False, rtol=1e-3, atol=1e-4)[1:].permute(1, 0, 2) output['mean_pred'] = self.mean_layer(z) output['std_pred'] = self.sigma_fn(self.sigma_layer(z)) return output
def forward(self, times, coeffs, final_index, z0=None, stream=False, **kwargs): """ Arguments: times: The times of the observations for the input path X, e.g. as passed as an argument to `controldiffeq.natural_cubic_spline_coeffs`. coeffs: The coefficients describing the input path X, e.g. as returned by `controldiffeq.natural_cubic_spline_coeffs`. final_index: Each batch element may have a different final time. This defines the index within the tensor `times` of where the final time for each batch element is. z0: See the 'initial' argument to __init__. stream: Whether to return the result of the Neural CDE model at all times (True), or just the final time (False). Defaults to just the final time. The `final_index` argument is ignored if stream is True. **kwargs: Will be passed to cdeint. Returns: If stream is False, then this will return the terminal time z_T. If stream is True, then this will return all intermediate times z_t, for those t for which there was data. """ # Extract the sizes of the batch dimensions from the coefficients if len(coeffs) == 4: coeff, _, _, _ = coeffs else: coeff = coeffs[-1] batch_dims = coeff.shape[:-2] z0 = torch.zeros(*batch_dims, self.hidden_channels, dtype=coeff.dtype, device=coeff.device) cubic_spline = controldiffeq.NaturalCubicSpline(times, coeffs) print(coeff[:-4].shape == final_index.shape) cubic_spline.evaluate_1d(times[0]) if not stream: assert batch_dims == final_index.shape, "coeff.shape[:-2] must be the same as final_index.shape. " \ "coeff.shape[:-2]={}, final_index.shape={}" \ "".format(batch_dims, final_index.shape) if z0 is None: assert self.initial, "Was not expecting to be given no value of z0." if isinstance(self.func, ContinuousRNNConverter): # still an ugly hack z0 = torch.zeros(*batch_dims, self.hidden_channels, dtype=coeff.dtype, device=coeff.device) else: z0 = self.initial_network(cubic_spline.evaluate(times[0])) else: assert not self.initial, "Was expecting to be given a value of z0." if isinstance(self.func, ContinuousRNNConverter ): # continuing adventures in ugly hacks z0_extra = torch.zeros(*batch_dims, self.input_channels, dtype=z0.dtype, device=z0.device) z0 = torch.cat([z0_extra, z0], dim=-1) # Figure out what times we need to solve for if stream: t = times else: # faff around to make sure that we're outputting at all the times we need for final_index. sorted_final_index, inverse_final_index = final_index.unique( sorted=True, return_inverse=True) if 0 in sorted_final_index: sorted_final_index = sorted_final_index[1:] final_index = inverse_final_index else: final_index = inverse_final_index + 1 if len(times) - 1 in sorted_final_index: sorted_final_index = sorted_final_index[:-1] t = torch.cat([ times[0].unsqueeze(0), times[sorted_final_index], times[-1].unsqueeze(0) ]) # Switch default solver if 'method' not in kwargs: kwargs['method'] = 'rk4' if kwargs['method'] == 'rk4': if 'options' not in kwargs: kwargs['options'] = {} options = kwargs['options'] if 'step_size' not in options and 'grid_constructor' not in options: time_diffs = times[1:] - times[:-1] options['step_size'] = time_diffs.min().item() # Actually solve the CDE z_t = controldiffeq.cdeint(dX_dt=cubic_spline.derivative, z0=z0, func=self.func, t=t, **kwargs) # Organise the output if stream: # z_t is a tensor of shape (times, ..., channels), so change this to (..., times, channels) for i in range(len(z_t.shape) - 2, 0, -1): z_t = z_t.transpose(0, i) else: # final_index is a tensor of shape (...) # z_t is a tensor of shape (times, ..., channels) final_index_indices = final_index.unsqueeze(-1).expand( z_t.shape[1:]).unsqueeze(0) z_t = z_t.gather(dim=0, index=final_index_indices).squeeze(0) # Linear map and return pred_y = self.linear(z_t) return pred_y