예제 #1
0
    def __init__(self,
                 categories: Optional[List[Union[int, bool, str]]] = None,
                 allow_missing: Optional[bool] = None):
        """

        Parameters
        ----------
        categories
            The possible categories
        allow_missing
            Whether the categorical column is allowed to contain missing values
        """
        super().__init__()
        self._allow_missing = allow_missing
        self._freq = None
        if categories is not None:
            if type(categories[0]).__module__ == np.__name__:
                categories = [ele.item() for ele in categories]
            assert allow_missing is not None
            if allow_missing:
                self._vocab = Vocab(categories)
            else:
                self._vocab = Vocab(categories, unk_token=None)
        else:
            self._vocab = None
예제 #2
0
 def parse(self, column_data: pd.Series):
     super().parse(column_data=column_data)
     if self._allow_missing is None:
         if self.num_missing_sample > 0:
             self._allow_missing = True
         else:
             self._allow_missing = False
     value_counts = column_data.value_counts()
     if self._vocab is None:
         categories = sorted(list(value_counts.keys()))
         if type(categories[0]).__module__ == np.__name__:
             categories = [ele.item() for ele in categories]
         if self._allow_missing:
             self._vocab = Vocab(tokens=categories)
         else:
             self._vocab = Vocab(tokens=categories, unk_token=None)
     self._freq = [value_counts[ele] if ele in value_counts else 0 for ele in self.categories]
예제 #3
0
 def parse(self, column_data: pd.Series):
     super().parse(column_data)
     # Store statistics
     all_span_lengths = []
     categorical_label_counter = collections.Counter()
     for idx, entities in column_data.items():
         if entities is None:
             continue
         if isinstance(entities, dict) or isinstance(entities, tuple):
             entities = [entities]
         assert isinstance(entities, list),\
             'The entity type is "{}" and is not supported by ' \
             'GluonNLP. Received entities={}'.format(type(entities), entities)
         for entity in entities:
             if isinstance(entity, dict):
                 start = entity['start']
                 end = entity['end']
                 label = entity.get('label', None)
             else:
                 assert isinstance(entity, tuple)
                 if len(entity) == 2:
                     start, end = entity
                     label = None
                 else:
                     start, end, label = entity
             all_span_lengths.append(end - start)
             label_type = _get_entity_label_type(label)
             if label_type == _C.CATEGORICAL:
                 categorical_label_counter[label] += 1
             elif label_type == _C.NUMERICAL and self._label_shape is None:
                 self._label_shape = np.array(label).shape
             if self._label_type is not None:
                 assert self._label_type == label_type, \
                     'Unmatched label types. ' \
                     'The type of labels of all entities should be consistent. ' \
                     'Received label type="{}".' \
                     ' Stored label_type="{}"'.format(label_type, self._label_type)
             else:
                 self._label_type = label_type
     self._num_total_entity = len(all_span_lengths)
     self._avg_entity_per_sample = len(all_span_lengths) / self.num_valid_sample
     self._avg_span_length = np.mean(all_span_lengths).item()
     if self._label_type == _C.CATEGORICAL:
         if self._label_vocab is None:
             keys = sorted(categorical_label_counter.keys())
             self._label_vocab = Vocab(tokens=keys,
                                       unk_token=None)
             self._label_freq = [categorical_label_counter[ele] for ele in keys]
         else:
             for key in categorical_label_counter.keys():
                 if key not in self._label_vocab:
                     raise ValueError('The entity label="{}" is not found in the provided '
                                      'vocabulary. The provided labels="{}"'
                                      .format(key,
                                              self._label_vocab.all_tokens))
             self._label_freq = [categorical_label_counter[ele]
                                 for ele in self._label_vocab.all_tokens]
예제 #4
0
    def __init__(self, parent,
                 label_type=None,
                 label_shape=None,
                 label_keys=None):
        """

        Parameters
        ----------
        parent
            The column name of its parent
        label_type
            The type of the labels.
            Can be the following:
            - null
            - categorical
            - numerical
        label_shape
            The shape of the label. Only be available when the entity contains numerical label
        label_keys
            The vocabulary of the categorical label.
            It is only available when the entity contains categorical label.
        """
        super().__init__()
        self._parent = parent
        self._label_type = label_type
        self._label_shape = label_shape
        if self._label_shape is not None:
            self._label_shape = tuple(self._label_shape)
        if label_keys is not None:
            self._label_vocab = Vocab(tokens=label_keys,
                                      unk_token=None)
        else:
            self._label_vocab = None
        self._label_freq = None
        self._num_total_entity = None
        self._avg_entity_per_sample = None
        self._avg_span_length = None