def test_groupby( num_features_per_attribute: Dict[Text, int], specified_attributes: Optional[List[Text]], ): features_list = [] for attribute, number in num_features_per_attribute.items(): for idx in range(number): matrix = np.full(shape=(1, idx + 1), fill_value=idx + 1) config = dict( features=matrix, attribute=attribute, feature_type=FEATURE_TYPE_SEQUENCE, # doesn't matter origin=f"origin-{idx}", # doens't matter ) feat = Features(**config) features_list.append(feat) result = Features.groupby_attribute(features_list, attributes=specified_attributes) if specified_attributes is None: for attribute, number in num_features_per_attribute.items(): if number > 0: assert attribute in result assert len(result[attribute]) == number else: assert attribute not in result else: assert set(result.keys()) == set(specified_attributes) for attribute in specified_attributes: assert attribute in result number = num_features_per_attribute.get(attribute, 0) assert len(result[attribute]) == number
def collect_features(self, sub_state: SubState, attributes: Optional[Iterable[Text]] = None ) -> Dict[Text, List[Features]]: """Collects features for all attributes in the given substate. There might be be multiple messages in the container that contain features relevant for the given substate, e.g. this is the case if `TEXT` and `INTENT` are present in the given substate. All of those messages will be collected and their features combined. Args: sub_state: substate for which we want to extract the relevent features attributes: if not `None`, this specifies the list of the attributes of the `Features` that we're interested in (i.e. all other `Features` contained in the relevant messages will be ignored) Returns: a dictionary that maps all the (requested) attributes to a list of `Features` Raises: `ValueError`: if there exists some key pair (i.e. key attribute and corresponding value) from the given substate cannot be found `RuntimeError`: if features for the same attribute are found in two different messages that are associated with the given substate """ # If we specify a list of attributes, then we want a dict with one entry # for each attribute back - even if the corresponding list of features is empty. features: Dict[Text, List[Features]] = (dict() if attributes is None else { attribute: [] for attribute in attributes }) # collect all relevant key attributes key_attributes = set(sub_state.keys()).intersection( self.KEY_ATTRIBUTES) for key_attribute in key_attributes: key_value = str(sub_state[key_attribute]) message = self._table[key_attribute].get(key_value) if not message: raise ValueError( f"Unknown key ({key_attribute},{key_value}). Cannot retrieve " f"features for substate {sub_state}") features_from_message = Features.groupby_attribute( message.features, attributes=attributes) for feat_attribute, feat_value in features_from_message.items(): existing_values = features.get(feat_attribute) # Note: the following if-s are needed because if we specify a list of # attributes then `features_from_message` will contain one entry per # attribute even if the corresponding feature list is empty. if feat_value and existing_values: raise RuntimeError( f"Feature for attribute {feat_attribute} has already been " f"extracted from a different message stored under a key " f"in {key_attributes} " f"that is different from {key_attribute}. This means there's a " f"redundancy in the message container.") if feat_value: features[feat_attribute] = feat_value return features