Ejemplo n.º 1
0
    def test_text_processor(self):
        x_set, y_set = TestMacros.load_labeling_corpus()
        x_samples = random.sample(x_set, 5)
        text_processor = SequenceProcessor(min_count=1)
        text_processor.build_vocab(x_set, y_set)
        text_idx = text_processor.transform(x_samples)

        text_info_dict = text_processor.to_dict()
        text_processor2: SequenceProcessor = load_data_object(text_info_dict)

        text_idx2 = text_processor2.transform(x_samples)
        sample_lengths = [len(i) for i in x_samples]

        assert (text_idx2 == text_idx).all()
        assert text_processor.inverse_transform(
            text_idx, lengths=sample_lengths) == x_samples
        assert text_processor2.inverse_transform(
            text_idx2, lengths=sample_lengths) == x_samples
Ejemplo n.º 2
0
    def test_base_cases(self):
        embedding = self.build_embedding()
        x, y = SMP2018ECDTCorpus.load_data()
        processor = SequenceProcessor()
        processor.build_vocab(x, y)
        embedding.setup_text_processor(processor)

        samples = random.sample(x, sample_count)
        res = embedding.embed(samples)
        max_len = max([len(i) for i in samples]) + 2

        if embedding.max_position is not None:
            max_len = embedding.max_position

        assert res.shape == (len(samples), max_len, embedding.embedding_size)

        # Test Save And Load
        embed_dict = embedding.to_dict()
        embedding2 = load_data_object(embed_dict)
        embedding2.setup_text_processor(processor)
        assert embedding2.embed(samples).shape == (len(samples), max_len,
                                                   embedding.embedding_size)
Ejemplo n.º 3
0
    def test_label_processor(self):
        x_set, y_set = TestMacros.load_labeling_corpus()
        text_processor = SequenceProcessor(build_vocab_from_labels=True,
                                           min_count=1)
        text_processor.build_vocab(x_set, y_set)

        samples = random.sample(y_set, 20)

        text_idx = text_processor.transform(samples)

        text_info_dict = text_processor.to_dict()

        text_processor2: SequenceProcessor = load_data_object(text_info_dict)

        text_idx2 = text_processor2.transform(samples)
        lengths = [len(i) for i in samples]
        assert (text_idx2 == text_idx).all()
        assert text_processor2.inverse_transform(text_idx,
                                                 lengths=lengths) == samples
        assert text_processor2.inverse_transform(text_idx2,
                                                 lengths=lengths) == samples

        text_idx3 = text_processor.transform(samples, seq_length=20)
        assert [len(i) for i in text_idx3] == [20] * len(text_idx3)