def init_data(self, train_path: str = 'datasets/rus.train', valid_path: str = 'datasets/rus.test'): self.params.update(locals()) self.train_dataset = BucketDataset(samples=DataLoader(file_path=train_path).load()) self.valid_dataset = BucketDataset(samples=DataLoader(file_path=valid_path).load()) self.bmes_mapping = BMESToIdMapping() self.char_mapping = CharToIdMapping(chars=list(self.train_dataset.get_chars()), include_unknown=True) self.word_mapping = WordSegmentTypeToIdMapping(segments=self.train_dataset.get_segment_types(), include_unknown=False) print('Char mapping:', self.char_mapping) print('Word Segment type mapping:', self.word_mapping) print('BMES mapping:', self.bmes_mapping) self.processor = DataProcessor(char_mapping=self.char_mapping, word_segment_mapping=self.word_mapping, bmes_mapping=self.bmes_mapping) print('Removing wrong labels (current labels are: the cross product [BMES x SegmentTypes])...') labels = list(chain.from_iterable([self.processor.segments_to_label(sample.segments) for sample in self.train_dataset])) self.label_mapping = LabelToIdMapping(labels=labels) label_ids = [self.label_mapping[l] for l in labels] self.class_weights = class_weight.compute_class_weight('balanced', np.unique(label_ids), label_ids) self.processor.label_mapping = self.label_mapping print('Calculated class weights:', self.class_weights) print('Number of classes per char:', self.processor.nb_classes()) return self
def test_dataset(self): loader = DataLoader(samples=[ 'accompanied ac/compani/ed', 'acknowledging ac/knowledg/ing', 'defections defect/ion/s' ]) loaded_samples = loader.load() self.assertEqual(len(loaded_samples), 3) self.assertEqual(loaded_samples[0].word, 'accompanied') self.assertEqual(loaded_samples[1].word, 'acknowledging') self.assertEqual(loaded_samples[1].segments[1].segment, 'knowledg') self.assertIsNone(loaded_samples[1].segments[1].type)
def predict(model_path: str, batch_size: int = 1, input_path='datasets/rus.test', output_path='logs/rus.predictions'): word2morph = Word2Morph.load_model(path=model_path) inputs = DataLoader(file_path=input_path).load() correct, wrong, predicted_samples = word2morph.evaluate(inputs, batch_size=batch_size) with open(output_path, 'w', encoding='utf-8') as f: f.write('\n'.join([str(sample) for sample in predicted_samples]))
def test_crf_model(self): model = RNNModel(nb_symbols=37, embeddings_size=8, dropout=0.2, use_crf=True, nb_classes=25) model.summary() model.compile('adam', crf_loss, metrics=[crf_viterbi_accuracy]) loader = DataLoader( samples=['одуматься о:PREF/дум:ROOT/а:SUFF/ть:SUFF/ся:POSTFIX']) x, y = self.processor.parse_one(sample=loader.load()[0]) pred = model.predict(np.array([x])) self.assertEqual(pred.shape, (1, 9, 25)) print('CRF output:', pred) print(x.shape) print(y.shape)
def test_model_structure(self): model = CNNModel(nb_symbols=37, embeddings_size=8, dropout=0.2, dense_output_units=64, nb_classes=25) model.summary() model.compile('adam', 'sparse_categorical_crossentropy', metrics=['acc']) loader = DataLoader( samples=['одуматься о:PREF/дум:ROOT/а:SUFF/ть:SUFF/ся:POSTFIX']) x, y = self.processor.parse_one(sample=loader.load()[0]) pred = model.predict(np.array([x])) self.assertEqual(pred.shape, (1, 9, 25)) print(x.shape) print(y.shape)
def test_dataset(self): dataset = Dataset( DataLoader(samples=[ 'accompanied\tac/compani/ed', 'acknowledging\tac/knowledg/ing', 'defections\tdefect/ion/s' ]).load()) self.assertEqual(len(dataset), 3) self.assertEqual(dataset[0].word, 'accompanied') self.assertEqual(dataset[1].segments[0].segment, 'ac') self.assertEqual(dataset[1].segments[1].segment, 'knowledg')
def test_buckets(self): dataset = BucketDataset( DataLoader(samples=[ 'accompanied\tac/compani/ed', 'acknowledging\tac/knowledg/ing', 'abcdowledging\tac/knowledg/ing', 'akpmowledging\tac/knowledg/ing', 'anowledging\tac/knowledg/ing', 'defections\tdefect/ion/s' ]).load()) print([(length, [str(sample) for sample in samples]) for length, samples in dataset.buckets.items()]) print([(item[0], len(item[1])) for item in dataset.buckets.items()]) before_shuffling_length = len(dataset) dataset.shuffle() self.assertEqual(len(dataset), before_shuffling_length) print([(item[0], len(item[1])) for item in dataset.buckets.items()]) print([(length, [str(sample) for sample in samples]) for length, samples in dataset.buckets.items()]) self.assertEqual(len(dataset.buckets[13]), 3) self.assertEqual(len(dataset.buckets[11]), 2)