Пример #1
0
    def __init__(self,
                 kf_step: 'GaussianStep',
                 processes: Sequence[Process],
                 measures: Sequence[str],
                 process_covariance: Covariance,
                 measure_covariance: Covariance,
                 initial_covariance: Covariance):
        super(ScriptKalmanFilter, self).__init__()

        self.kf_step = kf_step

        # measures:
        self.measures = measures
        self.measure_to_idx = {m: i for i, m in enumerate(self.measures)}

        # processes:
        self.processes = nn.ModuleDict()
        self.process_to_slice: Dict[str, Tuple[int, int]] = {}
        self.state_rank = 0
        for p in processes:
            assert p.measure, f"{p.id} does not have its `measure` set"
            self.processes[p.id] = p
            self.process_to_slice[p.id] = (self.state_rank, self.state_rank + len(p.state_elements))
            self.state_rank += len(p.state_elements)

        self.process_covariance = process_covariance.set_id('process_covariance')
        self.measure_covariance = measure_covariance.set_id('measure_covariance')
        self.initial_covariance = initial_covariance.set_id('initial_covariance')

        # can disable for debugging/tests:
        self._scale_by_measure_var = True
Пример #2
0
    def test_from_log_cholesky(self):
        module = Covariance(id='test', rank=3)

        module.state_dict()['cholesky_log_diag'][:] = torch.arange(1., 3.1)
        module.state_dict()['cholesky_off_diag'][:] = torch.arange(1., 3.1)

        expected = torch.tensor([[7.3891, 2.7183, 5.4366],
                                 [2.7183, 55.5982, 24.1672],
                                 [5.4366, 24.1672, 416.4288]])
        diff = (expected - module({}, {})).abs()
        self.assertTrue((diff < .0001).all())
Пример #3
0
    def test_equations(self):
        data = Tensor([[-50., 50., 1.]])[:, :, None]

        #
        design = simple_mv_velocity_design(dims=1)
        batch_design = design.for_batch(1, 1)
        torch_kf = KalmanFilter(processes=design.processes.values(), measures=design.measures)
        pred = torch_kf(data)

        #
        filter_kf = filterpy_KalmanFilter(dim_x=2, dim_z=1)
        filter_kf.x = torch_kf.design.init_state_mean_params.detach().numpy()[:, None]
        filter_kf.P = Covariance.from_log_cholesky(torch_kf.design.init_cholesky_log_diag,
                                                   torch_kf.design.init_cholesky_off_diag).detach().numpy()

        filter_kf.F = batch_design.F(0)[0].detach().numpy()
        filter_kf.H = batch_design.H(0)[0].detach().numpy()
        filter_kf.R = batch_design.R(0)[0].detach().numpy()
        filter_kf.Q = batch_design.Q(0)[0].detach().numpy()
        filter_kf.states = []
        for t in range(data.shape[1]):
            filter_kf.states.append(filter_kf.x)
            filter_kf.update(data[:, t, :])
            filter_kf.predict()
        filterpy_states = np.stack(filter_kf.states).squeeze()
        kf_states = pred.means.detach().numpy().squeeze()

        for r, c in product(*[range(x) for x in kf_states.shape]):
            self.assertAlmostEqual(filterpy_states[r, c], kf_states[r, c], places=3)
Пример #4
0
    def __init__(self,
                 processes: Sequence[Process],
                 measures: Sequence[str],
                 process_covariance: Optional[Covariance] = None,
                 measure_covariance: Optional[Covariance] = None,
                 initial_covariance: Optional[Covariance] = None,
                 compiled: bool = True,
                 **kwargs):
        """
        :param processes: A list of `Process` modules.
        :param measures: A list of strings specifying the names of the dimensions of the time-series being measured.
        :param process_covariance: A module created with `Covariance.from_processes(processes, cov_type='process')`.
        :param measure_covariance: A module created with `Covariance.from_measures(measures)`.
        :param initial_covariance: A module created with `Covariance.from_processes(processes, cov_type='initial')`.
        :param compiled: Should the core modules be passed through torch.jit.script to compile them to TorchScript?
        Can be disabled if compilation issues arise.
        :param kwargs: Further arguments passed to ScriptKalmanFilter's child-classes (base-class takes no kwargs).
        """
        super(KalmanFilter, self).__init__()

        self._validate(processes, measures)

        # covariances:
        if process_covariance is None:
            process_covariance = Covariance.for_processes(processes,
                                                          cov_type='process')

        if measure_covariance is None:
            measure_covariance = Covariance.for_measures(measures)

        if initial_covariance is None:
            initial_covariance = Covariance.for_processes(processes,
                                                          cov_type='initial')

        self.script_module = self.script_cls(
            kf_step=self.kf_step(),
            processes=processes,
            measures=measures,
            process_covariance=process_covariance,
            measure_covariance=measure_covariance,
            initial_covariance=initial_covariance,
            **kwargs)
        if compiled:
            for pid, p in self.script_module.processes.items():
                self.script_module.processes[pid] = torch.jit.script(p)
            self.script_module = torch.jit.script(self.script_module)
Пример #5
0
    def test_from_log_cholesky(self):
        covs = Covariance.from_log_cholesky(
            log_diag=torch.arange(1., 3.1).expand(3, -1),
            off_diag=torch.arange(1., 3.1).expand(3, -1))

        gt = torch.tensor([[7.3891, 2.7183, 5.4366],
                           [2.7183, 55.5982, 24.1672],
                           [5.4366, 24.1672, 416.4288]])
        for cov in covs:
            diff = (gt - cov).abs()
            self.assertTrue((diff < .0001).all())
Пример #6
0
    def _Q_init(self) -> None:
        partial_proc_cov = Covariance.from_log_cholesky(self.design.process_cholesky_log_diag,
                                                        self.design.process_cholesky_off_diag,
                                                        device=self.device)

        partial_mat_dimnames = list(self.design.all_dynamic_state_elements())
        full_mat_dimnames = list(self.design.all_state_elements())

        # move from partial cov to full w/block-diag:
        Q = torch.zeros(size=(self.num_groups, self.state_size, self.state_size), device=self.device)
        for r in range(len(partial_mat_dimnames)):
            for c in range(len(partial_mat_dimnames)):
                to_r = full_mat_dimnames.index(partial_mat_dimnames[r])
                to_c = full_mat_dimnames.index(partial_mat_dimnames[c])
                Q[:, to_r, to_c] = partial_proc_cov[r, c]

        # process variances are scaled by the variances of the measurements they are associated with:
        measure_log_stds = self.design.measure_scaling().diag().sqrt().log()
        diag_flat = torch.ones(self.state_size, device=self.device)
        for process_name, process in self.processes.items():
            measure_idx = [self.measure_idx[m] for m in process.measures]
            log_scaling = measure_log_stds[measure_idx].mean()
            process_slice = self.process_idx[process_name]
            diag_flat[process_slice] = log_scaling.exp()

        diag_multi = torch.diagflat(diag_flat).expand(self.num_groups, -1, -1)
        Q = diag_multi.matmul(Q).matmul(diag_multi)

        # adjustments from processes:
        diag_multi = torch.eye(self.state_size, device=self.device).expand(self.num_groups, -1, -1).clone()
        dynamic_assignments = []
        for process_id, process in self.processes.items():
            o = self.process_start_idx[process_id]
            for type, var_diag_multis in zip(['base', 'dynamic'], process.variance_diag_multi_assignments):
                for state_element, values in var_diag_multis.items():
                    i = process.state_element_idx[state_element] + o
                    if type == 'dynamic':
                        dynamic_assignments.append(((i, i), values))
                    else:
                        diag_multi[:, i, i] = values

        self._Q_base = diag_multi.matmul(Q).matmul(diag_multi)
        self._Q_diag_multi_dynamic_assignments = dynamic_assignments
Пример #7
0
 def test_empty_idx(self):
     module = torch.jit.script(Covariance(id='test', rank=3, empty_idx=[0]))
     cov = module({}, {})
     self.assertTrue((cov[0, :] == 0).all())
     self.assertTrue((cov[:, 0] == 0).all())
     self.assertTrue((cov == cov.t()).all())
Пример #8
0
    def __init__(self,
                 design: 'Design',
                 num_groups: int,
                 num_timesteps: int,
                 process_kwargs: Optional[Dict[str, Dict]] = None):
        process_kwargs = process_kwargs or {}

        self.design = design
        self.device = design.device

        # process indices:
        self.process_idx = design.process_idx
        self.process_start_idx = {process_id: idx.start for process_id, idx in self.process_idx.items()}

        # initial mean/cov:
        self.initial_mean = torch.zeros(num_groups, design.state_size, device=self.device)
        init_cov = Covariance.from_log_cholesky(log_diag=design.init_cholesky_log_diag,
                                                off_diag=design.init_cholesky_off_diag,
                                                device=self.device)
        self.initial_covariance = init_cov.expand(num_groups, -1, -1)

        # create processes for batch:
        assert set(process_kwargs.keys()).issubset(design.processes.keys())
        assert isinstance(design.processes, OrderedDict)  # below assumes key ordering
        self.processes: Dict[str, ProcessForBatch] = OrderedDict()
        for process_name, process in design.processes.items():
            this_proc_kwargs = process_kwargs.get(process_name, {})

            # assign process:
            try:
                self.processes[process_name] = process.for_batch(num_groups=num_groups,
                                                                 num_timesteps=num_timesteps,
                                                                 **this_proc_kwargs)
            except TypeError as e:
                # if missing kwargs, useful to know which process in the traceback
                raise TypeError("`{pn}.for_batch` raised the following error:\n{e}".format(pn=process_name, e=e))

            # assign initial mean:
            pslice = self.process_idx[process_name]
            self.initial_mean[:, pslice] = process.initial_state_means_for_batch(design.init_state_mean_params[pslice],
                                                                                 num_groups=num_groups,
                                                                                 **this_proc_kwargs)

        # measures:
        self.measure_idx = {measure_id: i for i, measure_id in enumerate(design.measures)}

        # size:
        self.num_groups = num_groups
        self.num_timesteps = num_timesteps
        self.state_size = design.state_size
        self.measure_size = design.measure_size

        # transitions:
        self._F_base: Tensor = None
        self._F_dynamic_assignments: List[Tuple[Tuple[int, int], SeqOfTensors]] = None
        self._F_init()

        # measurements:
        self._H_base: Tensor = None
        self._H_dynamic_assignments: List[Tuple[Tuple[int, int], SeqOfTensors]] = None
        self._H_init()

        # process-var:
        self._Q_base: Tensor = None
        self._Q_diag_multi_dynamic_assignments: List[Tuple[Tuple[int, int], SeqOfTensors]] = None
        self._Q_init()

        # measure-var:
        self._R_base: Tensor = None
        # R_dynamic_assignments not implemented yet
        self._R_init()
Пример #9
0
 def measure_scaling(self) -> Tensor:
     return Covariance.from_log_cholesky(self.measure_cholesky_log_diag,
                                         self.measure_cholesky_off_diag,
                                         device=self.device)