예제 #1
0
    def _encode_data(self, dataset: SequenceDataset, params: EncoderParams):
        sequence_objs = [obj for obj in dataset.get_data(params.pool_size)]

        sequences = [obj.get_sequence() for obj in sequence_objs]
        example_ids = dataset.get_example_ids()
        max_seq_len = max([len(seq) for seq in sequences])
        labels = self._get_labels(sequence_objs,
                                  params) if params.encode_labels else None

        examples = self._encode_sequence_list(
            sequences,
            pad_n_sequences=len(sequence_objs),
            pad_sequence_len=max_seq_len)

        feature_names = self._get_feature_names(max_seq_len)

        if self.flatten:
            examples = examples.reshape(
                (len(sequence_objs),
                 max_seq_len * len(self.onehot_dimensions)))
            feature_names = [
                item for sublist in feature_names for item in sublist
            ]

        encoded_data = EncodedData(examples=examples,
                                   labels=labels,
                                   example_ids=example_ids,
                                   feature_names=feature_names,
                                   encoding=OneHotEncoder.__name__)

        return encoded_data
예제 #2
0
    def _encode_data(self, dataset: SequenceDataset, params: EncoderParams):
        sequence_objs = [obj for obj in dataset.get_data(params.pool_size)]

        sequences = [
            obj.get_sequence(self.sequence_type) for obj in sequence_objs
        ]

        if any(seq is None for seq in sequences):
            raise ValueError(
                f"{OneHotEncoder.__name__}: sequence dataset {dataset.name} (id: {dataset.identifier}) contains empty sequences for the specified "
                f"sequence type {self.sequence_type.name.lower()}. Please check that the dataset is imported correctly."
            )

        example_ids = dataset.get_example_ids()
        max_seq_len = max([len(seq) for seq in sequences])
        labels = self._get_labels(sequence_objs,
                                  params) if params.encode_labels else None

        examples = self._encode_sequence_list(
            sequences,
            pad_n_sequences=len(sequence_objs),
            pad_sequence_len=max_seq_len)

        feature_names = self._get_feature_names(max_seq_len)

        if self.flatten:
            examples = examples.reshape(
                (len(sequence_objs),
                 max_seq_len * len(self.onehot_dimensions)))
            feature_names = [
                item for sublist in feature_names for item in sublist
            ]

        encoded_data = EncodedData(examples=examples,
                                   labels=labels,
                                   example_ids=example_ids,
                                   feature_names=feature_names,
                                   encoding=OneHotEncoder.__name__)

        return encoded_data