def test_text_processor(self): x_set, y_set = TestMacros.load_labeling_corpus() x_samples = random.sample(x_set, 5) text_processor = SequenceProcessor(min_count=1) text_processor.build_vocab(x_set, y_set) text_idx = text_processor.transform(x_samples) text_info_dict = text_processor.to_dict() text_processor2: SequenceProcessor = load_data_object(text_info_dict) text_idx2 = text_processor2.transform(x_samples) sample_lengths = [len(i) for i in x_samples] assert (text_idx2 == text_idx).all() assert text_processor.inverse_transform( text_idx, lengths=sample_lengths) == x_samples assert text_processor2.inverse_transform( text_idx2, lengths=sample_lengths) == x_samples
def test_base_cases(self): embedding = self.build_embedding() x, y = SMP2018ECDTCorpus.load_data() processor = SequenceProcessor() processor.build_vocab(x, y) embedding.setup_text_processor(processor) samples = random.sample(x, sample_count) res = embedding.embed(samples) max_len = max([len(i) for i in samples]) + 2 if embedding.max_position is not None: max_len = embedding.max_position assert res.shape == (len(samples), max_len, embedding.embedding_size) # Test Save And Load embed_dict = embedding.to_dict() embedding2 = load_data_object(embed_dict) embedding2.setup_text_processor(processor) assert embedding2.embed(samples).shape == (len(samples), max_len, embedding.embedding_size)
def test_label_processor(self): x_set, y_set = TestMacros.load_labeling_corpus() text_processor = SequenceProcessor(build_vocab_from_labels=True, min_count=1) text_processor.build_vocab(x_set, y_set) samples = random.sample(y_set, 20) text_idx = text_processor.transform(samples) text_info_dict = text_processor.to_dict() text_processor2: SequenceProcessor = load_data_object(text_info_dict) text_idx2 = text_processor2.transform(samples) lengths = [len(i) for i in samples] assert (text_idx2 == text_idx).all() assert text_processor2.inverse_transform(text_idx, lengths=lengths) == samples assert text_processor2.inverse_transform(text_idx2, lengths=lengths) == samples text_idx3 = text_processor.transform(samples, seq_length=20) assert [len(i) for i in text_idx3] == [20] * len(text_idx3)