예제 #1
0
    def test_trainer_respects_keep_serialized_model_every_num_seconds(self):
        # To test:
        #   Create an fake data loader that sleeps for 2.5 second per epoch, so the total
        #   training time for one epoch is slightly greater then 2.5 seconds.
        #   Run for 6 epochs, keeping the last 2 models, models also kept every 5 seconds.
        #   Check the resulting checkpoints.  Should then have models at epochs
        #       2, 4, plus the last two at 5 and 6.

        class SlowDataLoader:
            data_loader = DataLoader(self.instances,
                                     batch_size=2,
                                     collate_fn=allennlp_collate)

            def __iter__(self):
                time.sleep(2.5)
                return iter(self.data_loader)

            def __len__(self):
                return len(self.data_loader)

        trainer = GradientDescentTrainer(
            self.model,
            self.optimizer,
            SlowDataLoader(),
            num_epochs=6,
            serialization_dir=self.TEST_DIR,
            checkpointer=Checkpointer(
                serialization_dir=self.TEST_DIR,
                num_serialized_models_to_keep=2,
                keep_serialized_model_every_num_seconds=5,
            ),
        )
        trainer.train()

        # Now check the serialized files
        for prefix in ["model_state_epoch_*", "training_state_epoch_*"]:
            file_names = glob.glob(os.path.join(self.TEST_DIR, prefix))
            epochs = [
                int(re.search(r"_([0-9])\.th", fname).group(1))
                for fname in file_names
            ]
            # epoch N has N-1 in file name
            assert sorted(epochs) == [1, 3, 4, 5]
예제 #2
0
    def test_epoch_callback_is_called_at_every_epoch(self):
        class FakeEpochCallback(EpochCallback):
            def __call__(self, trainer: "GradientDescentTrainer",
                         metrics: Dict[str, Any], epoch: int) -> None:
                if not hasattr(trainer, "epoch_callback_calls"):
                    trainer.epoch_callback_calls = []  # type: ignore
                trainer.epoch_callback_calls.append(epoch)  # type: ignore

        trainer = GradientDescentTrainer(
            self.model,
            self.optimizer,
            self.data_loader,
            num_epochs=4,
            validation_data_loader=self.validation_data_loader,
            epoch_callbacks=[FakeEpochCallback()],
        )
        trainer.train()
        expected_calls = [epoch for epoch in range(-1, 4)]
        assert trainer.epoch_callback_calls == expected_calls
예제 #3
0
    def test_trainer_can_log_histograms(self):
        # enable activation logging
        for module in self.model.modules():
            module.should_log_activations = True

        trainer = GradientDescentTrainer(
            self.model,
            self.optimizer,
            self.data_loader,
            num_epochs=3,
            serialization_dir=self.TEST_DIR,
            callbacks=[
                TensorBoardCallback(
                    serialization_dir=self.TEST_DIR,
                    distribution_interval=2,
                )
            ],
        )
        trainer.train()
예제 #4
0
    def test_trainer_can_run_and_resume_with_momentum_scheduler(self):
        scheduler = MomentumScheduler.from_params(
            optimizer=self.optimizer,
            params=Params({
                "type": "inverted_triangular",
                "cool_down": 2,
                "warm_up": 2
            }),
        )
        trainer = GradientDescentTrainer(
            model=self.model,
            optimizer=self.optimizer,
            data_loader=self.data_loader,
            momentum_scheduler=scheduler,
            validation_metric="-loss",
            validation_data_loader=self.validation_data_loader,
            num_epochs=4,
            serialization_dir=self.TEST_DIR,
        )
        trainer.train()

        new_scheduler = MomentumScheduler.from_params(
            optimizer=self.optimizer,
            params=Params({
                "type": "inverted_triangular",
                "cool_down": 2,
                "warm_up": 2
            }),
        )
        new_trainer = GradientDescentTrainer(
            model=self.model,
            optimizer=self.optimizer,
            data_loader=self.data_loader,
            momentum_scheduler=new_scheduler,
            validation_metric="-loss",
            validation_data_loader=self.validation_data_loader,
            num_epochs=6,
            serialization_dir=self.TEST_DIR,
        )
        epoch = new_trainer._restore_checkpoint()
        assert epoch == 4
        assert new_trainer._momentum_scheduler.last_epoch == 3
        new_trainer.train()
예제 #5
0
    def test_trainer_respects_num_serialized_models_to_keep(self):
        trainer = GradientDescentTrainer(
            self.model,
            self.optimizer,
            self.data_loader,
            num_epochs=5,
            serialization_dir=self.TEST_DIR,
            checkpointer=Checkpointer(serialization_dir=self.TEST_DIR,
                                      num_serialized_models_to_keep=3),
        )
        trainer.train()

        # Now check the serialized files
        for prefix in ["model_state_epoch_*", "training_state_epoch_*"]:
            file_names = glob.glob(os.path.join(self.TEST_DIR, prefix))
            epochs = [
                int(re.search(r"_([0-9])\.th", fname).group(1))
                for fname in file_names
            ]
            assert sorted(epochs) == [2, 3, 4]
예제 #6
0
    def test_trainer_saves_metrics_every_epoch(self):
        trainer = GradientDescentTrainer(
            model=self.model,
            optimizer=self.optimizer,
            data_loader=self.data_loader,
            validation_data_loader=self.validation_data_loader,
            num_epochs=5,
            serialization_dir=self.TEST_DIR,
            checkpointer=Checkpointer(serialization_dir=self.TEST_DIR,
                                      num_serialized_models_to_keep=3),
        )
        trainer.train()

        for epoch in range(5):
            epoch_file = self.TEST_DIR / f"metrics_epoch_{epoch}.json"
            assert epoch_file.exists()
            metrics = json.load(open(epoch_file))
            assert "validation_loss" in metrics
            assert "best_validation_loss" in metrics
            assert metrics.get("epoch") == epoch
예제 #7
0
 def test_trainer_can_run_cuda(self):
     self.model.cuda()
     trainer = GradientDescentTrainer(
         self.model, self.optimizer, self.data_loader, num_epochs=2, cuda_device=0
     )
     metrics = trainer.train()
     assert "peak_cpu_memory_MB" in metrics
     assert isinstance(metrics["peak_cpu_memory_MB"], float)
     assert metrics["peak_cpu_memory_MB"] > 0
     assert "peak_gpu_0_memory_MB" in metrics
     assert isinstance(metrics["peak_gpu_0_memory_MB"], int)
예제 #8
0
 def test_trainer_can_run_amp(self):
     self.model.cuda()
     trainer = GradientDescentTrainer(
         self.model,
         self.optimizer,
         self.data_loader,
         num_epochs=2,
         cuda_device=0,
         opt_level="O1",
     )
     _ = trainer.train()
예제 #9
0
 def test_trainer_can_run_amp(self, grad_norm, num_gradient_accumulation_steps):
     self.model.cuda()
     trainer = GradientDescentTrainer(
         self.model,
         self.optimizer,
         self.data_loader,
         num_epochs=2,
         cuda_device=0,
         use_amp=True,
         grad_norm=True,
         num_gradient_accumulation_steps=num_gradient_accumulation_steps,
     )
     _ = trainer.train()
예제 #10
0
    def test_trainer_can_run(self):
        trainer = GradientDescentTrainer(
            model=self.model,
            optimizer=self.optimizer,
            data_loader=self.data_loader,
            validation_data_loader=self.validation_data_loader,
            num_epochs=2,
        )
        metrics = trainer.train()
        assert "best_validation_loss" in metrics
        assert isinstance(metrics["best_validation_loss"], float)
        assert "best_validation_accuracy" in metrics
        assert isinstance(metrics["best_validation_accuracy"], float)
        assert "best_validation_accuracy3" in metrics
        assert isinstance(metrics["best_validation_accuracy3"], float)
        assert "best_epoch" in metrics
        assert isinstance(metrics["best_epoch"], int)

        # Making sure that both increasing and decreasing validation metrics work.
        trainer = GradientDescentTrainer(
            model=self.model,
            optimizer=self.optimizer,
            data_loader=self.data_loader,
            validation_data_loader=self.validation_data_loader,
            validation_metric="+loss",
            num_epochs=2,
        )
        metrics = trainer.train()
        assert "best_validation_loss" in metrics
        assert isinstance(metrics["best_validation_loss"], float)
        assert "best_validation_accuracy" in metrics
        assert isinstance(metrics["best_validation_accuracy"], float)
        assert "best_validation_accuracy3" in metrics
        assert isinstance(metrics["best_validation_accuracy3"], float)
        assert "best_epoch" in metrics
        assert isinstance(metrics["best_epoch"], int)
        assert "peak_cpu_memory_MB" in metrics
        assert isinstance(metrics["peak_cpu_memory_MB"], float)
        assert metrics["peak_cpu_memory_MB"] > 0
예제 #11
0
    def test_trainer_saves_and_loads_best_validation_metrics_correctly_2(self):
        # Use -loss and run 1 epoch of original-training, and one of restored-training
        # Run 1 epoch of original training.
        trainer = GradientDescentTrainer(
            self.model,
            self.optimizer,
            self.data_loader,
            validation_data_loader=self.validation_data_loader,
            validation_metric="+loss",
            num_epochs=1,
            serialization_dir=self.TEST_DIR,
        )
        trainer.train()

        _ = trainer._restore_checkpoint()
        best_epoch_1 = trainer._metric_tracker.best_epoch
        best_validation_metrics_epoch_1 = trainer._metric_tracker.best_epoch_metrics
        # best_validation_metrics_epoch_1: {'accuracy': 0.75, 'accuracy3': 1.0, 'loss': 0.6243013441562653}
        assert isinstance(best_validation_metrics_epoch_1, dict)
        assert "loss" in best_validation_metrics_epoch_1

        # Run 1 more epoch of restored training.
        restore_trainer = GradientDescentTrainer(
            self.model,
            self.optimizer,
            self.data_loader,
            validation_data_loader=self.validation_data_loader,
            validation_metric="+loss",
            num_epochs=2,
            serialization_dir=self.TEST_DIR,
        )
        restore_trainer.train()
        _ = restore_trainer._restore_checkpoint()
        best_epoch_2 = restore_trainer._metric_tracker.best_epoch
        best_validation_metrics_epoch_2 = restore_trainer._metric_tracker.best_epoch_metrics

        # Because of using +loss, 2nd epoch won't be better than 1st. So best val metrics should be same.
        assert best_epoch_1 == best_epoch_2 == 0
        assert best_validation_metrics_epoch_2 == best_validation_metrics_epoch_1
예제 #12
0
    def test_restored_training_returns_best_epoch_metrics_even_if_no_better_epoch_is_found_after_restoring(
        self, ):
        # Instead of -loss, use +loss to assure 2nd epoch is considered worse.
        # Run 1 epoch of original training.
        original_trainer = GradientDescentTrainer(
            self.model,
            self.optimizer,
            self.data_loader,
            validation_data_loader=self.validation_data_loader,
            validation_metric="+loss",
            num_epochs=1,
            serialization_dir=self.TEST_DIR,
        )
        training_metrics = original_trainer.train()

        # Run 1 epoch of restored training.
        restored_trainer = GradientDescentTrainer(
            self.model,
            self.optimizer,
            self.data_loader,
            validation_data_loader=self.validation_data_loader,
            validation_metric="+loss",
            num_epochs=2,
            serialization_dir=self.TEST_DIR,
        )
        restored_metrics = restored_trainer.train()

        assert "best_validation_loss" in restored_metrics
        assert "best_validation_accuracy" in restored_metrics
        assert "best_validation_accuracy3" in restored_metrics
        assert "best_epoch" in restored_metrics

        # Epoch 2 validation loss should be lesser than that of Epoch 1
        assert training_metrics["best_validation_loss"] == restored_metrics[
            "best_validation_loss"]
        assert training_metrics["best_epoch"] == 0
        assert training_metrics["validation_loss"] > restored_metrics[
            "validation_loss"]
예제 #13
0
 def test_data_loader_lazy_epoch_size_correct(self):
     num_epochs = 3
     trainer = GradientDescentTrainer(
         self.model,
         self.optimizer,
         self.data_loader_lazy,
         validation_data_loader=self.validation_data_loader,
         num_epochs=num_epochs,
         serialization_dir=self.TEST_DIR,
     )
     assert trainer._batch_num_total == 0
     metrics = trainer.train()
     epoch = metrics["epoch"]
     assert epoch == num_epochs - 1
     assert trainer._batch_num_total == num_epochs * 2
예제 #14
0
 def test_trainer_respects_epoch_size_smaller_tnan_total(self):
     batches_per_epoch = 1
     num_epochs = 2
     data_loader_smaller_epoch = SimpleDataLoader(
         self.instances,
         2,
         batches_per_epoch=batches_per_epoch,
     )
     trainer = GradientDescentTrainer(
         self.model,
         self.optimizer,
         data_loader_smaller_epoch,
         validation_data_loader=self.validation_data_loader,
         num_epochs=num_epochs,
         serialization_dir=self.TEST_DIR,
     )
     assert trainer._batch_num_total == 0
     metrics = trainer.train()
     epoch = metrics["epoch"]
     assert epoch == num_epochs - 1
     assert trainer._batch_num_total == num_epochs * batches_per_epoch
예제 #15
0
 def test_data_loader_lazy_epoch_size_correct_custom_epoch_size(self):
     batches_per_epoch = 3
     num_epochs = 3
     data_loader_custom_epoch_lazy = PyTorchDataLoader(
         self.instances_lazy,
         batch_size=2,
         collate_fn=allennlp_collate,
         batches_per_epoch=batches_per_epoch,
     )
     trainer = GradientDescentTrainer(
         self.model,
         self.optimizer,
         data_loader_custom_epoch_lazy,
         validation_data_loader=self.validation_data_loader,
         num_epochs=num_epochs,
         serialization_dir=self.TEST_DIR,
     )
     assert trainer._batch_num_total == 0
     metrics = trainer.train()
     epoch = metrics["epoch"]
     assert epoch == num_epochs - 1
     assert trainer._batch_num_total == num_epochs * batches_per_epoch
예제 #16
0
    def test_trainer_can_run_gradient_accumulation(self):
        instances = list(self.instances)
        steps_to_accumulate = 2

        trainer = GradientDescentTrainer(
            self.model,
            self.optimizer,
            self.data_loader,
            validation_data_loader=self.validation_data_loader,
            num_epochs=2,
            num_gradient_accumulation_steps=steps_to_accumulate,
        )
        assert trainer._num_gradient_accumulation_steps == steps_to_accumulate

        metrics = trainer.train()

        num_batches_trained_per_epoch = trainer._batch_num_total // (metrics["training_epochs"] + 1)
        num_batches_expected = math.ceil(
            math.ceil(len(instances) / self.data_loader.batch_size) / steps_to_accumulate
        )

        assert num_batches_trained_per_epoch == num_batches_expected
예제 #17
0
 def test_trainer_respects_epoch_size_larger_tnan_total(self):
     batches_per_epoch = 7
     num_epochs = 3
     data_loader_larger_epoch = AllennlpDataLoader(
         self.instances,
         batch_size=2,
         collate_fn=allennlp_collate,
         batches_per_epoch=batches_per_epoch,
     )
     trainer = GradientDescentTrainer(
         self.model,
         self.optimizer,
         data_loader_larger_epoch,
         validation_data_loader=self.validation_data_loader,
         num_epochs=num_epochs,
         serialization_dir=self.TEST_DIR,
     )
     assert trainer._batch_num_total == 0
     metrics = trainer.train()
     epoch = metrics["epoch"]
     assert epoch == num_epochs - 1
     assert trainer._batch_num_total == num_epochs * batches_per_epoch
예제 #18
0
    def test_total_loss_is_average_of_batch_loss(self):

        batches_per_epoch = 3

        data_loader_custom_epoch_lazy = PyTorchDataLoader(
            self.instances_lazy,
            batch_size=2,
            collate_fn=allennlp_collate,
            batches_per_epoch=batches_per_epoch,
        )

        class FakeBatchCallback(BatchCallback):
            def __call__(
                self,
                trainer: "GradientDescentTrainer",
                batch_inputs: List[List[TensorDict]],
                batch_outputs: List[Dict[str, Any]],
                batch_metrics: Dict[str, Any],
                epoch: int,
                batch_number: int,
                is_training: bool,
                is_master: bool,
            ) -> None:
                if not hasattr(trainer, "batch_losses"):
                    trainer.batch_losses = []  # type: ignore
                trainer.batch_losses.append(
                    batch_outputs[0]["loss"].item())  # type: ignore

        trainer = GradientDescentTrainer(
            self.model,
            self.optimizer,
            data_loader_custom_epoch_lazy,
            num_epochs=1,
            batch_callbacks=[FakeBatchCallback()],
        )
        metrics = trainer.train()

        assert metrics["training_loss"] == float(
            sum(trainer.batch_losses) / batches_per_epoch)
def objective_fn(
        trial: Trial,
        device: int,
        direction: str,
        target_metric: str,
        base_serialization_dir: str,
):
    embedding_dim = trial.suggest_int("embedding_dim", 128, 256)
    max_filter_size = trial.suggest_int("max_filter_size", 3, 6)
    num_filters = trial.suggest_int("num_filters", 128, 256)
    output_dim = trial.suggest_int("output_dim", 128, 512)
    dropout = trial.suggest_float("dropout", 0, 1.0, log=False)
    lr = trial.suggest_float("lr", 1e-4, 1e-1, log=True)

    train_dataset, valid_dataset, vocab = prepare_data()
    model = create_model(vocab, embedding_dim, max_filter_size, num_filters, output_dim, dropout)

    if device > -1:
        model.to(torch.device("cuda:{}".format(device)))

    optimizer = SGD(model.parameters(), lr=lr)
    data_loader = DataLoader(train_dataset, batch_size=10, collate_fn=allennlp_collate)
    validation_data_loader = DataLoader(valid_dataset, batch_size=64, collate_fn=allennlp_collate)
    serialization_dir = os.path.join(base_serialization_dir, "trial_{}".format(trial.number))
    trainer = GradientDescentTrainer(
        model=model,
        optimizer=optimizer,
        data_loader=data_loader,
        validation_data_loader=validation_data_loader,
        validation_metric=("+" if direction == "MAXIMIZE" else "-") + target_metric,
        patience=None,  # `patience=None` since it could conflict with AllenNLPPruningCallback
        num_epochs=50,
        cuda_device=device,
        serialization_dir=serialization_dir,
        epoch_callbacks=[AllenNLPPruningCallback(trial, f"validation_{target_metric}")],
    )
    vocab.save_to_files(os.path.join(serialization_dir, "vocabulary"))
    return trainer.train()[f"best_validation_{target_metric}"]
예제 #20
0
    def test_trainer_callback_is_called_everywhere(self):
        class FakeTrainerCallback(TrainerCallback):
            def on_start(self,
                         trainer: "GradientDescentTrainer",
                         is_primary: bool = True,
                         **kwargs) -> None:
                if not hasattr(trainer, "start_callback_is_fired_first"):
                    trainer.start_callback_is_fired_first = True  # type: ignore

            def on_batch(
                self,
                trainer: "GradientDescentTrainer",
                batch_inputs: List[TensorDict],
                batch_outputs: List[Dict[str, Any]],
                batch_metrics: Dict[str, Any],
                epoch: int,
                batch_number: int,
                is_training: bool,
                is_primary: bool = True,
                batch_grad_norm: Optional[float] = None,
                **kwargs,
            ) -> None:
                if not hasattr(trainer, "start_callback_is_fired_first"):
                    trainer.start_callback_is_fired_first = False  # type: ignore

                if not hasattr(trainer, "batch_callback_calls"):
                    trainer.batch_callback_calls = []  # type: ignore
                trainer.batch_callback_calls.append(
                    (epoch, batch_number, is_training))  # type: ignore

            def on_epoch(
                self,
                trainer: "GradientDescentTrainer",
                metrics: Dict[str, Any],
                epoch: int,
                is_primary: bool = True,
                **kwargs,
            ) -> None:
                if not hasattr(trainer, "start_callback_is_fired_first"):
                    trainer.start_callback_is_fired_first = False  # type: ignore

                if not hasattr(trainer, "epoch_callback_calls"):
                    trainer.epoch_callback_calls = []  # type: ignore
                trainer.epoch_callback_calls.append(epoch)  # type: ignore

            def on_end(
                self,
                trainer: "GradientDescentTrainer",
                metrics: Dict[str, Any] = None,
                epoch: int = None,
                is_primary: bool = True,
                **kwargs,
            ) -> None:
                if not hasattr(trainer, "start_callback_is_fired_first"):
                    trainer.start_callback_is_fired_first = False  # type: ignore

                if not hasattr(trainer, "end_callback_calls"):
                    trainer.end_callback_calls = []  # type: ignore
                trainer.end_callback_calls.append(epoch)  # type: ignore

        trainer = GradientDescentTrainer(
            self.model,
            self.optimizer,
            self.data_loader,
            num_epochs=2,
            validation_data_loader=self.validation_data_loader,
            callbacks=[FakeTrainerCallback(serialization_dir=self.TEST_DIR)],
        )
        trainer.train()
        expected_batch_calls = [
            (epoch, batch_number + 1, is_train) for epoch in range(2)
            for is_train in (True, False)
            for batch_number in range(len(self.instances) // 2)
        ]
        expected_epoch_calls = [epoch for epoch in range(0, 2)]
        expected_end_calls = [1]

        assert trainer.start_callback_is_fired_first
        assert trainer.batch_callback_calls == expected_batch_calls
        assert trainer.epoch_callback_calls == expected_epoch_calls
        assert trainer.end_callback_calls == expected_end_calls
예제 #21
0
# テキストの特徴ベクトルの作成
text_embedder = BasicTextFieldEmbedder({"tokens": embedding})
encoder = BagOfEmbeddingsEncoder(embedding_dim=100)

# 文書分類器の作成
model = BasicClassifier(vocab=vocab,
                        text_field_embedder=text_embedder,
                        seq2vec_encoder=encoder)

# データローダ
train_loader = PyTorchDataLoader(train_dataset, batch_size=32, shuffle=True)
validation_loader = PyTorchDataLoader(validation_dataset,
                                      batch_size=32,
                                      shuffle=False)

# GPU上にモデルをコピー
# model = model.cuda()

# オプティマイザの作成
optimizer = AdamOptimizer(model.named_parameters())

# トレイナの作成
trainer = GradientDescentTrainer(model=model,
                                 optimizer=optimizer,
                                 data_loader=train_loader,
                                 validation_data_loader=validation_loader,
                                 num_epochs=10,
                                 patience=3)

metrics = trainer.train()
pprint.pprint(metrics)
예제 #22
0
    def run(  # type: ignore
        self,
        model: Lazy[Model],
        dataset: DatasetDict,
        data_loader: Lazy[TangoDataLoader],
        optimizer: Lazy[Optimizer],
        validation_data_loader: Optional[Lazy[TangoDataLoader]] = None,
        training_split: str = "train",
        validation_split: Optional[str] = None,
        patience: Optional[int] = None,
        validation_metric: Union[str, List[str]] = "-loss",
        num_epochs: int = 20,
        checkpointer: Optional[Lazy[Checkpointer]] = None,
        grad_norm: Union[float, bool] = False,
        grad_clipping: Optional[float] = None,
        learning_rate_scheduler: Optional[Lazy[LearningRateScheduler]] = None,
        momentum_scheduler: Optional[Lazy[MomentumScheduler]] = None,
        moving_average: Optional[Lazy[MovingAverage]] = None,
        callbacks: List[Lazy[TrainerCallback]] = None,
        num_gradient_accumulation_steps: int = 1,
        use_amp: bool = False,
        enable_default_callbacks: bool = True,
        run_confidence_checks: bool = True,
        no_grad: Optional[List[str]] = None,
        limit_batches_per_epoch: Optional[int] = None,
    ) -> Model:
        serialization_dir = self.work_dir()

        if validation_data_loader is None:
            validation_data_loader = data_loader
        if validation_split is None:
            validation_loader = None
        else:
            concrete_validation_data_loader = validation_data_loader.construct(
                instances=dataset.splits[validation_split])
            del validation_data_loader
            if limit_batches_per_epoch is not None:
                concrete_validation_data_loader = MaxBatchesDataLoader(
                    concrete_validation_data_loader, limit_batches_per_epoch)
            validation_loader = DataLoaderAdapter(
                tango_data_loader=concrete_validation_data_loader)

        concrete_data_loader = data_loader.construct(
            instances=dataset.splits[training_split])
        del data_loader
        if limit_batches_per_epoch is not None:
            concrete_data_loader = MaxBatchesDataLoader(
                concrete_data_loader, limit_batches_per_epoch)
        loader = DataLoaderAdapter(tango_data_loader=concrete_data_loader)

        if torch.cuda.device_count() > 0:
            cuda_device = torch.device(0)
        else:
            cuda_device = torch.device("cpu")
        check_for_gpu(cuda_device)
        loader.set_target_device(cuda_device)
        if validation_loader is not None:
            validation_loader.set_target_device(cuda_device)

        concrete_model = model.construct(vocab=dataset.vocab).to(cuda_device)
        del model
        if no_grad:
            for name, parameter in concrete_model.named_parameters():
                if any(re.search(regex, name) for regex in no_grad):
                    parameter.requires_grad_(False)
        parameters = [[n, p] for n, p in concrete_model.named_parameters()
                      if p.requires_grad]
        concrete_optimizer = optimizer.construct(model_parameters=parameters)
        del optimizer
        log_frozen_and_tunable_parameter_names(concrete_model)

        concrete_moving_average = (None if moving_average is None else
                                   moving_average.construct(
                                       parameters=parameters))
        del moving_average

        concrete_learning_rate_scheduler = (
            None if learning_rate_scheduler is None else
            learning_rate_scheduler.construct(
                optimizer=concrete_optimizer,
                num_epochs=num_epochs,
                num_steps_per_epoch=concrete_data_loader.num_batches_per_epoch(
                ),
            ))
        del learning_rate_scheduler

        concrete_momentum_scheduler = (None if momentum_scheduler is None else
                                       momentum_scheduler.construct(
                                           optimizer=concrete_optimizer))
        del momentum_scheduler

        if checkpointer is not None:
            concrete_checkpointer = checkpointer.construct(
                serialization_dir=serialization_dir)
        else:
            concrete_checkpointer = Checkpointer(serialization_dir)
        del checkpointer

        concrete_callbacks: List[TrainerCallback] = [
            cb.construct(serialization_dir=serialization_dir)
            for cb in callbacks or []
        ]
        del callbacks

        trainer = GradientDescentTrainer(
            concrete_model,
            optimizer=concrete_optimizer,
            data_loader=loader,
            patience=patience,
            validation_metric=validation_metric,
            validation_data_loader=validation_loader,
            num_epochs=num_epochs,
            serialization_dir=serialization_dir,
            checkpointer=concrete_checkpointer,
            grad_norm=grad_norm,
            grad_clipping=grad_clipping,
            learning_rate_scheduler=concrete_learning_rate_scheduler,
            momentum_scheduler=concrete_momentum_scheduler,
            moving_average=concrete_moving_average,
            callbacks=concrete_callbacks,
            num_gradient_accumulation_steps=num_gradient_accumulation_steps,
            use_amp=use_amp,
            enable_default_callbacks=enable_default_callbacks,
            run_confidence_checks=run_confidence_checks,
        )
        trainer.train()

        return trainer.model