コード例 #1
0
    def test_load_saved_model(self):
        with tempfile.NamedTemporaryFile() as snapshot_file:
            train_data = tests_module.test_file("train_data_tiny.tsv")
            eval_data = tests_module.test_file("test_data_tiny.tsv")
            config = PyTextConfig(
                task=DocumentClassificationTask.Config(
                    data=Data.Config(
                        source=TSVDataSource.Config(
                            train_filename=train_data,
                            eval_filename=eval_data,
                            field_names=["label", "slots", "text"],
                        )
                    )
                ),
                version=LATEST_VERSION,
                save_snapshot_path=snapshot_file.name,
            )
            task = create_task(config.task)
            model = task.model

            save(config, model, meta=None, tensorizers=task.data.tensorizers)
            task2, config2 = load(snapshot_file.name)

            self.assertEqual(config, config2)
            self.assertModulesEqual(model, task2.model)

            model.eval()
            task2.model.eval()

            inputs = torch.LongTensor([[1, 2, 3]]), torch.LongTensor([3])
            self.assertEqual(model(*inputs).tolist(), task2.model(*inputs).tolist())
コード例 #2
0
ファイル: predictor_test.py プロジェクト: westinedu/pytext
    def test_batch_predict_caffe2_model(self):
        with tempfile.NamedTemporaryFile(
        ) as snapshot_file, tempfile.NamedTemporaryFile() as caffe2_model_file:
            train_data = tests_module.test_file("train_data_tiny.tsv")
            eval_data = tests_module.test_file("test_data_tiny.tsv")
            config = PyTextConfig(
                task=DocumentClassificationTask.Config(data=Data.Config(
                    source=TSVDataSource.Config(
                        train_filename=train_data,
                        eval_filename=eval_data,
                        test_filename=eval_data,
                        field_names=["label", "slots", "text"],
                    ))),
                version=LATEST_VERSION,
                save_snapshot_path=snapshot_file.name,
                export_caffe2_path=caffe2_model_file.name,
            )
            task = create_task(config.task)
            task.export(task.model, caffe2_model_file.name)
            model = task.model
            save(config, model, meta=None, tensorizers=task.data.tensorizers)

            results = batch_predict_caffe2_model(snapshot_file.name,
                                                 caffe2_model_file.name)
            self.assertEqual(4, len(results))
コード例 #3
0
    def test_batch_predict_caffe2_model(self):
        with tempfile.NamedTemporaryFile() as snapshot_file, tempfile.NamedTemporaryFile() as caffe2_model_file:
            train_data = tests_module.test_file("train_data_tiny.tsv")
            eval_data = tests_module.test_file("test_data_tiny.tsv")
            config = PyTextConfig(
                task=DocumentClassificationTask.Config(
                    model=DocModel.Config(
                        inputs=DocModel.Config.ModelInput(
                            tokens=TokenTensorizer.Config(),
                            dense=FloatListTensorizer.Config(
                                column="dense", dim=1, error_check=True
                            ),
                            labels=LabelTensorizer.Config(),
                        )
                    ),
                    data=Data.Config(
                        source=TSVDataSource.Config(
                            train_filename=train_data,
                            eval_filename=eval_data,
                            test_filename=eval_data,
                            field_names=["label", "slots", "text", "dense"],
                        )
                    ),
                ),
                version=21,
                save_snapshot_path=snapshot_file.name,
                export_caffe2_path=caffe2_model_file.name,
            )
            task = create_task(config.task)
            task.export(task.model, caffe2_model_file.name)
            model = task.model
            save(config, model, meta=None, tensorizers=task.data.tensorizers)

            pt_results = task.predict(task.data.data_source.test)

            def assert_caffe2_results_correct(caffe2_results):
                for pt_res, res in zip(pt_results, caffe2_results):
                    np.testing.assert_array_almost_equal(
                        pt_res["score"].tolist()[0],
                        [score[0] for score in res.values()],
                    )

            results = batch_predict_caffe2_model(
                snapshot_file.name, caffe2_model_file.name
            )
            self.assertEqual(4, len(results))
            assert_caffe2_results_correct(results)

            results = batch_predict_caffe2_model(
                snapshot_file.name, caffe2_model_file.name, cache_size=2
            )
            self.assertEqual(4, len(results))
            assert_caffe2_results_correct(results)

            results = batch_predict_caffe2_model(
                snapshot_file.name, caffe2_model_file.name, cache_size=-1
            )
            self.assertEqual(4, len(results))
            assert_caffe2_results_correct(results)
コード例 #4
0
        def test_load_checkpoint(self):
            with tempfile.NamedTemporaryFile() as checkpoint_file:
                train_data = tests_module.test_file("train_data_tiny.tsv")
                eval_data = tests_module.test_file("test_data_tiny.tsv")
                config = PyTextConfig(
                    task=DocumentClassificationTask.Config(data=Data.Config(
                        source=TSVDataSource.Config(
                            train_filename=train_data,
                            eval_filename=eval_data,
                            field_names=["label", "slots", "text"],
                        ))),
                    version=LATEST_VERSION,
                    save_snapshot_path=checkpoint_file.name,
                )
                task = create_task(config.task)
                model = task.model
                # test checkpoint saving and loading
                optimizer = create_optimizer(Adam.Config(), model)
                scheduler = create_scheduler(Scheduler.Config(), optimizer)
                training_state = TrainingState(
                    model=model,
                    optimizer=optimizer,
                    scheduler=scheduler,
                    start_time=0,
                    epoch=0,
                    rank=0,
                    stage=Stage.TRAIN,
                    epochs_since_last_improvement=0,
                    best_model_state=None,
                    best_model_metric=None,
                    tensorizers=None,
                )

                checkpoint_path = checkpoint_file.name
                save(
                    config,
                    model,
                    None,
                    task.data.tensorizers,
                    training_state,
                    checkpoint_file,
                )
                task_restored, config_restored, training_state_restored = load(
                    checkpoint_path)
                optimizer_restored = training_state_restored.optimizer
                scheduler_restored = training_state_restored.scheduler
                self.assertOptimizerEqual(optimizer, optimizer_restored)
                self.assertNotNone(scheduler_restored)
                self.assertEqual(config, config_restored)
                self.assertModulesEqual(model, task_restored.model)
                model.eval()
                task_restored.model.eval()

                inputs = torch.LongTensor([[1, 2, 3]]), torch.LongTensor([3])
                self.assertEqual(
                    model(*inputs).tolist(),
                    task_restored.model(*inputs).tolist())
コード例 #5
0
    def save_checkpoint(self, state: TrainingState,
                        train_config: PyTextConfig) -> str:
        # Only one worker should save checkpoints
        if state.rank != 0:
            return

        if train_config.save_module_checkpoints or train_config.save_all_checkpoints:
            # saves per-epoch sub-modules when save_all_checkpoints or
            # save_module_checkpoints is enabled
            state.model.save_modules(base_path=train_config.modules_save_dir,
                                     suffix=f"-ep{state.epoch}")
        if state.epochs_since_last_improvement == 0:
            # state.epochs_since_last_improvement == 0 means found a better
            # model in current epoch, thus update best model's sub-modules
            state.model.save_modules(base_path=train_config.modules_save_dir)

        # next to add new config and implementation of frequency on checkpointing
        if train_config.save_all_checkpoints:
            return save(
                config=train_config,
                model=state.model,
                meta=None,
                tensorizers=None,
                training_state=state,
                identifier=str(state.epoch),
            )
コード例 #6
0
    def test_load_checkpoint_in_dist_training(self):
        with tempfile.NamedTemporaryFile() as checkpoint_file:
            train_data = tests_module.test_file("train_data_tiny.tsv")
            eval_data = tests_module.test_file("test_data_tiny.tsv")
            config = PyTextConfig(
                task=DocumentClassificationTask.Config(data=Data.Config(
                    source=BlockShardedTSVDataSource.Config(
                        train_filename=train_data,
                        eval_filename=eval_data,
                        field_names=["label", "slots", "text"],
                    ))),
                version=LATEST_VERSION,
                save_snapshot_path=checkpoint_file.name,
            )
            task = create_task(config.task)
            model = task.model
            # test checkpoint saving and loading
            optimizer = create_optimizer(Adam.Config(), model)
            scheduler = create_scheduler(Scheduler.Config(), optimizer)
            training_state = TrainingState(
                model=model,
                optimizer=optimizer,
                scheduler=scheduler,
                start_time=0,
                epoch=0,
                rank=0,
                stage=Stage.TRAIN,
                epochs_since_last_improvement=0,
                best_model_state=None,
                best_model_metric=None,
                tensorizers=task.data.tensorizers,
            )

            id = "epoch-1"
            saved_path = save(config, model, None, task.data.tensorizers,
                              training_state, id)
            new_rank = 2
            new_world_size = 4
            task_restored, config_restored, training_state_restored = load(
                saved_path, rank=new_rank, world_size=new_world_size)
            self.assertCheckpointEqual(
                model,
                config,
                training_state,
                task_restored.model,
                config_restored,
                training_state_restored,
            )
            self.assertEqual(task_restored.data.data_source.rank, new_rank)
            self.assertEqual(task_restored.data.data_source.world_size,
                             new_world_size)
コード例 #7
0
    def save_checkpoint(self, state: TrainingState,
                        train_config: PyTextConfig):
        # Only one worker should save checkpoints
        if state.rank != 0:
            return

        if train_config.save_module_checkpoints or train_config.save_all_checkpoints:
            state.model.save_modules(base_path=train_config.modules_save_dir,
                                     suffix=f"-ep{state.epoch}")
        # TODO: add new config and implementation of frequency on checkpointing
        if train_config.save_all_checkpoints:
            per_epoch_save_path = self.generate_checkpoint_path(
                train_config, state.epoch)
            with open(per_epoch_save_path, "wb") as checkpoint_stream:
                print("Saving checkpoint to ", per_epoch_save_path)
                save(
                    config=train_config,
                    model=state.model,
                    meta=None,
                    tensorizers=None,
                    training_state=state,
                    f=checkpoint_stream,
                )
コード例 #8
0
 def save_checkpoint(self, state: TrainingState,
                     train_config: PyTextConfig) -> str:
     # Only one worker should save checkpoints
     if state.rank != 0:
         return
     # checkpoint the whole model instead of sub-modules
     # users can load sub-modules from model, externally
     if train_config.save_all_checkpoints:
         return save(
             config=train_config,
             model=state.model,
             meta=None,
             tensorizers=None,
             training_state=state,
             identifier=str(state.epoch),
         )
コード例 #9
0
ファイル: trainer.py プロジェクト: omargamal510/pytext
    def save_checkpoint(self, state: TrainingState, train_config: PyTextConfig) -> str:
        # Only one worker should save checkpoints
        if state.rank != 0:
            return

        if train_config.save_module_checkpoints or train_config.save_all_checkpoints:
            state.model.save_modules(
                base_path=train_config.modules_save_dir, suffix=f"-ep{state.epoch}"
            )
        # next to add new config and implementation of frequency on checkpointing
        if train_config.save_all_checkpoints:
            return save(
                config=train_config,
                model=state.model,
                meta=None,
                tensorizers=None,
                training_state=state,
                identifier=str(state.epoch),
            )
コード例 #10
0
    def test_load_saved_model(self):
        with tempfile.NamedTemporaryFile() as snapshot_file:
            train_data = tests_module.test_file("train_data_tiny.tsv")
            eval_data = tests_module.test_file("test_data_tiny.tsv")
            config = PyTextConfig(
                task=DocumentClassificationTask.Config(
                    data=Data.Config(
                        source=TSVDataSource.Config(
                            train_filename=train_data,
                            eval_filename=eval_data,
                            field_names=["label", "slots", "text"],
                        )
                    )
                ),
                version=LATEST_VERSION,
                save_snapshot_path=snapshot_file.name,
            )
            task = create_task(config.task)
            model = task.model

            save(config, model, meta=None, tensorizers=task.data.tensorizers)
            task2, config2, training_state_none = load(snapshot_file.name)

            self.assertEqual(config, config2)
            self.assertModulesEqual(model, task2.model)
            self.assertIsNone(training_state_none)
            model.eval()
            task2.model.eval()

            inputs = torch.LongTensor([[1, 2, 3]]), torch.LongTensor([3])
            self.assertEqual(model(*inputs).tolist(), task2.model(*inputs).tolist())

        def assertOptimizerEqual(self, optim_1, optim_2, msg=None):
            self.assertTrue(optim_1 is Optimizer and optim_2 is Optimizer, msg)
            state_dict_1 = optim_1.state_dict()
            state_dict_2 = optim_2.state_dict()
            self.assertEqual(len(state_dict_1), len(state_dict_2))
            for key_1, val_1 in optim_1.state_dict().items():
                self.assertEqualt(val_1, state_dict_2[key_1], msg)

        def test_load_checkpoint(self):
            with tempfile.NamedTemporaryFile() as checkpoint_file:
                train_data = tests_module.test_file("train_data_tiny.tsv")
                eval_data = tests_module.test_file("test_data_tiny.tsv")
                config = PyTextConfig(
                    task=DocumentClassificationTask.Config(
                        data=Data.Config(
                            source=TSVDataSource.Config(
                                train_filename=train_data,
                                eval_filename=eval_data,
                                field_names=["label", "slots", "text"],
                            )
                        )
                    ),
                    version=LATEST_VERSION,
                    save_snapshot_path=checkpoint_file.name,
                )
                task = create_task(config.task)
                model = task.model
                # test checkpoint saving and loading
                optimizer = create_optimizer(Adam.Config(), model)
                scheduler = create_scheduler(Scheduler.Config(), optimizer)
                training_state = TrainingState(
                    model=model,
                    optimizer=optimizer,
                    scheduler=scheduler,
                    start_time=0,
                    epoch=0,
                    rank=0,
                    stage=Stage.TRAIN,
                    epochs_since_last_improvement=0,
                    best_model_state=None,
                    best_model_metric=None,
                    tensorizers=task.data.tensorizers,
                )

                checkpoint_path = checkpoint_file.name

                save(
                    config,
                    model,
                    None,
                    task.data.tensorizers,
                    training_state,
                    "epoch-1",
                )
                task_restored, config_restored, training_state_restored = load(
                    checkpoint_path
                )
                optimizer_restored = training_state_restored.optimizer
                scheduler_restored = training_state_restored.scheduler
                self.assertOptimizerEqual(optimizer, optimizer_restored)
                self.assertNotNone(scheduler_restored)
                self.assertEqual(config, config_restored)
                self.assertModulesEqual(model, task_restored.model)
                model.eval()
                task_restored.model.eval()

                inputs = torch.LongTensor([[1, 2, 3]]), torch.LongTensor([3])
                self.assertEqual(
                    model(*inputs).tolist(), task_restored.model(*inputs).tolist()
                )