コード例 #1
0
ファイル: train_test.py プロジェクト: zheng5yu9/allennlp
    def test_train_model_can_instantiate_from_params(self):
        params = Params.from_file(self.FIXTURES_ROOT / "simple_tagger" / "experiment.json")

        # Can instantiate from base class params
        TrainModel.from_params(
            params=params, serialization_dir=self.TEST_DIR, local_rank=0, batch_weight_key=""
        )
コード例 #2
0
    def test_pretrained_configs(self, path):
        params = Params.from_file(
            path,
            params_overrides="{"
            "'trainer.cuda_device': -1, "
            "'trainer.use_amp': false, "
            "'trainer.num_epochs': 2, "
            "}",
        )

        # Patch max_instances in the multitask case
        patch_dataset_reader(params["dataset_reader"])
        if "validation_dataset_reader" in params:
            # Unclear why this doesn't work for biattentive_classification_network
            if "biattentive_classification_network" not in path:
                patch_dataset_reader(params["validation_dataset_reader"])

        # Patch any pretrained glove files with smaller fixtures.
        patch_glove(params)
        # Patch image_dir and feature_cache_dir keys so they point at our test fixtures instead.
        patch_image_dir(params)

        # Remove unnecessary keys.
        for key in ("random_seed", "numpy_seed", "pytorch_seed",
                    "distributed"):
            if key in params:
                del params[key]

        # Just make sure the train loop can be instantiated.
        TrainModel.from_params(params=params,
                               serialization_dir=self.TEST_DIR,
                               local_rank=0)
コード例 #3
0
    def test_pretrained_configs(self, path):
        params = Params.from_file(
            path,
            params_overrides="{"
            "'trainer.cuda_device': -1, "
            "'trainer.use_amp': false, "
            "'trainer.num_epochs': 2, "
            "'dataset_reader.max_instances': 4, "
            "'dataset_reader.lazy': false, "
            "}",
        )

        # Patch any pretrained glove files with smaller fixtures.
        patch_glove(params)

        # Remove unnecessary keys.
        for key in ("random_seed", "numpy_seed", "pytorch_seed",
                    "distributed"):
            if key in params:
                del params[key]

        # Just make sure the train loop can be instantiated.
        TrainModel.from_params(params=params,
                               serialization_dir=self.TEST_DIR,
                               local_rank=0)
コード例 #4
0
ファイル: train_test.py プロジェクト: zheng5yu9/allennlp
    def test_train_can_fine_tune_model_from_archive(self):
        params = Params.from_file(
            self.FIXTURES_ROOT / "basic_classifier" / "experiment_from_archive.jsonnet"
        )
        train_loop = TrainModel.from_params(
            params=params, serialization_dir=self.TEST_DIR, local_rank=0, batch_weight_key=""
        )
        train_loop.run()

        model = Model.from_archive(
            self.FIXTURES_ROOT / "basic_classifier" / "serialization" / "model.tar.gz"
        )

        # This is checking that the vocabulary actually got extended.  The data that we're using for
        # training is different from the data we used to produce the model archive, and we set
        # parameters such that the vocab should have been extended.
        assert train_loop.model.vocab.get_vocab_size() > model.vocab.get_vocab_size()
コード例 #5
0
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--action",
                        type=Action,
                        choices=list(Action),
                        required=True)
    parser.add_argument("--config", required=True)
    parser.add_argument("--serialization-dir", required=True)
    parser.add_argument("--batch-count", type=int, default=0)
    parser.add_argument("--assume-multiprocess-types", action="store_true")
    args = parser.parse_args()

    params = Params.from_file(args.config)
    train_model = TrainModel.from_params(params,
                                         args.serialization_dir,
                                         local_rank=0,
                                         batch_weight_key="")
    trainer = train_model.trainer
    raw_generator = trainer.iterator(trainer.train_data,
                                     num_epochs=1,
                                     shuffle=True)

    if args.action is Action.log:
        log_iterable(raw_generator, args.assume_multiprocess_types)
    elif args.action is Action.time:
        time_iterable(raw_generator, args.batch_count)
    elif args.action is Action.first:
        time_to_first(raw_generator)
    else:
        raise Exception(f"Unaccounted for action {args.action}")