示例#1
0
文件: train.py 项目: pavelsof/mstnn
    def pick_best(self, dataset, num_best=1):
        """
		Assuming that checkpoints have been saved during training, deletes all
		but those num_best that are performing best against the given Dataset.
		"""
        if num_best >= len(self.checkpoints):
            return

        scorer = Scorer(dataset)
        scores = {}

        with tempfile.TemporaryDirectory() as temp_dir:
            for path in self.checkpoints:
                parsed = Model.load(path).parse(dataset)

                output_fp = os.path.join(temp_dir, os.path.basename(path))
                Dataset(output_fp).write_graphs(parsed)

                scores[path] = scorer.score(Dataset(output_fp))
                print('{}: {:.2f}'.format(path, scores[path]))

        best = [(uas, path) for path, uas in scores.items()]
        best = sorted(best, reverse=True)[:num_best]
        best = [item[1] for item in best]

        for path in self.checkpoints:
            if path not in best:
                os.remove(path)

        print('kept: {}'.format(', '.join(best)))
示例#2
0
    def test_write_graphs(self):
        graphs = [graph for graph in self.dataset.gen_graphs()]

        with tempfile.TemporaryDirectory() as temp_dir:
            dataset = Dataset(os.path.join(temp_dir, 'test'))
            dataset.write_graphs(graphs)

            self.assertTrue(
                filecmp.cmp(self.dataset.file_path,
                            dataset.file_path,
                            shallow=False))
示例#3
0
    def test_write_sentences(self):
        sents = [sent for sent in self.dataset.gen_sentences()]

        with tempfile.TemporaryDirectory() as temp_dir:
            dataset = Dataset(os.path.join(temp_dir, 'test'))
            dataset.write_sentences(sents)

            self.assertTrue(
                filecmp.cmp(self.dataset.file_path,
                            dataset.file_path,
                            shallow=False))
示例#4
0
文件: score.py 项目: pavelsof/mstnn
def score(parsed, standard, ud_version=2):
    """
	Calculates and returns the UAS score of a dataset against another dataset,
	usually parser output against gold-standard data. Expects the paths to two
	conllu datasets.
	
	Could raise a ConlluError or a ScoreError.
	
	This can be seen as the main function of the cli's score command; however,
	it is also used elsewhere in the code.
	"""
    parsed = Dataset(parsed, ud_version)
    standard = Dataset(standard, ud_version)

    return Scorer(standard).score(parsed)
示例#5
0
文件: train.py 项目: pavelsof/mstnn
def train(model_fp,
          train_fp,
          ud_version=2,
          ignore_forms=False,
          ignore_lemmas=False,
          ignore_morph=False,
          epochs=10,
          batch_size=32,
          dev_fp=None,
          num_best=1,
          forms_word2vec=None,
          lemmas_word2vec=None):
    """
	Trains an mstnn model. Expects a path where the models will be written to,
	and a path to a conllu dataset that will be used for training. The epochs,
	batch_size, and ignore_* args are passed on to the train_on method.
	
	The dev_fp optional path should specify a development dataset to check the
	trained model against. The UD version would apply to both datasets. The
	num_best keyword arg specifies the number of best performing checkpoints to
	keep when there is a development dataset to check against.
	
	Paths to pre-trained form and/or lemma embeddings can be specified. These
	are expected to be in binary word2vec format.
	
	This can be seen as the main function of the cli's train command.
	"""
    forms_vecs = None if forms_word2vec is None else \
      KeyedVectors.load_word2vec_format(forms_word2vec, binary=True)
    lemmas_vecs = None if lemmas_word2vec is None else \
      KeyedVectors.load_word2vec_format(lemmas_word2vec, binary=True)

    trainer = Trainer(model_fp)
    trainer.train_on(Dataset(train_fp, ud_version),
                     ignore_forms,
                     ignore_lemmas,
                     ignore_morph,
                     epochs,
                     batch_size,
                     save_checkpoints=True,
                     forms_vecs=forms_vecs,
                     lemmas_vecs=lemmas_vecs)

    if dev_fp is not None:
        trainer.pick_best(Dataset(dev_fp, ud_version), num_best=num_best)
示例#6
0
    def test_bad_file(self):
        dataset = Dataset('code/dontexist')
        with self.assertRaises(ConlluError):
            [sent for sent in dataset.gen_sentences()]

        dataset = Dataset('code/conllu.py')
        with self.assertRaises(ConlluError):
            [sent for sent in dataset.gen_sentences()]
示例#7
0
文件: diff.py 项目: pavelsof/mstnn
def diff(fp1, fp2, ud_version=2):
    """
	Expects the paths to two conllu datasets comprising parses of the same
	sentences and returns a string describing those parses that differ. The
	optional ud_version applies to both datasets.
	
	Could raise a ConlluError or a DiffError.
	"""
    dataset1 = Dataset(fp1, ud_version)
    dataset2 = Dataset(fp2, ud_version)

    output = []

    for graph1, graph2 in zip(dataset1.gen_graphs(), dataset2.gen_graphs()):
        if len(graph1) != len(graph2):
            raise DiffError('Sentences do not match: {} and {}'.format(
                fp1, fp2))

        sent_output = []
        is_same = True

        loop = zip(graph1.nodes(data=True), graph2.nodes(data=True))
        for (node1, data1), (node2, data2) in loop:
            if node1 == 0: data1['FORM'] = 'ROOT'
            if node2 == 0: data2['FORM'] = 'ROOT'

            try:
                assert node1 == node2
                assert data1['FORM'] == data2['FORM']
            except AssertionError:
                raise DiffError('Sentences do not match: {} and {}'.format(
                    fp1, fp2))

            edges1 = ','.join(map(str, graph1.edge[node1]))
            edges2 = ','.join(map(str, graph2.edge[node2]))

            if edges1 != edges2:
                is_same = False

            sent_output.append('{!s}\t{}\t\t{}\t{}'.format(
                node1, data1['FORM'], edges1, edges2))

        if not is_same:
            output.append('\n'.join(sent_output))

    return '\n\n'.join(output)
示例#8
0
 def setUp(self):
     self.dataset = Dataset('data/UD_Basque/eu-ud-dev.conllu', ud_version=1)
示例#9
0
class DatasetTestCase(TestCase):
    def setUp(self):
        self.dataset = Dataset('data/UD_Basque/eu-ud-dev.conllu', ud_version=1)

    def test_bad_file(self):
        dataset = Dataset('code/dontexist')
        with self.assertRaises(ConlluError):
            [sent for sent in dataset.gen_sentences()]

        dataset = Dataset('code/conllu.py')
        with self.assertRaises(ConlluError):
            [sent for sent in dataset.gen_sentences()]

    def test_gen_sentences(self):
        res = []

        for sent in self.dataset.gen_sentences():
            for index, word in enumerate(sent, 1):
                self.assertTrue(isinstance(word, Word))
                self.assertTrue(isinstance(word.ID, int))
                self.assertTrue(isinstance(word.HEAD, int))
                self.assertEqual(word.ID, index)

            res.append(sent)

        self.assertEqual(len(res[0]), 10)
        self.assertEqual(len(res[-1]), 25)

        self.assertEqual(
            res[0][0],
            Word._make([
                1, 'Atenasen', 'Atenas', 'PROPN', '_', {
                    'Case': frozenset(['Ine']),
                    'Definite': frozenset(['Def']),
                    'Number': frozenset(['Sing'])
                }, 8, 'nmod', '_', '_'
            ]))

    def test_gen_graphs(self):
        res = []

        for graph in self.dataset.gen_graphs():
            self.assertTrue(isinstance(graph, nx.DiGraph))
            res.append(graph)

        self.assertEqual(res[0].number_of_nodes(), 10 + 1)
        self.assertEqual(res[-1].number_of_nodes(), 25 + 1)

    def test_write_sentences(self):
        sents = [sent for sent in self.dataset.gen_sentences()]

        with tempfile.TemporaryDirectory() as temp_dir:
            dataset = Dataset(os.path.join(temp_dir, 'test'))
            dataset.write_sentences(sents)

            self.assertTrue(
                filecmp.cmp(self.dataset.file_path,
                            dataset.file_path,
                            shallow=False))

    def test_write_graphs(self):
        graphs = [graph for graph in self.dataset.gen_graphs()]

        with tempfile.TemporaryDirectory() as temp_dir:
            dataset = Dataset(os.path.join(temp_dir, 'test'))
            dataset.write_graphs(graphs)

            self.assertTrue(
                filecmp.cmp(self.dataset.file_path,
                            dataset.file_path,
                            shallow=False))
示例#10
0
from code.features import Extractor


@composite
def subdicts(draw, source_dict):
    d = {}

    keys = draw(sets(sampled_from(source_dict.keys())))
    for key in keys:
        d[key] = tuple(
            sorted(draw(sets(sampled_from(source_dict[key]), min_size=1))))

    return d


dataset = Dataset('data/UD_Basque/eu-ud-dev.conllu', ud_version=1)

extractor = Extractor()
extractor.read(dataset)


class FeaturesTestCase(TestCase):
    def test_read(self):
        self.assertTrue(isinstance(extractor.pos_tags, tuple))
        self.assertIn('PROPN', extractor.pos_tags)
        self.assertIn('CONJ', extractor.pos_tags)

        self.assertTrue(isinstance(extractor.morph, dict))
        self.assertIn('Case', extractor.morph)
        self.assertIn('Definite', extractor.morph)
        self.assertIn('Number', extractor.morph)