コード例 #1
0
def test_groupby(
    num_features_per_attribute: Dict[Text, int],
    specified_attributes: Optional[List[Text]],
):

    features_list = []
    for attribute, number in num_features_per_attribute.items():
        for idx in range(number):
            matrix = np.full(shape=(1, idx + 1), fill_value=idx + 1)
            config = dict(
                features=matrix,
                attribute=attribute,
                feature_type=FEATURE_TYPE_SEQUENCE,  # doesn't matter
                origin=f"origin-{idx}",  # doens't matter
            )
            feat = Features(**config)
            features_list.append(feat)

    result = Features.groupby_attribute(features_list, attributes=specified_attributes)
    if specified_attributes is None:
        for attribute, number in num_features_per_attribute.items():
            if number > 0:
                assert attribute in result
                assert len(result[attribute]) == number
            else:
                assert attribute not in result
    else:
        assert set(result.keys()) == set(specified_attributes)
        for attribute in specified_attributes:
            assert attribute in result
            number = num_features_per_attribute.get(attribute, 0)
            assert len(result[attribute]) == number
コード例 #2
0
ファイル: precomputation.py プロジェクト: zoovu/rasa
    def collect_features(self,
                         sub_state: SubState,
                         attributes: Optional[Iterable[Text]] = None
                         ) -> Dict[Text, List[Features]]:
        """Collects features for all attributes in the given substate.

        There might be be multiple messages in the container that contain features
        relevant for the given substate, e.g. this is the case if `TEXT` and
        `INTENT` are present in the given substate. All of those messages will be
        collected and their features combined.

        Args:
          sub_state: substate for which we want to extract the relevent features
          attributes: if not `None`, this specifies the list of the attributes of the
            `Features` that we're interested in (i.e. all other `Features` contained
            in the relevant messages will be ignored)

        Returns:
          a dictionary that maps all the (requested) attributes to a list of `Features`

        Raises:
          `ValueError`: if there exists some key pair (i.e. key attribute and
            corresponding value) from the given substate cannot be found
          `RuntimeError`: if features for the same attribute are found in two
            different messages that are associated with the given substate
        """
        # If we specify a list of attributes, then we want a dict with one entry
        # for each attribute back - even if the corresponding list of features is empty.
        features: Dict[Text,
                       List[Features]] = (dict() if attributes is None else {
                           attribute: []
                           for attribute in attributes
                       })
        # collect all relevant key attributes
        key_attributes = set(sub_state.keys()).intersection(
            self.KEY_ATTRIBUTES)
        for key_attribute in key_attributes:
            key_value = str(sub_state[key_attribute])
            message = self._table[key_attribute].get(key_value)
            if not message:
                raise ValueError(
                    f"Unknown key ({key_attribute},{key_value}). Cannot retrieve "
                    f"features for substate {sub_state}")
            features_from_message = Features.groupby_attribute(
                message.features, attributes=attributes)
            for feat_attribute, feat_value in features_from_message.items():
                existing_values = features.get(feat_attribute)
                # Note: the following if-s are needed because if we specify a list of
                # attributes then `features_from_message` will contain one entry per
                # attribute even if the corresponding feature list is empty.
                if feat_value and existing_values:
                    raise RuntimeError(
                        f"Feature for attribute {feat_attribute} has already been "
                        f"extracted from a different message stored under a key "
                        f"in {key_attributes} "
                        f"that is different from {key_attribute}. This means there's a "
                        f"redundancy in the message container.")
                if feat_value:
                    features[feat_attribute] = feat_value
        return features