示例#1
0
def test_train(logger, monkeypatch):
    import wandb
    import data
    from train import Trainer, parse_args
    import data.evaluation.multiwoz
    from unittest.mock import MagicMock
    import tensorboardX

    with tempfile.TemporaryDirectory() as d:
        with monkeypatch.context() as m:
            m.setattr(wandb, 'run', MagicMock())
            m.setattr(wandb, 'log', MagicMock())
            m.setattr(wandb, 'config', MagicMock())
            m.setattr(wandb, 'tensorboard', MagicMock())
            m.setattr(tensorboardX, 'SummaryWriter', MagicMock())
            wandb.run.dir = d

            old_load_dataset = data.load_dataset
            m.setattr(
                data, 'load_dataset', lambda name, **kwargs: patch_dataset(
                    old_load_dataset(name, **kwargs)))
            m.setattr("sys.argv", ["train.py"])
            args = parse_args()
            # args.fp16 = True
            args.batch_size = 2
            args.epochs = 1
            args.logging_steps = 1
            args.validation_steps = 1
            args.evaluation_dialogs = 1
            trainer = Trainer(args, logger)

            # Patch prediction
            old_prediction = trainer._run_prediction

            def mock_prediction(*args, **kwargs):  # noqa:E306
                m.setattr(trainer.dev_predictor.pipeline.predictor,
                          'max_belief_length', 2)
                m.setattr(trainer.dev_predictor.pipeline.predictor,
                          'max_response_length', 2)
                return old_prediction(*args, **kwargs)

            m.setattr(trainer, '_run_prediction', mock_prediction)

            # Patch evaluation
            # old_evaluation = trainer._run_evaluation
            # def mock_evaluation(*args, **kwargs):  # noqa:E306
            #     # To speed up the generation
            #     old_generate = trainer.dev_predictor.predictor.predictor.model.generate
            #     cached_result = dict()
            #     def mock_generate(input_ids, *args, **kwargs):  # noqa:E306
            #         nonlocal cached_result
            #         assert 'eos_token_id' in kwargs
            #         eos_token_id = kwargs.get('eos_token_id')
            #         if eos_token_id not in cached_result:
            #             r = old_generate(*args, input_ids=input_ids, **kwargs)[:, len(input_ids[0]):]
            #             eos = torch.zeros_like(r).fill_(kwargs.get('eos_token_id'))
            #             cached_result[eos_token_id] = torch.cat([torch.zeros_like(r), r, eos], 1)
            #         return torch.cat([input_ids, cached_result[eos_token_id]], 1)

            #     m.setattr(trainer.dev_predictor.predictor.predictor.model, 'generate', mock_generate)
            #     m.setattr(trainer.dev_predictor.predictor.predictor, 'max_belief_length', 2)
            #     m.setattr(trainer.dev_predictor.predictor.predictor, 'max_response_length', 2)
            #     return old_evaluation(*args, **kwargs)
            # m.setattr(trainer, '_run_evaluation', mock_evaluation)

            # Patch publish artifact
            trainer._publish_artifact = lambda: None

            # Run train
            trainer.train()