示例#1
0
    def retrieve_and_delete_saved(self, shard: Optional[int] = None):
        """
        Helper function for the tests below. Finds the weight and training state files in
        self.TEST_DIR, parses their names for the epochs that were saved, deletes them,
        and returns the saved epochs as two lists of integers.
        """
        serialization_files = os.listdir(self.TEST_DIR)

        model_checkpoints = [
            x for x in serialization_files if "model_state_" in x
        ]
        if shard is not None:
            model_checkpoints = [
                x for x in model_checkpoints if x.endswith(f"_w{shard}.th")
            ]
        found_model_states = [
            Checkpointer._parse_model_state_path(x) for x in model_checkpoints
        ]
        for f in model_checkpoints:
            os.remove(os.path.join(self.TEST_DIR, f))

        training_checkpoints = [
            x for x in serialization_files if "training_state_" in x
        ]
        if shard is not None:
            training_checkpoints = [
                x for x in training_checkpoints if x.endswith(f"_w{shard}.th")
            ]
        found_training_states = [
            Checkpointer._parse_training_state_path(x)
            for x in training_checkpoints
        ]
        for f in training_checkpoints:
            os.remove(os.path.join(self.TEST_DIR, f))
        return sorted(found_model_states), sorted(found_training_states)
示例#2
0
 def test_with_time(self):
     """
     Tests that keep_serialized_model_every_num_seconds parameter causes a checkpoint to be saved
     after enough time has elapsed between epochs.
     """
     num_to_keep = 10
     num_epochs = 30
     target = list(range(num_epochs - num_to_keep, num_epochs))
     pauses = [5, 18, 26]
     target = sorted(set(target + pauses))
     checkpointer = Checkpointer(
         serialization_dir=self.TEST_DIR,
         num_serialized_models_to_keep=num_to_keep,
         keep_serialized_model_every_num_seconds=1,
     )
     for e in range(num_epochs):
         if e in pauses:
             time.sleep(2)
         checkpointer.save_checkpoint(
             epoch=e,
             trainer=FakeTrainer(model_state={"epoch": e}, training_states={"epoch": e}),
             is_best_so_far=False,
         )
     models, training = self.retrieve_and_delete_saved()
     assert models == training == target
示例#3
0
    def test_trainer_saves_models_at_specified_interval(self):
        data_loader = DataLoader(self.instances,
                                 batch_size=4,
                                 collate_fn=allennlp_collate)

        trainer = GradientDescentTrainer(
            self.model,
            self.optimizer,
            data_loader,
            num_epochs=2,
            serialization_dir=self.TEST_DIR,
            checkpointer=Checkpointer(
                serialization_dir=self.TEST_DIR,
                model_save_interval=0.0001,
                num_serialized_models_to_keep=10,
            ),
        )

        trainer.train()

        # Now check the serialized files for models saved during the epoch.
        prefix = "model_state_epoch_*"
        file_names = sorted(glob.glob(os.path.join(self.TEST_DIR, prefix)))
        epochs = [
            re.search(r"_([0-9\.\-]+)\.th", fname).group(1)
            for fname in file_names
        ]
        # We should have checkpoints at the end of each epoch and during each, e.g.
        # [0.timestamp, 0, 1.timestamp, 1]
        assert len(epochs) == 4
        assert epochs[3] == "1"
        assert "." in epochs[0]

        # Now make certain we can restore from timestamped checkpoint.
        # To do so, remove the checkpoint from the end of epoch 1&2, so
        # that we are forced to restore from the timestamped checkpoints.
        for k in range(2):
            os.remove(
                os.path.join(self.TEST_DIR,
                             "model_state_epoch_{}.th".format(k)))
            os.remove(
                os.path.join(self.TEST_DIR,
                             "training_state_epoch_{}.th".format(k)))
        os.remove(os.path.join(self.TEST_DIR, "best.th"))

        restore_trainer = GradientDescentTrainer(
            self.model,
            self.optimizer,
            self.data_loader,
            num_epochs=2,
            serialization_dir=self.TEST_DIR,
            checkpointer=Checkpointer(serialization_dir=self.TEST_DIR,
                                      model_save_interval=0.0001),
        )
        epoch = restore_trainer._restore_checkpoint()
        assert epoch == 2
        # One batch per epoch.
        assert restore_trainer._batch_num_total == 2
示例#4
0
 def test_keep_zero(self):
     checkpointer = Checkpointer(
         serialization_dir=self.TEST_DIR, num_serialized_models_to_keep=0
     )
     for e in range(10):
         checkpointer.save_checkpoint(
             epoch=e,
             trainer=FakeTrainer(model_state={"epoch": e}, training_states={"epoch": e}),
             is_best_so_far=True,
         )
     files = os.listdir(self.TEST_DIR)
     assert "model_state_epoch_1.th" not in files
     assert "training_state_epoch_1.th" not in files
示例#5
0
 def test_keep_zero(self):
     checkpointer = Checkpointer(serialization_dir=self.TEST_DIR,
                                 keep_most_recent_by_count=0)
     for epochs_completed in range(5):
         state = {
             "epochs_completed": epochs_completed,
             "batches_in_epoch_completed": 0
         }
         checkpointer.maybe_save_checkpoint(
             FakeTrainer(model_state=state, training_state=state),
             epochs_completed, 0)
     files = os.listdir(self.TEST_DIR)
     assert not any("model_state_" in x for x in files)
     assert not any("training_state_" in x for x in files)
def build_trainer(
    model: Model,
    serialization_dir: str,
    train_loader: DataLoader,
    dev_loader: DataLoader
) -> Trainer:
    parameters = [
        [n, p]
        for n, p in model.named_parameters() if p.requires_grad
    ]

    checkpointer  = Checkpointer(serialization_dir, num_serialized_models_to_keep=0)
    optimizer = AdamOptimizer(parameters)
    trainer = GradientDescentTrainer(
        model=model,
        serialization_dir=serialization_dir,
        checkpointer=checkpointer,
        data_loader=train_loader,
        validation_data_loader=dev_loader,
        num_epochs=50,
        optimizer=optimizer,
        cuda_device=0,
        validation_metric="-loss",
        patience=5,


    )
    return trainer
示例#7
0
    def test_restoring_works_with_older_checkpointing(self):
        trainer = GradientDescentTrainer(
            self.model,
            self.optimizer,
            self.data_loader,
            validation_data_loader=self.validation_data_loader,
            num_epochs=3,
            serialization_dir=self.TEST_DIR,
            checkpointer=Checkpointer(
                serialization_dir=self.TEST_DIR, num_serialized_models_to_keep=4
            ),
        )
        trainer.train()

        for index in range(3):
            path = str(self.TEST_DIR / "training_state_epoch_{}.th".format(index))
            state = torch.load(path)
            state.pop("metric_tracker")
            state.pop("batch_num_total")
            state["val_metric_per_epoch"] = [0.4, 0.1, 0.8]
            torch.save(state, path)

        next_epoch = trainer._restore_checkpoint()
        best_epoch = trainer._metric_tracker.best_epoch

        # Loss decreases in 3 epochs, but because we hard fed the val metrics as above:
        assert next_epoch == 3
        assert best_epoch == 1
        assert trainer._metric_tracker._best_so_far == 0.1
        assert trainer._metric_tracker._epochs_with_no_improvement == 1
示例#8
0
    def test_default(self):
        """
        Tests that the default behavior keeps just the last 2 checkpoints.
        """
        default_num_to_keep = 2
        num_epochs = 30
        target = list(range(num_epochs - default_num_to_keep, num_epochs))

        checkpointer = Checkpointer(serialization_dir=self.TEST_DIR)

        for e in range(num_epochs):
            checkpointer.save_checkpoint(
                epoch=e,
                trainer=FakeTrainer(model_state={"epoch": e}, training_states={"epoch": e}),
                is_best_so_far=False,
            )
        models, training = self.retrieve_and_delete_saved()
        assert models == training == target
示例#9
0
 def test_with_time(self):
     num_epochs = 30
     pauses = [5, 18, 26]
     target = [(e, 0) for e in pauses]
     checkpointer = Checkpointer(
         serialization_dir=self.TEST_DIR,
         save_completed_epochs=False,
         save_every_num_seconds=1,
         keep_most_recent_by_count=3,
     )
     for e in range(num_epochs):
         if e in pauses:
             time.sleep(2)
         state = {"epochs_completed": e, "batches_in_epoch_completed": 0}
         checkpointer.maybe_save_checkpoint(
             trainer=FakeTrainer(model_state=state, training_state=state),
             num_epochs_completed=e,
             num_batches_in_epoch_completed=0,
         )
     models, training = self.retrieve_and_delete_saved()
     assert models == training == target
示例#10
0
    def test_default(self):
        """
        Tests that the default behavior keeps just the last 2 checkpoints.
        """
        default_num_to_keep = 2
        num_epochs = 5
        target = [(e, 0)
                  for e in range(num_epochs - default_num_to_keep, num_epochs)]

        checkpointer = Checkpointer(serialization_dir=self.TEST_DIR)
        for epochs_completed in range(num_epochs):
            for batches_completed in [0, 5, 10]:
                state = {
                    "epochs_completed": epochs_completed,
                    "batches_in_epoch_completed": batches_completed,
                }
                checkpointer.maybe_save_checkpoint(
                    FakeTrainer(model_state=state, training_state=state),
                    epochs_completed,
                    batches_completed,
                )
        models, training = self.retrieve_and_delete_saved()
        assert models == training == target
示例#11
0
    def test_registered_subclass(self):
        """
        Tests that registering Checkpointer subclasses works correctly.
        """

        @Checkpointer.register("checkpointer_subclass")
        class CheckpointerSubclass(Checkpointer):
            def __init__(self, x: int, y: int) -> None:
                super().__init__()
                self.x = x
                self.y = y

        sub_inst = Checkpointer.from_params(
            Params({"type": "checkpointer_subclass", "x": 1, "y": 3})
        )
        assert sub_inst.__class__ == CheckpointerSubclass
        assert sub_inst.x == 1 and sub_inst.y == 3
示例#12
0
    def test_trainer_respects_num_serialized_models_to_keep(self):
        trainer = GradientDescentTrainer(
            self.model,
            self.optimizer,
            self.data_loader,
            num_epochs=5,
            serialization_dir=self.TEST_DIR,
            checkpointer=Checkpointer(
                serialization_dir=self.TEST_DIR, num_serialized_models_to_keep=3
            ),
        )
        trainer.train()

        # Now check the serialized files
        for prefix in ["model_state_epoch_*", "training_state_epoch_*"]:
            file_names = glob.glob(os.path.join(self.TEST_DIR, prefix))
            epochs = [int(re.search(r"_([0-9])\.th", fname).group(1)) for fname in file_names]
            assert sorted(epochs) == [2, 3, 4]
示例#13
0
    def test_trainer_respects_keep_serialized_model_every_num_seconds(self):
        # To test:
        #   Create an fake data loader that sleeps for 2.5 second per epoch, so the total
        #   training time for one epoch is slightly greater then 2.5 seconds.
        #   Run for 6 epochs, keeping the last 2 models, models also kept every 5 seconds.
        #   Check the resulting checkpoints.  Should then have models at epochs
        #       2, 4, plus the last two at 5 and 6.

        class SlowDataLoader:
            data_loader = SimpleDataLoader(self.instances, batch_size=2)

            def __iter__(self):
                time.sleep(2.5)
                return iter(self.data_loader)

            def __len__(self):
                return len(self.data_loader)

            def set_target_device(self, _):
                pass

        trainer = GradientDescentTrainer(
            self.model,
            self.optimizer,
            SlowDataLoader(),
            num_epochs=6,
            serialization_dir=self.TEST_DIR,
            checkpointer=Checkpointer(
                serialization_dir=self.TEST_DIR,
                num_serialized_models_to_keep=2,
                keep_serialized_model_every_num_seconds=5,
            ),
        )
        trainer.train()

        # Now check the serialized files
        for prefix in ["model_state_epoch_*", "training_state_epoch_*"]:
            file_names = glob.glob(os.path.join(self.TEST_DIR, prefix))
            epochs = [
                int(re.search(r"_([0-9])\.th", fname).group(1))
                for fname in file_names
            ]
            # epoch N has N-1 in file name
            assert sorted(epochs) == [1, 3, 4, 5]
示例#14
0
    def test_trainer_saves_metrics_every_epoch(self):
        trainer = GradientDescentTrainer(
            model=self.model,
            optimizer=self.optimizer,
            data_loader=self.data_loader,
            validation_data_loader=self.validation_data_loader,
            num_epochs=5,
            serialization_dir=self.TEST_DIR,
            checkpointer=Checkpointer(serialization_dir=self.TEST_DIR,
                                      num_serialized_models_to_keep=3),
        )
        trainer.train()

        for epoch in range(5):
            epoch_file = self.TEST_DIR / f"metrics_epoch_{epoch}.json"
            assert epoch_file.exists()
            metrics = json.load(open(epoch_file))
            assert "validation_loss" in metrics
            assert "best_validation_loss" in metrics
            assert metrics.get("epoch") == epoch
示例#15
0
    def test_default_distributed_with_sharded_state(self):
        """
        Simulates using the Checkpointer during distributed training with a sharded model.
        """
        world_size = 2
        default_num_to_keep = 2
        num_epochs = 5
        target = [(e, 0)
                  for e in range(num_epochs - default_num_to_keep, num_epochs)]

        checkpointers = [
            Checkpointer(serialization_dir=self.TEST_DIR)
            for _ in range(world_size)
        ]
        for i, checkpointer in enumerate(checkpointers):
            checkpointer._rank = i
            checkpointer.state_is_sharded = True

        for epochs_completed in range(num_epochs):
            for batches_completed in [0, 5, 10]:
                for i, checkpointer in enumerate(checkpointers):
                    state = {
                        "epochs_completed": epochs_completed,
                        "batches_in_epoch_completed": batches_completed,
                        "rank": i,
                    }
                    checkpointer.maybe_save_checkpoint(
                        FakeTrainer(model_state=state, training_state=state),
                        epochs_completed,
                        batches_completed,
                    )

        for i, checkpointer in enumerate(checkpointers):
            checkpoint = checkpointer.load_checkpoint()
            assert checkpoint is not None
            model_state, training_state = checkpoint
            assert model_state["rank"] == i
            assert training_state["rank"] == i

            models, training = self.retrieve_and_delete_saved(shard=i)
            assert models == training == target
示例#16
0
 def test_base_class_from_params(self):
     Checkpointer.from_params(Params({}), serialization_dir=self.TEST_DIR)
示例#17
0
    def run(  # type: ignore
        self,
        model: Lazy[Model],
        dataset: DatasetDict,
        data_loader: Lazy[TangoDataLoader],
        optimizer: Lazy[Optimizer],
        validation_data_loader: Optional[Lazy[TangoDataLoader]] = None,
        training_split: str = "train",
        validation_split: Optional[str] = None,
        patience: Optional[int] = None,
        validation_metric: Union[str, List[str]] = "-loss",
        num_epochs: int = 20,
        checkpointer: Optional[Lazy[Checkpointer]] = None,
        grad_norm: Union[float, bool] = False,
        grad_clipping: Optional[float] = None,
        learning_rate_scheduler: Optional[Lazy[LearningRateScheduler]] = None,
        momentum_scheduler: Optional[Lazy[MomentumScheduler]] = None,
        moving_average: Optional[Lazy[MovingAverage]] = None,
        callbacks: List[Lazy[TrainerCallback]] = None,
        num_gradient_accumulation_steps: int = 1,
        use_amp: bool = False,
        enable_default_callbacks: bool = True,
        run_confidence_checks: bool = True,
        no_grad: Optional[List[str]] = None,
        limit_batches_per_epoch: Optional[int] = None,
    ) -> Model:
        serialization_dir = self.work_dir()

        if validation_data_loader is None:
            validation_data_loader = data_loader
        if validation_split is None:
            validation_loader = None
        else:
            concrete_validation_data_loader = validation_data_loader.construct(
                instances=dataset.splits[validation_split])
            del validation_data_loader
            if limit_batches_per_epoch is not None:
                concrete_validation_data_loader = MaxBatchesDataLoader(
                    concrete_validation_data_loader, limit_batches_per_epoch)
            validation_loader = DataLoaderAdapter(
                tango_data_loader=concrete_validation_data_loader)

        concrete_data_loader = data_loader.construct(
            instances=dataset.splits[training_split])
        del data_loader
        if limit_batches_per_epoch is not None:
            concrete_data_loader = MaxBatchesDataLoader(
                concrete_data_loader, limit_batches_per_epoch)
        loader = DataLoaderAdapter(tango_data_loader=concrete_data_loader)

        if torch.cuda.device_count() > 0:
            cuda_device = torch.device(0)
        else:
            cuda_device = torch.device("cpu")
        check_for_gpu(cuda_device)
        loader.set_target_device(cuda_device)
        if validation_loader is not None:
            validation_loader.set_target_device(cuda_device)

        concrete_model = model.construct(vocab=dataset.vocab).to(cuda_device)
        del model
        if no_grad:
            for name, parameter in concrete_model.named_parameters():
                if any(re.search(regex, name) for regex in no_grad):
                    parameter.requires_grad_(False)
        parameters = [[n, p] for n, p in concrete_model.named_parameters()
                      if p.requires_grad]
        concrete_optimizer = optimizer.construct(model_parameters=parameters)
        del optimizer
        log_frozen_and_tunable_parameter_names(concrete_model)

        concrete_moving_average = (None if moving_average is None else
                                   moving_average.construct(
                                       parameters=parameters))
        del moving_average

        concrete_learning_rate_scheduler = (
            None if learning_rate_scheduler is None else
            learning_rate_scheduler.construct(
                optimizer=concrete_optimizer,
                num_epochs=num_epochs,
                num_steps_per_epoch=concrete_data_loader.num_batches_per_epoch(
                ),
            ))
        del learning_rate_scheduler

        concrete_momentum_scheduler = (None if momentum_scheduler is None else
                                       momentum_scheduler.construct(
                                           optimizer=concrete_optimizer))
        del momentum_scheduler

        if checkpointer is not None:
            concrete_checkpointer = checkpointer.construct(
                serialization_dir=serialization_dir)
        else:
            concrete_checkpointer = Checkpointer(serialization_dir)
        del checkpointer

        concrete_callbacks: List[TrainerCallback] = [
            cb.construct(serialization_dir=serialization_dir)
            for cb in callbacks or []
        ]
        del callbacks

        trainer = GradientDescentTrainer(
            concrete_model,
            optimizer=concrete_optimizer,
            data_loader=loader,
            patience=patience,
            validation_metric=validation_metric,
            validation_data_loader=validation_loader,
            num_epochs=num_epochs,
            serialization_dir=serialization_dir,
            checkpointer=concrete_checkpointer,
            grad_norm=grad_norm,
            grad_clipping=grad_clipping,
            learning_rate_scheduler=concrete_learning_rate_scheduler,
            momentum_scheduler=concrete_momentum_scheduler,
            moving_average=concrete_moving_average,
            callbacks=concrete_callbacks,
            num_gradient_accumulation_steps=num_gradient_accumulation_steps,
            use_amp=use_amp,
            enable_default_callbacks=enable_default_callbacks,
            run_confidence_checks=run_confidence_checks,
        )
        trainer.train()

        return trainer.model
示例#18
0
 def test_base_class_from_params(self):
     Checkpointer.from_params(Params({}))
                        encoder=encoder,
                        decoder=decoder)

# # データローダ
train_loader = PyTorchDataLoader(train_dataset, batch_size=32, shuffle=True)
validation_loader = PyTorchDataLoader(validation_dataset,
                                      batch_size=32,
                                      shuffle=False)

# GPU上にモデルをコピー
if args.cuda:
    model = model.cuda()

# オプティマイザの作成
optimizer = AdamOptimizer(model.named_parameters())
checkpointer = Checkpointer(serialization_dir=args.serialization_dir,
                            num_serialized_models_to_keep=None)
learning_rate_scheduler = LinearWithWarmup(
    optimizer=optimizer,
    num_epochs=args.num_epochs,
    num_steps_per_epoch=len(train_loader),
    warmup_steps=4000)
# トレイナの作成
trainer = GradientDescentTrainer(
    model=model,
    optimizer=optimizer,
    data_loader=train_loader,
    validation_data_loader=validation_loader,
    validation_metric="-loss",
    num_epochs=args.num_epochs,
    learning_rate_scheduler=learning_rate_scheduler,
    checkpointer=checkpointer)