コード例 #1
0
ファイル: train_test.py プロジェクト: yakazimir/allennlp
    def test_train_saves_all_keys_in_config(self):
        params = Params(
            {
                "model": {
                    "type": "simple_tagger",
                    "text_field_embedder": {
                        "token_embedders": {"tokens": {"type": "embedding", "embedding_dim": 5}}
                    },
                    "encoder": {"type": "lstm", "input_size": 5, "hidden_size": 7, "num_layers": 2},
                },
                "pytorch_seed": 42,
                "numpy_seed": 42,
                "random_seed": 42,
                "dataset_reader": {"type": "sequence_tagging"},
                "train_data_path": SEQUENCE_TAGGING_DATA_PATH,
                "validation_data_path": SEQUENCE_TAGGING_DATA_PATH,
                "iterator": {"type": "basic", "batch_size": 2},
                "trainer": {"num_epochs": 2, "optimizer": "adam"},
            }
        )

        serialization_dir = os.path.join(self.TEST_DIR, "test_train_model")
        params_as_dict = (
            params.as_ordered_dict()
        )  # Do it here as train_model will pop all the values.
        train_model(params, serialization_dir=serialization_dir)

        config_path = os.path.join(serialization_dir, CONFIG_NAME)
        with open(config_path, "r") as config:
            saved_config_as_dict = OrderedDict(json.load(config))
        assert params_as_dict == saved_config_as_dict