Ejemplo n.º 1
0
    def __init__(self, processes: Sequence[Process], measures: Sequence[str],
                 **kwargs):
        """
        :param processes: Processes
        :param measures: Measure-names
        :param kwargs: Not used by this base-class, passed from `KalmanFilter.__init__()`
        """
        self.measures = tuple(measures)

        self.processes = OrderedDict()
        for process in processes:
            if process.id in self.processes.keys():
                raise ValueError(f"Duplicate process-ids: {process.id}.")
            self.processes[process.id] = process

        self._validate()

        # params:
        self._initial_mean = None
        self.init_mean_params = Parameter(torch.randn(len(
            self.state_elements)))
        self.init_covariance = PartialCovarianceFromLogCholesky(
            full_dim_names=self.state_elements,
            partial_dim_names=self.unfixed_state_elements)
        self.process_covariance = PartialCovarianceFromLogCholesky(
            full_dim_names=self.state_elements,
            partial_dim_names=self.dynamic_state_elements)
        self.measure_covariance = CovarianceFromLogCholesky(
            rank=len(self.measures))
Ejemplo n.º 2
0
    def __init__(self,
                 processes: Sequence[Process],
                 measures: Sequence[str],
                 measure_var_predict: Sequence[torch.nn.Module] = (),
                 process_var_predict: Sequence[torch.nn.Module] = ()):
        """
        :param processes: Processes
        :param measures: Measure-names
        :param measure_var_predict: See documentation for KalmanFilter.
        :param process_var_predict: See documentation for KalmanFilter.
        """
        self.measures = tuple(measures)

        self.processes = OrderedDict()
        for process in processes:
            if process.id in self.processes.keys():
                raise ValueError(f"Duplicate process-ids: {process.id}.")
            self.processes[process.id] = process

        self._validate()

        # process-variance predictions:
        self._process_var_nn = self._standardize_var_nn(process_var_predict,
                                                        var_type='process',
                                                        top_level=True)

        # measure-variance predictions:
        self._measure_var_nn = self._standardize_var_nn(measure_var_predict,
                                                        var_type='measure',
                                                        top_level=True)

        # params:

        # initial:
        self._initial_mean = None
        self.init_mean_params = Parameter(
            .1 * torch.randn(len(self.state_elements)))
        self.init_covariance = PartialCovarianceFromLogCholesky(
            full_dim_names=self.state_elements,
            partial_dim_names=self.unfixed_state_elements)

        # process:
        self.process_covariance = PartialCovarianceFromLogCholesky(
            full_dim_names=self.state_elements,
            partial_dim_names=self.dynamic_state_elements)

        # measure:
        self.measure_covariance = CovarianceFromLogCholesky(
            rank=len(self.measures))
        self._measure_var_adjustments = MeasureVarianceMultiplierMatrix(
            self.measures)
Ejemplo n.º 3
0
class Design(NiceRepr, Batchable):
    """
    A class for specifying the 'design' of a KalmanFilter -- i.e. what measures are modeled by what processes.
    """
    _repr_attrs = ('process_list', 'measures')

    def __init__(self,
                 processes: Sequence[Process],
                 measures: Sequence[str],
                 measure_var_predict: Sequence[torch.nn.Module] = (),
                 process_var_predict: Sequence[torch.nn.Module] = ()):
        """
        :param processes: Processes
        :param measures: Measure-names
        :param measure_var_predict: See documentation for KalmanFilter.
        :param process_var_predict: See documentation for KalmanFilter.
        """
        self.measures = tuple(measures)

        self.processes = OrderedDict()
        for process in processes:
            if process.id in self.processes.keys():
                raise ValueError(f"Duplicate process-ids: {process.id}.")
            self.processes[process.id] = process

        self._validate()

        # process-variance predictions:
        self._process_var_nn = self._standardize_var_nn(process_var_predict,
                                                        var_type='process',
                                                        top_level=True)

        # measure-variance predictions:
        self._measure_var_nn = self._standardize_var_nn(measure_var_predict,
                                                        var_type='measure',
                                                        top_level=True)

        # params:

        # initial:
        self._initial_mean = None
        self.init_mean_params = Parameter(
            .1 * torch.randn(len(self.state_elements)))
        self.init_covariance = PartialCovarianceFromLogCholesky(
            full_dim_names=self.state_elements,
            partial_dim_names=self.unfixed_state_elements)

        # process:
        self.process_covariance = PartialCovarianceFromLogCholesky(
            full_dim_names=self.state_elements,
            partial_dim_names=self.dynamic_state_elements)

        # measure:
        self.measure_covariance = CovarianceFromLogCholesky(
            rank=len(self.measures))
        self._measure_var_adjustments = MeasureVarianceMultiplierMatrix(
            self.measures)

    @cached_property
    def state_elements(self) -> Sequence[Tuple[str, str]]:
        out = []
        for process_name, process in self.processes.items():
            out.extend((process_name, state_element)
                       for state_element in process.state_elements)
        return out

    @cached_property
    def dynamic_state_elements(self) -> Sequence[Tuple[str, str]]:
        out = []
        for process_name, process in self.processes.items():
            out.extend((process_name, state_element)
                       for state_element in process.dynamic_state_elements)
        return out

    @cached_property
    def unfixed_state_elements(self) -> Sequence[Tuple[str, str]]:
        out = []
        for process_name, process in self.processes.items():
            out.extend((process_name, state_element)
                       for state_element in process.state_elements
                       if state_element not in process.fixed_state_elements)
        return out

    @cached_property
    def process_slices(self) -> Dict[str, slice]:
        process_slices = OrderedDict()
        start_counter = 0
        for process_name, process in self.processes.items():
            end_counter = start_counter + len(process.state_elements)
            process_slices[process_name] = slice(start_counter, end_counter)
            start_counter = end_counter
        return process_slices

    def _validate(self):
        if not self.measures:
            raise ValueError("Empty `measures`")
        if len(self.measures) != len(set(self.measures)):
            raise ValueError("Duplicates in `measures`")
        if not self.processes:
            raise ValueError("Empty `processes`")

        used_measures = set()
        for process_name, process in self.processes.items():
            for measure in process.measures:
                if measure not in self.measures:
                    raise RuntimeError(f"{measure} not in `measures`")
                used_measures.add(measure)

        unused_measures = set(self.measures).difference(used_measures)
        if unused_measures:
            raise ValueError(
                f"The following `measures` are not in any of the `processes`:\n{unused_measures}"
            )

    # For Batch -------:
    def for_batch(self, num_groups: int, num_timesteps: int,
                  **kwargs) -> 'Design':
        for_batch = copy(self)
        for_batch.processes = OrderedDict()
        for_batch.batch_info = (num_groups, num_timesteps)
        for_batch._initial_mean = torch.zeros(num_groups,
                                              len(self.state_elements))

        batch_dim_kwargs = {
            'num_groups': num_groups,
            'num_timesteps': num_timesteps
        }

        unused_kwargs = set(kwargs.keys())

        # processes:
        for process_name, process in self.processes.items():
            proc_kwargs, used = self._parse_kwargs(
                batch_kwargs=process.batch_kwargs(),
                prefix=process.id,
                all_kwargs=kwargs,
                aliases=getattr(process, 'batch_kwargs_aliases', {}))
            for k in used:
                unused_kwargs.discard(k)

            # wrap calls w/process-name for easier tracebacks:
            try:
                for_batch.processes[process_name] = process.for_batch(
                    **batch_dim_kwargs, **proc_kwargs)
                for_batch._initial_mean[:, self.process_slices[
                    process_name]] = process.initial_state_means_for_batch(
                        parameters=self.init_mean_params[
                            self.process_slices[process_name]],
                        num_groups=num_groups,
                        **proc_kwargs)
            except Exception as e:
                # add process-name to traceback
                raise type(
                    e
                )(f"Failed to create `{process}.for_batch()` (see traceback above)."
                  ) from e

            if for_batch.processes[process_name] is None:
                raise RuntimeError(
                    f"{process_name}'s `for_batch` call did not return anything."
                )

        # var adjustments:
        for_batch._measure_var_adjustments = self._measure_var_adjustments.for_batch(
            **batch_dim_kwargs)
        for var_type, nn_list in {
                'measure': self._measure_var_nn,
                'process': self._process_var_nn
        }.items():
            for i, nn in enumerate(nn_list):
                nn_kwargs, used = self._parse_kwargs(
                    prefix=f'{var_type}_var_nn{i}',
                    batch_kwargs=nn._forward_kwargs,
                    all_kwargs={
                        **kwargs,
                        **batch_dim_kwargs
                    },
                    aliases=getattr(nn, '_forward_kwargs_aliases', {}))

                # a cheat that makes the `seasonal` alias more convenient:
                if 'datetimes' in nn._forward_kwargs and 'datetimes' not in nn_kwargs and hasattr(
                        nn, '_dt_helper'):
                    if 'start_datetimes' in kwargs:
                        nn_kwargs['datetimes'] = nn._dt_helper.make_grid(
                            kwargs['start_datetimes'], num_timesteps)

                for k in used:
                    unused_kwargs.discard(k)

                adjustments = adjustments_from_nn(
                    nn=nn,
                    **batch_dim_kwargs,
                    nn_kwargs=nn_kwargs,
                    output_names=self.measures
                    if var_type == 'measure' else self.dynamic_state_elements,
                    time_split_kwargs=getattr(nn, '_time_split_kwargs', ()))

                for el, adj in adjustments.items():
                    for_batch._adjust_variance(el,
                                               adjustment=adj,
                                               check_slow_grad=False)

        if unused_kwargs:
            warn("Unexpected keyword arguments: {}".format(unused_kwargs))

        return for_batch

    @property
    def initial_mean(self):
        if self.is_for_batch:
            return self._initial_mean
        else:
            raise RuntimeError(
                f"Tried to access `{type(self).__name__}.initial_mean`, but only possible for output of `for_batch()`."
            )

    # Parameters -------:
    def param_dict(self) -> ModuleDict:
        p = ModuleDict()
        for process_name, process in self.processes.items():
            p[f"process:{process_name}"] = process.param_dict()

        p['measure_cov'] = self.measure_covariance.param_dict()
        p['measure_var_nn'] = self._measure_var_nn

        p['init_state'] = ParameterDict([('mean', self.init_mean_params)])
        p['init_state'].update(self.init_covariance.param_dict().items())

        p['process_cov'] = self.process_covariance.param_dict()
        p['process_var_nn'] = self._process_var_nn

        return p

    # Transition Matrix -------:
    @cached_property
    def F(self) -> DynamicMatrix:
        merged = TransitionMatrix.merge([
            (nm, process.transition_mat)
            for nm, process in self.processes.items()
        ])
        assert list(merged.from_elements) == list(self.state_elements) == list(
            merged.to_elements)
        return merged.compile()

    # Measurement Matrix ------:
    @cached_property
    def H(self) -> DynamicMatrix:
        merged = MeasureMatrix.merge([
            (nm, process.measure_mat)
            for nm, process in self.processes.items()
        ])
        assert list(merged.state_elements) == list(self.state_elements)
        # order dim:
        assert set(merged.measures) == set(self.measures)
        merged.measures[:] = self.measures
        return merged.compile()

    # Process-Covariance Matrix ------:
    def Q(self, t: int) -> torch.Tensor:
        # processes can apply multipliers to the variance of their state-elements:
        diag_multi = self._process_variance_multi(t=t)
        return diag_multi.matmul(self._base_Q).matmul(diag_multi)

    @cached_property
    def _process_variance_multi(self) -> DynamicMatrix:
        merged = ProcessVarianceMultiplierMatrix.merge([
            (nm, process.variance_multi_mat)
            for nm, process in self.processes.items()
        ])
        assert list(merged.state_elements) == list(self.state_elements)
        return merged.compile()

    @cached_property
    def _base_Q(self):
        Q = self.process_covariance.create(leading_dims=())

        # process covariance is scaled by the variances of the measurement-variances:
        Q_rescaled = self._scale_covariance(Q)

        # expand for batch-size:
        return Q_rescaled.expand(self.num_groups, -1, -1)

    # Measure-Covariance Matrix ------:
    def R(self, t: int):
        diag_multi = self._measure_variance_multi(t=t)
        return diag_multi.matmul(self._base_R).matmul(diag_multi)

    @cached_property
    def _measure_variance_multi(self) -> DynamicMatrix:
        return self._measure_var_adjustments.compile()

    @cached_property
    def _base_R(self):
        return self.measure_covariance.create(leading_dims=(self.num_groups, ))

    # Initial Cov ------:
    @cached_property
    def initial_covariance(self) -> torch.Tensor:
        init_cov = self.init_covariance.create(leading_dims=())
        # init covariance is scaled by the variances of the measurement-variances:
        init_cov_rescaled = self._scale_covariance(init_cov)
        # expand for batch-size:
        return init_cov_rescaled.expand(self.num_groups, -1, -1)

    def _scale_covariance(self, cov: torch.Tensor) -> torch.Tensor:
        """
        Rescale variances associated with processes (process-covariance or initial covariance) by the
        measurement-variances. Helpful in practice for training.
        """
        measure_idx_by_measure = {
            measure: i
            for i, measure in enumerate(self.measures)
        }
        measure_log_stds = self.measure_covariance.create().diag().sqrt().log()
        diag_flat = torch.ones(len(self.state_elements))
        for process_name, process in self.processes.items():
            measure_idx = [measure_idx_by_measure[m] for m in process.measures]
            diag_flat[self.process_slices[process_name]] = measure_log_stds[
                measure_idx].mean().exp()
        diag_multi = torch.diagflat(diag_flat)
        cov_rescaled = diag_multi.matmul(cov).matmul(diag_multi)
        return cov_rescaled

    @property
    def process_list(self):
        return list(self.processes.values())

    # Private -----:
    def _parse_kwargs(self, prefix: str, all_kwargs: dict,
                      batch_kwargs: Iterable[str],
                      aliases: dict) -> Tuple[dict, set]:
        too_generic = {'input', 'x'}

        # use sklearn-style disambiguation:
        used = set()
        out = {}
        for k in batch_kwargs:
            specific_key = "{}__{}".format(prefix, k)
            if specific_key in all_kwargs:
                out[k] = all_kwargs[specific_key]
                used.add(specific_key)
            elif k in all_kwargs:
                if k in too_generic:
                    raise ValueError(
                        f"The argument `{k}` is too generic, so it needs to be passed in a way that specifies which "
                        f"process it should be handed off to (e.g. {specific_key})."
                    )
                out[k] = all_kwargs[k]
                used.add(k)
            else:
                alias = aliases.get(k) or aliases.get(specific_key)
                if alias in all_kwargs:
                    out[k] = all_kwargs[alias]
                    used.add(alias)
        return out, used

    def _standardize_var_nn(self,
                            var_nn: Union[torch.nn.Module, Sequence],
                            var_type: str,
                            top_level: bool = False) -> torch.nn.Module:

        if top_level:
            if isinstance(var_nn, torch.nn.ModuleList):
                return var_nn

            if callable(var_nn):
                # they passed a single NN instead of a list, wrap it:
                var_nn = [var_nn]
            elif len(var_nn) > 0 and isinstance(var_nn[0], str):
                # they passed a single alias instead of a list, wrap it:
                var_nn = [var_nn]

            return torch.nn.ModuleList([
                self._standardize_var_nn(sub_nn, var_type) for sub_nn in var_nn
            ])
        else:
            if callable(var_nn):
                out_nn = var_nn
            elif isinstance(var_nn, (tuple, list)):
                alias, args_or_kwargs = var_nn
                num_outputs = len(self.measures if var_type ==
                                  'measure' else self.dynamic_state_elements)
                if alias == 'per_group' and isinstance(args_or_kwargs, int):
                    args_or_kwargs = (args_or_kwargs, )
                if isinstance(args_or_kwargs, dict):
                    args, kwargs = (), args_or_kwargs
                else:
                    args, kwargs = args_or_kwargs, {}

                if alias == 'per_group':
                    if 'embedding_dim' not in kwargs:
                        kwargs['embedding_dim'] = num_outputs
                    out_nn = NamedEmbedding(*args, **kwargs)
                    out_nn._forward_kwargs_aliases = {'input': 'group_names'}
                elif alias == 'seasonal':
                    out_nn = FourierSeasonNN(*args,
                                             **kwargs,
                                             num_outputs=num_outputs)
                    out_nn._time_split_kwargs = ['datetimes']
                else:
                    raise ValueError(
                        f"Known aliases are 'per_group' and 'seasonal'; got '{alias}'"
                    )
            else:
                raise TypeError(
                    f"Expected `{var_type}_var_nn` to be a callable/torch.nn.Module, or a tuple with format "
                    f"`('alias',(arg1,arg2,...)`. Instead got `{type(var_nn)}`."
                )
            if not hasattr(out_nn, '_forward_kwargs'):
                out_nn._forward_kwargs = infer_forward_kwargs(out_nn)
            if not hasattr(out_nn, '_forward_kwargs_aliases'):
                out_nn._forward_kwargs_aliases = {}
            return out_nn

    def _adjust_variance(
        self,
        *args,
        adjustment: 'DesignMatAdjustment',
        check_slow_grad: bool = True,
    ):
        if len(args) == 1:
            if isinstance(args[0], (list, tuple)):
                args = args[0]
        if len(args) == 1:
            assert args[0] in self.measures
            self._measure_var_adjustments.adjust(
                value=adjustment,
                check_slow_grad=check_slow_grad,
                measure=args[0])
        else:
            process, state_element = args
            self.processes[process]._adjust_variance(
                state_element=state_element,
                adjustment=adjustment,
                check_slow_grad=check_slow_grad)
Ejemplo n.º 4
0
class Design(NiceRepr, Batchable):
    """
    A class for specifying the 'design' of a KalmanFilter -- i.e. what measures are modeled by what processes.
    """
    _repr_attrs = ('process_list', 'measures')

    def __init__(self, processes: Sequence[Process], measures: Sequence[str],
                 **kwargs):
        """
        :param processes: Processes
        :param measures: Measure-names
        :param kwargs: Not used by this base-class, passed from `KalmanFilter.__init__()`
        """
        self.measures = tuple(measures)

        self.processes = OrderedDict()
        for process in processes:
            if process.id in self.processes.keys():
                raise ValueError(f"Duplicate process-ids: {process.id}.")
            self.processes[process.id] = process

        self._validate()

        # params:
        self._initial_mean = None
        self.init_mean_params = Parameter(torch.randn(len(
            self.state_elements)))
        self.init_covariance = PartialCovarianceFromLogCholesky(
            full_dim_names=self.state_elements,
            partial_dim_names=self.unfixed_state_elements)
        self.process_covariance = PartialCovarianceFromLogCholesky(
            full_dim_names=self.state_elements,
            partial_dim_names=self.dynamic_state_elements)
        self.measure_covariance = CovarianceFromLogCholesky(
            rank=len(self.measures))

    @cached_property
    def state_elements(self) -> Sequence[Tuple[str, str]]:
        out = []
        for process_name, process in self.processes.items():
            out.extend((process_name, state_element)
                       for state_element in process.state_elements)
        return out

    @cached_property
    def dynamic_state_elements(self) -> Sequence[Tuple[str, str]]:
        out = []
        for process_name, process in self.processes.items():
            out.extend((process_name, state_element)
                       for state_element in process.dynamic_state_elements)
        return out

    @cached_property
    def unfixed_state_elements(self) -> Sequence[Tuple[str, str]]:
        out = []
        for process_name, process in self.processes.items():
            out.extend((process_name, state_element)
                       for state_element in process.state_elements
                       if state_element not in process.fixed_state_elements)
        return out

    @cached_property
    def process_slices(self) -> Dict[str, slice]:
        process_slices = OrderedDict()
        start_counter = 0
        for process_name, process in self.processes.items():
            end_counter = start_counter + len(process.state_elements)
            process_slices[process_name] = slice(start_counter, end_counter)
            start_counter = end_counter
        return process_slices

    def _validate(self):
        if not self.measures:
            raise ValueError("Empty `measures`")
        if len(self.measures) != len(set(self.measures)):
            raise ValueError("Duplicates in `measures`")
        if not self.processes:
            raise ValueError("Empty `processes`")

        used_measures = set()
        for process_name, process in self.processes.items():
            for measure in process.measures:
                if measure not in self.measures:
                    raise RuntimeError(f"{measure} not in `measures`")
                used_measures.add(measure)

        unused_measures = set(self.measures).difference(used_measures)
        if unused_measures:
            raise ValueError(
                f"The following `measures` are not in any of the `processes`:\n{unused_measures}"
            )

    # For Batch -------:
    def for_batch(self, num_groups: int, num_timesteps: int,
                  **kwargs) -> 'Design':
        out = copy(self)
        out.processes = OrderedDict()
        out.batch_info = (num_groups, num_timesteps)
        out._initial_mean = torch.zeros(num_groups, len(self.state_elements))

        for process_name, process in self.processes.items():
            # get kwargs for this process using sklearn-style disambiguation:
            proc_kwargs = {}
            for k in process.for_batch_kwargs():
                specific_key = "{}__{}".format(process.id, k)
                if specific_key in kwargs:
                    proc_kwargs[k] = kwargs[specific_key]
                elif k in kwargs:
                    proc_kwargs[k] = kwargs[k]

            # wrap calls w/process-name for easier tracebacks:
            try:
                out.processes[process_name] = process.for_batch(
                    num_groups=num_groups,
                    num_timesteps=num_timesteps,
                    **proc_kwargs)
                out._initial_mean[:, self.process_slices[
                    process_name]] = process.initial_state_means_for_batch(
                        parameters=self.init_mean_params[
                            self.process_slices[process_name]],
                        num_groups=num_groups,
                        **proc_kwargs)
            except Exception as e:
                # add process-name to traceback
                raise type(
                    e
                )(f"Failed to create `{process}.for_batch()` (see traceback above)."
                  ) from e

            if out.processes[process_name] is None:
                raise RuntimeError(
                    f"{process_name}'s `for_batch` call did not return anything."
                )

        return out

    @property
    def initial_mean(self):
        if self.is_for_batch:
            return self._initial_mean
        else:
            raise RuntimeError(
                f"Tried to access `{type(self).__name__}.initial_mean`, but only possible for output of `for_batch()`."
            )

    # Parameters -------:
    def param_dict(self) -> ModuleDict:
        p = ModuleDict()
        for process_name, process in self.processes.items():
            p[f"process:{process_name}"] = process.param_dict()

        p['measure_cov'] = self.measure_covariance.param_dict()

        p['init_state'] = ParameterDict([('mean', self.init_mean_params)])
        p['init_state'].update(self.init_covariance.param_dict().items())

        p['process_cov'] = self.process_covariance.param_dict()

        return p

    # Transition Matrix -------:
    @cached_property
    def F(self) -> DynamicMatrix:
        merged = TransitionMatrix.merge([
            (nm, process.transition_mat)
            for nm, process in self.processes.items()
        ])
        assert list(merged.from_elements) == list(self.state_elements) == list(
            merged.to_elements)
        return merged.compile()

    # Measurement Matrix ------:
    @cached_property
    def H(self) -> DynamicMatrix:
        merged = MeasureMatrix.merge([
            (nm, process.measure_mat)
            for nm, process in self.processes.items()
        ])
        assert list(merged.state_elements) == list(self.state_elements)
        # order dim:
        assert set(merged.measures) == set(self.measures)
        merged.measures[:] = self.measures
        return merged.compile()

    # Process-Covariance Matrix ------:
    def Q(self, t: int) -> torch.Tensor:
        # processes can apply multipliers to the variance of their state-elements:
        diag_multi = self._process_variance_multi(t=t)
        return diag_multi.matmul(self._base_Q).matmul(diag_multi)

    @cached_property
    def _process_variance_multi(self) -> DynamicMatrix:
        merged = VarianceMultiplierMatrix.merge([
            (nm, process.variance_multi_mat)
            for nm, process in self.processes.items()
        ])
        assert list(merged.state_elements) == list(self.state_elements)
        return merged.compile()

    @cached_property
    def _base_Q(self):
        Q = self.process_covariance.create(leading_dims=())

        # process covariance is scaled by the variances of the measurement-variances:
        Q_rescaled = self._scale_covariance(Q)

        # expand for batch-size:
        return Q_rescaled.expand(self.num_groups, -1, -1)

    # Measure-Covariance Matrix ------:
    def R(self, t: int):
        # base class does not do anything to measure-covariance
        return self._base_R

    @cached_property
    def _base_R(self):
        return self.measure_covariance.create(leading_dims=(self.num_groups, ))

    # Initial Cov ------:
    @cached_property
    def initial_covariance(self) -> torch.Tensor:
        init_cov = self.init_covariance.create(leading_dims=())
        # init covariance is scaled by the variances of the measurement-variances:
        init_cov_rescaled = self._scale_covariance(init_cov)
        # expand for batch-size:
        return init_cov_rescaled.expand(self.num_groups, -1, -1)

    def _scale_covariance(self, cov: torch.Tensor) -> torch.Tensor:
        """
        Rescale variances associated with processes (process-covariance or initial covariance) by the
        measurement-variances. Helpful in practice for training.
        """
        measure_idx_by_measure = {
            measure: i
            for i, measure in enumerate(self.measures)
        }
        measure_log_stds = self.measure_covariance.create().diag().sqrt().log()
        diag_flat = torch.ones(len(self.state_elements))
        for process_name, process in self.processes.items():
            measure_idx = [measure_idx_by_measure[m] for m in process.measures]
            diag_flat[self.process_slices[process_name]] = measure_log_stds[
                measure_idx].mean().exp()
        diag_multi = torch.diagflat(diag_flat)
        cov_rescaled = diag_multi.matmul(cov).matmul(diag_multi)
        return cov_rescaled

    @property
    def process_list(self):
        return list(self.processes.values())