Пример #1
0
def train_test():
    # Config Loader
    train_args = ConfigSection()
    ConfigLoader("config.cfg").load_config("./data_for_tests/config",
                                           {"POS": train_args})

    # Data Loader
    loader = TokenizeDatasetLoader(cws_data_path)
    train_data = loader.load_pku()

    # Preprocessor
    p = SeqLabelPreprocess()
    data_train = p.run(train_data, pickle_path=pickle_path)
    train_args["vocab_size"] = p.vocab_size
    train_args["num_classes"] = p.num_classes

    # Trainer
    trainer = SeqLabelTrainer(**train_args.data)

    # Model
    model = SeqLabeling(train_args)

    # Start training
    trainer.train(model, data_train)
    print("Training finished!")

    # Saver
    saver = ModelSaver("./data_for_tests/saved_model.pkl")
    saver.save_pytorch(model)
    print("Model saved!")

    del model, trainer, loader

    # Define the same model
    model = SeqLabeling(train_args)

    # Dump trained parameters into the model
    ModelLoader.load_pytorch(model, "./data_for_tests/saved_model.pkl")
    print("model loaded!")

    # Load test configuration
    test_args = ConfigSection()
    ConfigLoader("config.cfg").load_config("./data_for_tests/config",
                                           {"POS_test": test_args})

    # Tester
    tester = SeqLabelTester(**test_args.data)

    # Start testing
    tester.test(model, data_train)

    # print test results
    print(tester.show_metrics())
    print("model tested!")
Пример #2
0
def train_test():
    # Config Loader
    train_args = ConfigSection()
    ConfigLoader().load_config(config_path, {"POS_infer": train_args})

    # define dataset
    data_train = TokenizeDataSetLoader().load(cws_data_path)
    word_vocab = Vocabulary()
    label_vocab = Vocabulary()
    data_train.update_vocab(word_seq=word_vocab, label_seq=label_vocab)
    data_train.index_field("word_seq",
                           word_vocab).index_field("label_seq", label_vocab)
    data_train.set_origin_len("word_seq")
    data_train.rename_field("label_seq", "truth").set_target(truth=False)
    train_args["vocab_size"] = len(word_vocab)
    train_args["num_classes"] = len(label_vocab)

    save_pickle(word_vocab, pickle_path, "word2id.pkl")
    save_pickle(label_vocab, pickle_path, "label2id.pkl")

    # Trainer
    trainer = SeqLabelTrainer(**train_args.data)

    # Model
    model = SeqLabeling(train_args)

    # Start training
    trainer.train(model, data_train)

    # Saver
    saver = ModelSaver("./save/saved_model.pkl")
    saver.save_pytorch(model)

    del model, trainer

    # Define the same model
    model = SeqLabeling(train_args)

    # Dump trained parameters into the model
    ModelLoader.load_pytorch(model, "./save/saved_model.pkl")

    # Load test configuration
    test_args = ConfigSection()
    ConfigLoader().load_config(config_path, {"POS_infer": test_args})
    test_args["evaluator"] = SeqLabelEvaluator()

    # Tester
    tester = SeqLabelTester(**test_args.data)

    # Start testing
    data_train.set_target(truth=True)
    tester.test(model, data_train)
Пример #3
0
 def test_case_1(self):
     args = {
         "epochs": 3,
         "batch_size": 8,
         "validate": True,
         "use_cuda": True,
         "pickle_path": "./save/",
         "save_best_dev": True,
         "model_name": "default_model_name.pkl",
         "loss": Loss(None),
         "optimizer": Optimizer("Adam", lr=0.001, weight_decay=0),
         "vocab_size": 20,
         "word_emb_dim": 100,
         "rnn_hidden_units": 100,
         "num_classes": 3
     }
     trainer = SeqLabelTrainer()
     train_data = [
         [[1, 2, 3, 4, 5, 6], [1, 0, 1, 0, 1, 2]],
         [[2, 3, 4, 5, 1, 6], [0, 1, 0, 1, 0, 2]],
         [[1, 4, 1, 4, 1, 6], [1, 0, 1, 0, 1, 2]],
         [[1, 2, 3, 4, 5, 6], [1, 0, 1, 0, 1, 2]],
         [[2, 3, 4, 5, 1, 6], [0, 1, 0, 1, 0, 2]],
         [[1, 4, 1, 4, 1, 6], [1, 0, 1, 0, 1, 2]],
     ]
     dev_data = train_data
     model = SeqLabeling(args)
     trainer.train(network=model, train_data=train_data, dev_data=dev_data)
Пример #4
0
    def test_case_1(self):
        args = {
            "epochs": 3,
            "batch_size": 2,
            "validate": False,
            "use_cuda": False,
            "pickle_path": "./save/",
            "save_best_dev": True,
            "model_name": "default_model_name.pkl",
            "loss": Loss("cross_entropy"),
            "optimizer": Optimizer("Adam", lr=0.001, weight_decay=0),
            "vocab_size": 10,
            "word_emb_dim": 100,
            "rnn_hidden_units": 100,
            "num_classes": 5,
            "evaluator": SeqLabelEvaluator()
        }
        trainer = SeqLabelTrainer(**args)

        train_data = [
            [['a', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
            [['a', '@', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
            [['a', 'b', '#', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
            [['a', 'b', 'c', '?', 'e'], ['a', '@', 'c', 'd', 'e']],
            [['a', 'b', 'c', 'd', '$'], ['a', '@', 'c', 'd', 'e']],
            [['!', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
        ]
        vocab = {
            'a': 0,
            'b': 1,
            'c': 2,
            'd': 3,
            'e': 4,
            '!': 5,
            '@': 6,
            '#': 7,
            '$': 8,
            '?': 9
        }
        label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4}

        data_set = DataSet()
        for example in train_data:
            text, label = example[0], example[1]
            x = TextField(text, False)
            x_len = LabelField(len(text), is_target=False)
            y = TextField(label, is_target=False)
            ins = Instance(word_seq=x, truth=y, word_seq_origin_len=x_len)
            data_set.append(ins)

        data_set.index_field("word_seq", vocab)
        data_set.index_field("truth", label_vocab)

        model = SeqLabeling(args)

        trainer.train(network=model, train_data=data_set, dev_data=data_set)
        # If this can run, everything is OK.

        os.system("rm -rf save")
        print("pickle path deleted")
Пример #5
0
def infer():
    # Load infer configuration, the same as test
    test_args = ConfigSection()
    ConfigLoader("config.cfg").load_config(config_dir,
                                           {"POS_infer": test_args})

    # fetch dictionary size and number of labels from pickle files
    word2index = load_pickle(pickle_path, "word2id.pkl")
    test_args["vocab_size"] = len(word2index)
    index2label = load_pickle(pickle_path, "id2class.pkl")
    test_args["num_classes"] = len(index2label)

    # Define the same model
    model = SeqLabeling(test_args)

    # Dump trained parameters into the model
    ModelLoader.load_pytorch(model, os.path.join(pickle_path, model_name))
    print("model loaded!")

    # Data Loader
    raw_data_loader = BaseLoader(data_infer_path)
    infer_data = raw_data_loader.load_lines()

    # Inference interface
    infer = SeqLabelInfer(pickle_path)
    results = infer.predict(model, infer_data)

    for res in results:
        print(res)
    print("Inference finished!")
Пример #6
0
def infer():
    # Load infer configuration, the same as test
    test_args = ConfigSection()
    ConfigLoader().load_config(config_path, {"POS_infer": test_args})

    # fetch dictionary size and number of labels from pickle files
    word2index = load_pickle(pickle_path, "word2id.pkl")
    test_args["vocab_size"] = len(word2index)
    index2label = load_pickle(pickle_path, "label2id.pkl")
    test_args["num_classes"] = len(index2label)

    # Define the same model
    model = SeqLabeling(test_args)

    # Dump trained parameters into the model
    ModelLoader.load_pytorch(model, "./save/saved_model.pkl")
    print("model loaded!")

    # Load infer data
    infer_data = SeqLabelDataSet(load_func=BaseLoader.load)
    infer_data.load(data_infer_path, vocabs={"word_vocab": word2index}, infer=True)

    # inference
    infer = SeqLabelInfer(pickle_path)
    results = infer.predict(model, infer_data)
    print(results)
Пример #7
0
def foo():
    loader = TokenizeDatasetLoader("./data_for_tests/cws_pku_utf_8")
    train_data = loader.load_pku()

    train_args = ConfigSection()
    ConfigLoader("config.cfg").load_config("./data_for_tests/config",
                                           {"POS": train_args})

    # Preprocessor
    p = SeqLabelPreprocess()
    train_data = p.run(train_data)
    train_args["vocab_size"] = p.vocab_size
    train_args["num_classes"] = p.num_classes

    model = SeqLabeling(train_args)

    valid_args = {
        "save_output": True,
        "validate_in_training": True,
        "save_dev_input": True,
        "save_loss": True,
        "batch_size": 8,
        "pickle_path": "./data_for_tests/",
        "use_cuda": True
    }
    validator = SeqLabelTester(**valid_args)

    print("start validation.")
    validator.test(model, train_data)
    print(validator.show_metrics())
Пример #8
0
def train_test():
    # Config Loader
    train_args = ConfigSection()
    ConfigLoader().load_config(config_path, {"POS_infer": train_args})

    # define dataset
    data_train = SeqLabelDataSet(load_func=TokenizeDataSetLoader.load)
    data_train.load(cws_data_path)
    train_args["vocab_size"] = len(data_train.word_vocab)
    train_args["num_classes"] = len(data_train.label_vocab)

    save_pickle(data_train.word_vocab, pickle_path, "word2id.pkl")
    save_pickle(data_train.label_vocab, pickle_path, "label2id.pkl")

    # Trainer
    trainer = SeqLabelTrainer(**train_args.data)

    # Model
    model = SeqLabeling(train_args)

    # Start training
    trainer.train(model, data_train)

    # Saver
    saver = ModelSaver("./save/saved_model.pkl")
    saver.save_pytorch(model)

    del model, trainer

    # Define the same model
    model = SeqLabeling(train_args)

    # Dump trained parameters into the model
    ModelLoader.load_pytorch(model, "./save/saved_model.pkl")

    # Load test configuration
    test_args = ConfigSection()
    ConfigLoader().load_config(config_path, {"POS_infer": test_args})
    test_args["evaluator"] = SeqLabelEvaluator()

    # Tester
    tester = SeqLabelTester(**test_args.data)

    # Start testing
    change_field_is_target(data_train, "truth", True)
    tester.test(model, data_train)
Пример #9
0
    def test_seq_label(self):
        model_args = {
            "vocab_size": 10,
            "word_emb_dim": 100,
            "rnn_hidden_units": 100,
            "num_classes": 5
        }

        infer_data = [['a', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e'],
                      ['a', 'b', '#', 'd', 'e'], ['a', 'b', 'c', '?', 'e'],
                      ['a', 'b', 'c', 'd', '$'], ['!', 'b', 'c', 'd', 'e']]

        vocab = Vocabulary()
        vocab.word2idx = {
            'a': 0,
            'b': 1,
            'c': 2,
            'd': 3,
            'e': 4,
            '!': 5,
            '@': 6,
            '#': 7,
            '$': 8,
            '?': 9
        }
        class_vocab = Vocabulary()
        class_vocab.word2idx = {"0": 0, "1": 1, "2": 2, "3": 3, "4": 4}

        os.system("mkdir save")
        save_pickle(class_vocab, "./save/", "class2id.pkl")
        save_pickle(vocab, "./save/", "word2id.pkl")

        model = SeqLabeling(model_args)
        predictor = Predictor("./save/", task="seq_label")

        results = predictor.predict(network=model, data=infer_data)

        self.assertTrue(isinstance(results, list))
        self.assertGreater(len(results), 0)
        for res in results:
            self.assertTrue(isinstance(res, list))
            self.assertEqual(len(res), 5)
            self.assertTrue(isinstance(res[0], str))

        os.system("rm -rf save")
        print("pickle path deleted")
Пример #10
0
    def test_case_1(self):
        model_args = {
            "vocab_size": 10,
            "word_emb_dim": 100,
            "rnn_hidden_units": 100,
            "num_classes": 5
        }
        valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True,
                      "save_loss": True, "batch_size": 2, "pickle_path": "./save/",
                      "use_cuda": False, "print_every_step": 1}

        train_data = [
            [['a', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
            [['a', '@', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
            [['a', 'b', '#', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
            [['a', 'b', 'c', '?', 'e'], ['a', '@', 'c', 'd', 'e']],
            [['a', 'b', 'c', 'd', '$'], ['a', '@', 'c', 'd', 'e']],
            [['!', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
        ]
        vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9}
        label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4}

        data_set = DataSet()
        for example in train_data:
            text, label = example[0], example[1]
            x = TextField(text, False)
            y = TextField(label, is_target=True)
            ins = Instance(word_seq=x, label_seq=y)
            data_set.append(ins)

        data_set.index_field("word_seq", vocab)
        data_set.index_field("label_seq", label_vocab)

        model = SeqLabeling(model_args)

        tester = SeqLabelTester(**valid_args)
        tester.test(network=model, dev_data=data_set)
        # If this can run, everything is OK.

        os.system("rm -rf save")
        print("pickle path deleted")
Пример #11
0
def infer():
    # Load infer configuration, the same as test
    test_args = ConfigSection()
    ConfigLoader("config.cfg").load_config("./data_for_tests/config",
                                           {"POS_test": test_args})

    # fetch dictionary size and number of labels from pickle files
    word2index = load_pickle(pickle_path, "word2id.pkl")
    test_args["vocab_size"] = len(word2index)
    index2label = load_pickle(pickle_path, "id2class.pkl")
    test_args["num_classes"] = len(index2label)

    # Define the same model
    model = SeqLabeling(test_args)

    # Dump trained parameters into the model
    ModelLoader.load_pytorch(model, "./data_for_tests/saved_model.pkl")
    print("model loaded!")

    # Data Loader
    raw_data_loader = BaseLoader(data_infer_path)
    infer_data = raw_data_loader.load_lines()
    """
        Transform strings into list of list of strings. 
        [
            [word_11, word_12, ...],
            [word_21, word_22, ...],
            ...
        ]
        In this case, each line in "people_infer.txt" is already a sentence. So load_lines() just splits them.
    """

    # Inference interface
    infer = Predictor(pickle_path)
    results = infer.predict(model, infer_data)

    print(results)
    print("Inference finished!")
Пример #12
0
def infer():
    # Load infer configuration, the same as test
    test_args = ConfigSection()
    ConfigLoader().load_config(config_dir, {"POS_infer": test_args})

    # fetch dictionary size and number of labels from pickle files
    word_vocab = load_pickle(pickle_path, "word2id.pkl")
    label_vocab = load_pickle(pickle_path, "label2id.pkl")
    test_args["vocab_size"] = len(word_vocab)
    test_args["num_classes"] = len(label_vocab)
    print("vocabularies loaded")

    # Define the same model
    model = SeqLabeling(test_args)
    print("model defined")

    # Dump trained parameters into the model
    ModelLoader.load_pytorch(model, os.path.join(pickle_path, model_name))
    print("model loaded!")

    # Data Loader
    infer_data = SeqLabelDataSet(load_func=BaseLoader.load)
    infer_data.load(data_infer_path,
                    vocabs={
                        "word_vocab": word_vocab,
                        "label_vocab": label_vocab
                    },
                    infer=True)
    print("data set prepared")

    # Inference interface
    infer = SeqLabelInfer(pickle_path)
    results = infer.predict(model, infer_data)

    for res in results:
        print(res)
    print("Inference finished!")
Пример #13
0
def test_training():
    # Config Loader
    trainer_args = ConfigSection()
    model_args = ConfigSection()
    ConfigLoader().load_config(config_dir, {
        "test_seq_label_trainer": trainer_args, "test_seq_label_model": model_args})

    data_set = TokenizeDataSetLoader().load(data_path)
    word_vocab = Vocabulary()
    label_vocab = Vocabulary()
    data_set.update_vocab(word_seq=word_vocab, label_seq=label_vocab)
    data_set.index_field("word_seq", word_vocab).index_field("label_seq", label_vocab)
    data_set.set_origin_len("word_seq")
    data_set.rename_field("label_seq", "truth").set_target(truth=False)
    data_train, data_dev = data_set.split(0.3, shuffle=True)
    model_args["vocab_size"] = len(word_vocab)
    model_args["num_classes"] = len(label_vocab)

    save_pickle(word_vocab, pickle_path, "word2id.pkl")
    save_pickle(label_vocab, pickle_path, "label2id.pkl")

    trainer = SeqLabelTrainer(
        epochs=trainer_args["epochs"],
        batch_size=trainer_args["batch_size"],
        validate=False,
        use_cuda=False,
        pickle_path=pickle_path,
        save_best_dev=trainer_args["save_best_dev"],
        model_name=model_name,
        optimizer=Optimizer("SGD", lr=0.01, momentum=0.9),
    )

    # Model
    model = SeqLabeling(model_args)

    # Start training
    trainer.train(model, data_train, data_dev)

    # Saver
    saver = ModelSaver(os.path.join(pickle_path, model_name))
    saver.save_pytorch(model)

    del model, trainer

    # Define the same model
    model = SeqLabeling(model_args)

    # Dump trained parameters into the model
    ModelLoader.load_pytorch(model, os.path.join(pickle_path, model_name))

    # Load test configuration
    tester_args = ConfigSection()
    ConfigLoader().load_config(config_dir, {"test_seq_label_tester": tester_args})

    # Tester
    tester = SeqLabelTester(batch_size=4,
                            use_cuda=False,
                            pickle_path=pickle_path,
                            model_name="seq_label_in_test.pkl",
                            evaluator=SeqLabelEvaluator()
                            )

    # Start testing with validation data
    data_dev.set_target(truth=True)
    tester.test(model, data_dev)
Пример #14
0
def train_and_test():
    # Config Loader
    trainer_args = ConfigSection()
    model_args = ConfigSection()
    ConfigLoader("config.cfg").load_config(config_dir, {
        "test_seq_label_trainer": trainer_args,
        "test_seq_label_model": model_args
    })

    # Data Loader
    pos_loader = POSDatasetLoader(data_path)
    train_data = pos_loader.load_lines()

    # Preprocessor
    p = SeqLabelPreprocess()
    data_train, data_dev = p.run(train_data,
                                 pickle_path=pickle_path,
                                 train_dev_split=0.5)
    model_args["vocab_size"] = p.vocab_size
    model_args["num_classes"] = p.num_classes

    # Trainer: two definition styles
    # 1
    # trainer = SeqLabelTrainer(trainer_args.data)

    # 2
    trainer = SeqLabelTrainer(
        epochs=trainer_args["epochs"],
        batch_size=trainer_args["batch_size"],
        validate=trainer_args["validate"],
        use_cuda=trainer_args["use_cuda"],
        pickle_path=pickle_path,
        save_best_dev=trainer_args["save_best_dev"],
        model_name=model_name,
        optimizer=Optimizer("SGD", lr=0.01, momentum=0.9),
    )

    # Model
    model = SeqLabeling(model_args)

    # Start training
    trainer.train(model, data_train, data_dev)
    print("Training finished!")

    # Saver
    saver = ModelSaver(os.path.join(pickle_path, model_name))
    saver.save_pytorch(model)
    print("Model saved!")

    del model, trainer, pos_loader

    # Define the same model
    model = SeqLabeling(model_args)

    # Dump trained parameters into the model
    ModelLoader.load_pytorch(model, os.path.join(pickle_path, model_name))
    print("model loaded!")

    # Load test configuration
    tester_args = ConfigSection()
    ConfigLoader("config.cfg").load_config(
        config_dir, {"test_seq_label_tester": tester_args})

    # Tester
    tester = SeqLabelTester(save_output=False,
                            save_loss=False,
                            save_best_dev=False,
                            batch_size=4,
                            use_cuda=False,
                            pickle_path=pickle_path,
                            model_name="seq_label_in_test.pkl",
                            print_every_step=1)

    # Start testing with validation data
    tester.test(model, data_dev)

    # print test results
    print(tester.show_metrics())
    print("model tested!")
Пример #15
0
def train_and_test():
    # Config Loader
    trainer_args = ConfigSection()
    model_args = ConfigSection()
    ConfigLoader().load_config(config_dir, {
        "test_seq_label_trainer": trainer_args,
        "test_seq_label_model": model_args
    })

    data_set = SeqLabelDataSet()
    data_set.load(data_path)
    train_set, dev_set = data_set.split(0.3, shuffle=True)
    model_args["vocab_size"] = len(data_set.word_vocab)
    model_args["num_classes"] = len(data_set.label_vocab)

    save_pickle(data_set.word_vocab, pickle_path, "word2id.pkl")
    save_pickle(data_set.label_vocab, pickle_path, "label2id.pkl")

    trainer = SeqLabelTrainer(
        epochs=trainer_args["epochs"],
        batch_size=trainer_args["batch_size"],
        validate=False,
        use_cuda=trainer_args["use_cuda"],
        pickle_path=pickle_path,
        save_best_dev=trainer_args["save_best_dev"],
        model_name=model_name,
        optimizer=Optimizer("SGD", lr=0.01, momentum=0.9),
    )

    # Model
    model = SeqLabeling(model_args)

    # Start training
    trainer.train(model, train_set, dev_set)
    print("Training finished!")

    # Saver
    saver = ModelSaver(os.path.join(pickle_path, model_name))
    saver.save_pytorch(model)
    print("Model saved!")

    del model, trainer

    change_field_is_target(dev_set, "truth", True)

    # Define the same model
    model = SeqLabeling(model_args)

    # Dump trained parameters into the model
    ModelLoader.load_pytorch(model, os.path.join(pickle_path, model_name))
    print("model loaded!")

    # Load test configuration
    tester_args = ConfigSection()
    ConfigLoader().load_config(config_dir,
                               {"test_seq_label_tester": tester_args})

    # Tester
    tester = SeqLabelTester(batch_size=4,
                            use_cuda=False,
                            pickle_path=pickle_path,
                            model_name="seq_label_in_test.pkl",
                            evaluator=SeqLabelEvaluator())

    # Start testing with validation data
    tester.test(model, dev_set)
    print("model tested!")
Пример #16
0
    def test_seq_label(self):
        model_args = {
            "vocab_size": 10,
            "word_emb_dim": 100,
            "rnn_hidden_units": 100,
            "num_classes": 5
        }

        infer_data = [['a', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e'],
                      ['a', 'b', '#', 'd', 'e'], ['a', 'b', 'c', '?', 'e'],
                      ['a', 'b', 'c', 'd', '$'], ['!', 'b', 'c', 'd', 'e']]

        vocab = Vocabulary()
        vocab.word2idx = {
            'a': 0,
            'b': 1,
            'c': 2,
            'd': 3,
            'e': 4,
            '!': 5,
            '@': 6,
            '#': 7,
            '$': 8,
            '?': 9
        }
        class_vocab = Vocabulary()
        class_vocab.word2idx = {"0": 0, "1": 1, "2": 2, "3": 3, "4": 4}

        os.system("mkdir save")
        save_pickle(class_vocab, "./save/", "label2id.pkl")
        save_pickle(vocab, "./save/", "word2id.pkl")

        model = CNNText(model_args)
        import fastNLP.core.predictor as pre
        predictor = Predictor("./save/", pre.text_classify_post_processor)

        # Load infer data
        infer_data_set = convert_seq_dataset(infer_data)
        infer_data_set.index_field("word_seq", vocab)

        results = predictor.predict(network=model, data=infer_data_set)

        self.assertTrue(isinstance(results, list))
        self.assertGreater(len(results), 0)
        self.assertEqual(len(results), len(infer_data))
        for res in results:
            self.assertTrue(isinstance(res, str))
            self.assertTrue(res in class_vocab.word2idx)

        del model, predictor
        infer_data_set.set_origin_len("word_seq")

        model = SeqLabeling(model_args)
        predictor = Predictor("./save/", pre.seq_label_post_processor)

        results = predictor.predict(network=model, data=infer_data_set)
        self.assertTrue(isinstance(results, list))
        self.assertEqual(len(results), len(infer_data))
        for i in range(len(infer_data)):
            res = results[i]
            self.assertTrue(isinstance(res, list))
            self.assertEqual(len(res), len(infer_data[i]))

        os.system("rm -rf save")
        print("pickle path deleted")
Пример #17
0
pickle_path = "data_for_tests"

if __name__ == "__main__":
    # Data Loader
    pos = POSDatasetLoader(data_name, data_path)
    train_data = pos.load_lines()

    # Preprocessor
    p = POSPreprocess(train_data, pickle_path)
    vocab_size = p.vocab_size
    num_classes = p.num_classes

    # Trainer
    train_args = {
        "epochs": 20,
        "batch_size": 1,
        "num_classes": num_classes,
        "vocab_size": vocab_size,
        "pickle_path": pickle_path,
        "validate": True
    }
    trainer = POSTrainer(train_args)

    # Model
    model = SeqLabeling(100, 1, num_classes, vocab_size, bi_direction=True)

    # Start training
    trainer.train(model)

    print("Training finished!")