Esempio n. 1
0
    def from_dataset(
        cls,
        dataset: TimeSeriesDataSet,
        allowed_encoder_known_variable_names: List[str] = None,
        **kwargs,
    ) -> LightningModule:
        """
        Create model from dataset and set parameters related to covariates.

        Args:
            dataset: timeseries dataset
            allowed_encoder_known_variable_names: List of known variables that are allowed in encoder, defaults to all
            **kwargs: additional arguments such as hyperparameters for model (see ``__init__()``)

        Returns:
            LightningModule
        """
        # assert fixed encoder and decoder length for the moment
        if allowed_encoder_known_variable_names is None:
            allowed_encoder_known_variable_names = (
                dataset.time_varying_known_categoricals + dataset.time_varying_known_reals
            )

        # embeddings
        embedding_labels = {
            name: encoder.classes_
            for name, encoder in dataset.categorical_encoders.items()
            if name in dataset.categoricals
        }
        embedding_paddings = dataset.dropout_categoricals
        # determine embedding sizes based on heuristic
        embedding_sizes = {
            name: (len(encoder.classes_), get_embedding_size(len(encoder.classes_)))
            for name, encoder in dataset.categorical_encoders.items()
            if name in dataset.categoricals
        }
        embedding_sizes.update(kwargs.get("embedding_sizes", {}))
        kwargs.setdefault("embedding_sizes", embedding_sizes)

        new_kwargs = dict(
            static_categoricals=dataset.static_categoricals,
            time_varying_categoricals_encoder=[
                name for name in dataset.time_varying_known_categoricals if name in allowed_encoder_known_variable_names
            ]
            + dataset.time_varying_unknown_categoricals,
            time_varying_categoricals_decoder=dataset.time_varying_known_categoricals,
            static_reals=dataset.static_reals,
            time_varying_reals_encoder=[
                name for name in dataset.time_varying_known_reals if name in allowed_encoder_known_variable_names
            ]
            + dataset.time_varying_unknown_reals,
            time_varying_reals_decoder=dataset.time_varying_known_reals,
            x_reals=dataset.reals,
            x_categoricals=dataset.flat_categoricals,
            embedding_labels=embedding_labels,
            embedding_paddings=embedding_paddings,
            categorical_groups=dataset.variable_groups,
        )
        new_kwargs.update(kwargs)
        return super().from_dataset(dataset, **new_kwargs)
Esempio n. 2
0
    def __init__(
        self,
        embedding_sizes: Union[Dict[str, Tuple[int, int]], Dict[str, int],
                               List[int], List[Tuple[int, int]]],
        x_categoricals: List[str] = None,
        categorical_groups: Dict[str, List[str]] = {},
        embedding_paddings: List[str] = [],
        max_embedding_size: int = None,
    ):
        """Embedding layer for categorical variables including groups of categorical variables.

        Enabled for static and dynamic categories (i.e. 3 dimensions for batch x time x categories).

        Args:
            embedding_sizes (Union[Dict[str, Tuple[int, int]], Dict[str, int], List[int], List[Tuple[int, int]]]):
                either

                * dictionary of embedding sizes, e.g. ``{'cat1': (10, 3)}``
                  indicates that the first categorical variable has 10 unique values which are mapped to 3 embedding
                  dimensions. Use :py:func:`~pytorch_forecasting.utils.get_embedding_size` to automatically obtain
                  reasonable embedding sizes depending on the number of categories.
                * dictionary of categorical sizes, e.g. ``{'cat1': 10}`` where embedding sizes are inferred by
                  :py:func:`~pytorch_forecasting.utils.get_embedding_size`.
                * list of embedding and categorical sizes, e.g. ``[(10, 3), (20, 2)]`` (requires ``x_categoricals`` to
                  be empty)
                * list of categorical sizes where embedding sizes are inferred by
                  :py:func:`~pytorch_forecasting.utils.get_embedding_size` (requires ``x_categoricals`` to be empty).

                If input is provided as list, output will be a single tensor of shape batch x (optional) time x
                sum(embedding_sizes). Otherwise, output is a dictionary of embedding tensors.
            x_categoricals (List[str]): list of categorical variables that are used as input.
            categorical_groups (Dict[str, List[str]]): dictionary of categories that should be summed up in an
                embedding bag, e.g. ``{'cat1': ['cat2', 'cat3']}`` indicates that a new categorical variable ``'cat1'``
                is mapped to an embedding bag containing the second and third categorical variables.
                Defaults to empty dictionary.
            embedding_paddings (List[str]): list of categorical variables for which the value 0 is mapped to a zero
                embedding vector. Defaults to empty list.
            max_embedding_size (int, optional): if embedding size defined by ``embedding_sizes`` is larger than
                ``max_embedding_size``, it will be constrained. Defaults to None.
        """
        super().__init__()
        if isinstance(embedding_sizes, dict):
            self.concat_output = False  # return dictionary of embeddings
            # conduct input data checks
            assert x_categoricals is not None, "x_categoricals must be provided."
            categorical_group_variables = [
                name for names in categorical_groups.values() for name in names
            ]
            if len(categorical_groups) > 0:
                assert all(name in embedding_sizes
                           for name in categorical_groups
                           ), "categorical_groups must be in embedding_sizes."
                assert not any(
                    name in embedding_sizes
                    for name in categorical_group_variables
                ), "group variables in categorical_groups must not be in embedding_sizes."
                assert all(
                    name in x_categoricals
                    for name in categorical_group_variables
                ), "group variables in categorical_groups must be in x_categoricals."
            assert all(
                name in embedding_sizes for name in embedding_sizes
                if name not in categorical_group_variables
            ), ("all variables in embedding_sizes must be in x_categoricals - but only if"
                "not already in categorical_groups.")
        else:
            assert (
                x_categoricals is None and len(categorical_groups) == 0
            ), "If embedding_sizes is not a dictionary, categorical_groups and x_categoricals must be empty."
            # number embeddings based on order
            embedding_sizes = {
                str(name): size
                for name, size in enumerate(embedding_sizes)
            }
            x_categoricals = list(embedding_sizes.keys())
            self.concat_output = True

        # infer embedding sizes if not determined
        self.embedding_sizes = {
            name:
            (size, get_embedding_size(size)) if isinstance(size, int) else size
            for name, size in embedding_sizes.items()
        }
        self.categorical_groups = categorical_groups
        self.embedding_paddings = embedding_paddings
        self.max_embedding_size = max_embedding_size
        self.x_categoricals = x_categoricals

        self.init_embeddings()