コード例 #1
0
ファイル: example.py プロジェクト: zhengdaoli/NeuralCDE
    def forward(self, times, coeffs):
        ######################
        # Extract the sizes of the batch dimensions from the coefficients
        ######################
        coeff, _, _, _ = coeffs
        batch_dims = coeff.shape[:-2]
        z0 = torch.zeros(*batch_dims,
                         self.hidden_channels,
                         dtype=times.dtype,
                         device=times.device)

        ######################
        # Actually solve the CDE.
        ######################
        z_T = controldiffeq.cdeint(dX_dt=controldiffeq.NaturalCubicSpline(
            times, coeffs).derivative,
                                   z0=z0,
                                   func=self.func,
                                   t=times[[0, -1]],
                                   atol=1e-2,
                                   rtol=1e-2)
        ######################
        # Both the initial value and the terminal value are returned from cdeint; extract just the terminal value,
        # and then apply a linear map.
        ######################
        z_T = z_T[1]
        pred_y = self.linear(z_T)
        return pred_y
コード例 #2
0
ファイル: example.py プロジェクト: athon-millane/NeuralCDE
    def forward(self, times, coeffs):
        spline = controldiffeq.NaturalCubicSpline(times, coeffs)

        ######################
        # Easy to forget gotcha: Initial hidden state should be a function of the first observation.
        ######################
        z0 = self.initial(spline.evaluate(times[0]))

        ######################
        # Actually solve the CDE.
        ######################
        z_T = controldiffeq.cdeint(dX_dt=spline.derivative,
                                   z0=z0,
                                   func=self.func,
                                   t=times[[0, -1]],
                                   atol=1e-2,
                                   rtol=1e-2)

        self.l1 = self.func.l1
        self.l2 = self.func.l2

        ######################
        # Both the initial value and the terminal value are returned from cdeint; extract just the terminal value,
        # and then apply a linear map.
        ######################
        z_T = z_T[1]
        pred_y = self.readout(z_T)
        return pred_y, self.l1, self.l2
コード例 #3
0
ファイル: decoder.py プロジェクト: ashysheya/Neural_CDE_CNP
    def forward(self, task, output):

        t = torch.cat([output['t0'], task['x']])

        z = controldiffeq.cdeint(dX_dt=self.derivative,
                                 z0=output['z0'],
                                 func=self.cde_func,
                                 t=t,
                                 adjoint=False, 
                                 rtol=1e-3, 
                                 atol=1e-4)[1:].permute(1, 0, 2)

        output['mean_pred'] = self.mean_layer(z)
        output['std_pred'] = self.sigma_fn(self.sigma_layer(z))

        return output
コード例 #4
0
    def forward(self,
                times,
                coeffs,
                final_index,
                z0=None,
                stream=False,
                **kwargs):
        """
        Arguments:
            times: The times of the observations for the input path X, e.g. as passed as an argument to
                `controldiffeq.natural_cubic_spline_coeffs`.
            coeffs: The coefficients describing the input path X, e.g. as returned by
                `controldiffeq.natural_cubic_spline_coeffs`.
            final_index: Each batch element may have a different final time. This defines the index within the tensor
                `times` of where the final time for each batch element is.
            z0: See the 'initial' argument to __init__.
            stream: Whether to return the result of the Neural CDE model at all times (True), or just the final time
                (False). Defaults to just the final time. The `final_index` argument is ignored if stream is True.
            **kwargs: Will be passed to cdeint.

        Returns:
            If stream is False, then this will return the terminal time z_T. If stream is True, then this will return
            all intermediate times z_t, for those t for which there was data.
        """

        # Extract the sizes of the batch dimensions from the coefficients
        if len(coeffs) == 4:
            coeff, _, _, _ = coeffs
        else:
            coeff = coeffs[-1]

        batch_dims = coeff.shape[:-2]
        z0 = torch.zeros(*batch_dims,
                         self.hidden_channels,
                         dtype=coeff.dtype,
                         device=coeff.device)
        cubic_spline = controldiffeq.NaturalCubicSpline(times, coeffs)

        print(coeff[:-4].shape == final_index.shape)
        cubic_spline.evaluate_1d(times[0])

        if not stream:
            assert batch_dims == final_index.shape, "coeff.shape[:-2] must be the same as final_index.shape. " \
                                                    "coeff.shape[:-2]={}, final_index.shape={}" \
                                                    "".format(batch_dims, final_index.shape)

        if z0 is None:
            assert self.initial, "Was not expecting to be given no value of z0."
            if isinstance(self.func,
                          ContinuousRNNConverter):  # still an ugly hack
                z0 = torch.zeros(*batch_dims,
                                 self.hidden_channels,
                                 dtype=coeff.dtype,
                                 device=coeff.device)
            else:
                z0 = self.initial_network(cubic_spline.evaluate(times[0]))
        else:
            assert not self.initial, "Was expecting to be given a value of z0."
            if isinstance(self.func, ContinuousRNNConverter
                          ):  # continuing adventures in ugly hacks
                z0_extra = torch.zeros(*batch_dims,
                                       self.input_channels,
                                       dtype=z0.dtype,
                                       device=z0.device)
                z0 = torch.cat([z0_extra, z0], dim=-1)

        # Figure out what times we need to solve for
        if stream:
            t = times
        else:
            # faff around to make sure that we're outputting at all the times we need for final_index.
            sorted_final_index, inverse_final_index = final_index.unique(
                sorted=True, return_inverse=True)
            if 0 in sorted_final_index:
                sorted_final_index = sorted_final_index[1:]
                final_index = inverse_final_index
            else:
                final_index = inverse_final_index + 1
            if len(times) - 1 in sorted_final_index:
                sorted_final_index = sorted_final_index[:-1]
            t = torch.cat([
                times[0].unsqueeze(0), times[sorted_final_index],
                times[-1].unsqueeze(0)
            ])

        # Switch default solver
        if 'method' not in kwargs:
            kwargs['method'] = 'rk4'
        if kwargs['method'] == 'rk4':
            if 'options' not in kwargs:
                kwargs['options'] = {}
            options = kwargs['options']
            if 'step_size' not in options and 'grid_constructor' not in options:
                time_diffs = times[1:] - times[:-1]
                options['step_size'] = time_diffs.min().item()

        # Actually solve the CDE
        z_t = controldiffeq.cdeint(dX_dt=cubic_spline.derivative,
                                   z0=z0,
                                   func=self.func,
                                   t=t,
                                   **kwargs)

        # Organise the output
        if stream:
            # z_t is a tensor of shape (times, ..., channels), so change this to (..., times, channels)
            for i in range(len(z_t.shape) - 2, 0, -1):
                z_t = z_t.transpose(0, i)
        else:
            # final_index is a tensor of shape (...)
            # z_t is a tensor of shape (times, ..., channels)
            final_index_indices = final_index.unsqueeze(-1).expand(
                z_t.shape[1:]).unsqueeze(0)
            z_t = z_t.gather(dim=0, index=final_index_indices).squeeze(0)

        # Linear map and return
        pred_y = self.linear(z_t)
        return pred_y