コード例 #1
0
class Process(NiceRepr, Batchable):
    _repr_attrs = ('id', )

    def __init__(self,
                 id: str,
                 state_elements: Sequence[str],
                 initial_state: Optional[torch.nn.Module] = None):
        self.id = str(id)
        self.state_elements = state_elements

        # transitions:
        self.transition_mat = TransitionMatrix(self.state_elements,
                                               self.state_elements)

        # state-element -> measure
        # measures will be appended in add_measure, but state-elements need to be known at init
        self.measure_mat = MeasureMatrix(dim1_names=None,
                                         dim2_names=self.state_elements)

        # variance of dynamic state elements:
        self.variance_multi_mat = ProcessVarianceMultiplierMatrix(
            self.state_elements, self.dynamic_state_elements)

        # a callable that predicts the initial state
        self.initial_state = initial_state or InitialState(self.state_elements)

        self._validate()

    def for_batch(self, num_groups: int, num_timesteps: int,
                  **kwargs) -> 'Process':
        if not self.measures:
            raise TypeError(f"The process `{self.id}` has no measures.")
        if self.transition_mat.empty:
            raise TypeError(f"The process `{self.id}` has no transitions.")
        for_batch = copy(self)
        for_batch.batch_info = num_groups, num_timesteps
        for_batch.variance_multi_mat = self.variance_multi_mat.for_batch(
            num_groups, num_timesteps)
        for_batch.measure_mat = self.measure_mat.for_batch(
            num_groups, num_timesteps)
        for_batch.transition_mat = self.transition_mat.for_batch(
            num_groups, num_timesteps)
        return for_batch

    @property
    def measures(self):
        return self.measure_mat.measures

    def param_dict(self) -> torch.nn.ParameterDict:
        """
        Any parameters that should be exposed to the owning nn.Module.
        """
        p = torch.nn.ParameterDict()
        if hasattr(self.initial_state, 'named_parameters'):
            for nm, param in self.initial_state.named_parameters():
                p['initial_state_' + nm.replace('.', '_')] = param
        return p

    # children should implement ----------------
    def add_measure(self, measure: str) -> 'Process':
        """
        Calls '_set_measure' with default state_element, value
        """
        raise NotImplementedError

    @property
    def dynamic_state_elements(self) -> Sequence[str]:
        """
        state elements with process-variance. defaults to all
        """
        return self.state_elements

    @property
    def fixed_state_elements(self) -> Sequence[str]:
        """
        state elements with neither process-variance nor initial-variance -- i.e., they are fixed at their initial mean
        """
        return []

    def initial_state_means_for_batch(self, num_groups: int,
                                      **kwargs) -> Tensor:
        if 'num_groups' in self.initial_state._forward_kwargs:
            kwargs['num_groups'] = num_groups
        return self.initial_state(**kwargs)

    # For specifying design -----------:
    def _set_measure(self,
                     measure: str,
                     state_element: str,
                     value: DesignMatAssignment,
                     ilink: Optional[Callable] = None,
                     force: bool = False):
        self.measure_mat.assign(measure=measure,
                                state_element=state_element,
                                value=value,
                                overwrite=force)
        self.measure_mat.set_ilink(measure=measure,
                                   state_element=state_element,
                                   ilink=ilink,
                                   overwrite=force)

    def _adjust_measure(self,
                        measure: str,
                        state_element: str,
                        adjustment: 'DesignMatAdjustment',
                        check_slow_grad: bool = True):
        self.measure_mat.adjust(measure=measure,
                                state_element=state_element,
                                value=adjustment,
                                check_slow_grad=check_slow_grad)

    def _set_transition(self,
                        from_element: str,
                        to_element: str,
                        value: DesignMatAssignment,
                        ilink: Optional[Callable] = None,
                        force: bool = False):
        self.transition_mat.assign(from_element=from_element,
                                   to_element=to_element,
                                   value=value,
                                   overwrite=force)
        self.transition_mat.set_ilink(from_element=from_element,
                                      to_element=to_element,
                                      ilink=ilink,
                                      overwrite=force)

    def _adjust_transition(self,
                           from_element: str,
                           to_element: str,
                           adjustment: 'DesignMatAdjustment',
                           check_slow_grad: bool = True):
        self.transition_mat.adjust(from_element=from_element,
                                   to_element=to_element,
                                   value=adjustment,
                                   check_slow_grad=check_slow_grad)

    # no _set_variance: base handled by design, adjustments forced to be link='log'
    def _adjust_variance(self,
                         state_element: str,
                         adjustment: 'DesignMatAdjustment',
                         check_slow_grad: bool = True):
        self.variance_multi_mat.adjust(state_element=state_element,
                                       value=adjustment,
                                       check_slow_grad=check_slow_grad)

    # util methods ----------------
    def _validate(self):
        if len(self.state_elements) != len(set(self.state_elements)):
            raise ValueError("Duplicate `state_elements`.")
        if not set(self.dynamic_state_elements).isdisjoint(
                self.fixed_state_elements):
            raise ValueError(
                "Class has been misconfigured: some fixed state-elements are also dynamic-state-elements."
            )

    def batch_kwargs(self) -> Iterable[str]:
        if type(self).for_batch.__code__ == Process.for_batch.__code__:
            yield from ()
            return
        excluded = {'self', 'num_groups', 'num_timesteps'}
        for kwarg in inspect.signature(self.for_batch).parameters:
            if kwarg in excluded:
                continue
            if kwarg == 'kwargs':
                raise TypeError(
                    f"Signature of `{type(self).__name__}.for_batch` must define its keyword args explicitly."
                )
            yield kwarg

    def init_mean_kwargs(self) -> Iterable[str]:
        if not hasattr(self.initial_state, '_forward_kwargs'):
            self.initial_state._forward_kwargs = infer_forward_kwargs(
                self.initial_state)
        return self.initial_state._forward_kwargs
コード例 #2
0
class Process(NiceRepr, Batchable):
    _repr_attrs = ('id', )

    def __init__(self, id: str, state_elements: Sequence[str]):
        self.id = str(id)
        self.state_elements = state_elements

        # transitions:
        self.transition_mat = TransitionMatrix(self.state_elements,
                                               self.state_elements)

        # state-element -> measure
        # measures will be appended in add_measure, but state-elements need to be known at init
        self.measure_mat = MeasureMatrix(dim1_names=None,
                                         dim2_names=self.state_elements)

        # variance of dynamic state elements:
        self.variance_multi_mat = ProcessVarianceMultiplierMatrix(
            self.state_elements, self.dynamic_state_elements)

        self._validate()

    def for_batch(self, num_groups: int, num_timesteps: int,
                  **kwargs) -> 'Process':
        if not self.measures:
            raise TypeError(f"The process `{self.id}` has no measures.")
        if self.transition_mat.empty:
            raise TypeError(f"The process `{self.id}` has no transitions.")
        for_batch = copy(self)
        for_batch.batch_info = num_groups, num_timesteps
        for_batch.variance_multi_mat = self.variance_multi_mat.for_batch(
            num_groups, num_timesteps)
        for_batch.measure_mat = self.measure_mat.for_batch(
            num_groups, num_timesteps)
        for_batch.transition_mat = self.transition_mat.for_batch(
            num_groups, num_timesteps)
        return for_batch

    @property
    def measures(self):
        return self.measure_mat.measures

    # children should implement ----------------
    def param_dict(self) -> torch.nn.ParameterDict:
        """
        Any parameters that should be exposed to the owning nn.Module.
        """
        raise NotImplementedError

    def add_measure(self, measure: str) -> 'Process':
        """
        Calls '_set_measure' with default state_element, value
        """
        raise NotImplementedError

    @property
    def dynamic_state_elements(self) -> Sequence[str]:
        """
        state elements with process-variance. defaults to all
        """
        return self.state_elements

    @property
    def fixed_state_elements(self) -> Sequence[str]:
        """
        state elements with neither process-variance nor initial-variance -- i.e., they are fixed at their initial mean
        """
        return []

    def initial_state_means_for_batch(self, parameters: Parameter,
                                      num_groups: int, **kwargs) -> Tensor:
        """
        Most children should use default. Handles rearranging of state-means based on for_batch keyword args. E.g. a
        discrete seasonal process w/ a state-element for each season would need to know on which season the batch starts
        """
        return parameters.expand(num_groups, -1)

    # For specifying design -----------:
    def _set_measure(self,
                     measure: str,
                     state_element: str,
                     value: DesignMatAssignment,
                     ilink: Optional[Callable] = None,
                     force: bool = False):
        self.measure_mat.assign(measure=measure,
                                state_element=state_element,
                                value=value,
                                overwrite=force)
        self.measure_mat.set_ilink(measure=measure,
                                   state_element=state_element,
                                   ilink=ilink,
                                   overwrite=force)

    def _adjust_measure(self,
                        measure: str,
                        state_element: str,
                        adjustment: 'DesignMatAdjustment',
                        check_slow_grad: bool = True):
        self.measure_mat.adjust(measure=measure,
                                state_element=state_element,
                                value=adjustment,
                                check_slow_grad=check_slow_grad)

    def _set_transition(self,
                        from_element: str,
                        to_element: str,
                        value: DesignMatAssignment,
                        ilink: Optional[Callable] = None,
                        force: bool = False):
        self.transition_mat.assign(from_element=from_element,
                                   to_element=to_element,
                                   value=value,
                                   overwrite=force)
        self.transition_mat.set_ilink(from_element=from_element,
                                      to_element=to_element,
                                      ilink=ilink,
                                      overwrite=force)

    def _adjust_transition(self,
                           from_element: str,
                           to_element: str,
                           adjustment: 'DesignMatAdjustment',
                           check_slow_grad: bool = True):
        self.transition_mat.adjust(from_element=from_element,
                                   to_element=to_element,
                                   value=adjustment,
                                   check_slow_grad=check_slow_grad)

    # no _set_variance: base handled by design, adjustments forced to be link='log'
    def _adjust_variance(self,
                         state_element: str,
                         adjustment: 'DesignMatAdjustment',
                         check_slow_grad: bool = True):
        self.variance_multi_mat.adjust(state_element=state_element,
                                       value=adjustment,
                                       check_slow_grad=check_slow_grad)

    # util methods ----------------
    def _validate(self):
        if len(self.state_elements) != len(set(self.state_elements)):
            raise ValueError("Duplicate `state_elements`.")
        if not set(self.dynamic_state_elements).isdisjoint(
                self.fixed_state_elements):
            raise ValueError(
                "Class has been misconfigured: some fixed state-elements are also dynamic-state-elements."
            )

    def __init_subclass__(cls, **kwargs):
        overrides_batch_kwargs = (cls.batch_kwargs.__code__ !=
                                  Process.batch_kwargs.__code__)
        if not overrides_batch_kwargs:

            batch_kwargs = set(cls.batch_kwargs(cls.for_batch))
            init_mean_kwargs = set(
                cls.batch_kwargs(cls.initial_state_means_for_batch))

            overrides_for_batch = (cls.for_batch.__code__ !=
                                   Process.for_batch.__code__)
            overrides_init_mean = (
                cls.initial_state_means_for_batch.__code__ !=
                Process.initial_state_means_for_batch.__code__)

            if overrides_for_batch:
                if 'kwargs' in batch_kwargs:
                    raise TypeError(
                        f"Signature of `{cls.__name__}.for_batch` must define its keyword args explicitly."
                    )
            if overrides_init_mean:
                if 'kwargs' in init_mean_kwargs:
                    raise TypeError(
                        f"Signature of `{cls.__name__}.initial_state_means_for_batch` must define kwargs explicitly."
                    )
            if overrides_for_batch and overrides_init_mean:
                if batch_kwargs != init_mean_kwargs:
                    raise TypeError(
                        f"`{cls.__name__}.initial_state_means_for_batch()` must match signature of .for_batch()"
                    )
        super().__init_subclass__()

    @classmethod
    def batch_kwargs(cls, method: Optional[Callable] = None) -> Iterable[str]:
        if method is None:
            method = cls.for_batch
        excluded = {'self', 'num_groups', 'num_timesteps', 'parameters'}
        for kwarg in inspect.signature(method).parameters:
            if kwarg in excluded:
                continue
            yield kwarg