Exemplo n.º 1
0
    def test_trainer_can_resume_with_lr_scheduler(self):
        lr_scheduler = LearningRateScheduler.from_params(
                self.optimizer, Params({"type": "exponential", "gamma": 0.5}))
        callbacks = self.default_callbacks() + [UpdateLearningRate(lr_scheduler)]

        trainer = CallbackTrainer(model=self.model,
                                  training_data=self.instances,
                                  iterator=self.iterator,
                                  optimizer=self.optimizer,
                                  callbacks=callbacks,
                                  num_epochs=2, serialization_dir=self.TEST_DIR)
        trainer.train()

        new_lr_scheduler = LearningRateScheduler.from_params(
                self.optimizer, Params({"type": "exponential", "gamma": 0.5}))
        callbacks = self.default_callbacks() + [UpdateLearningRate(new_lr_scheduler)]

        new_trainer = CallbackTrainer(model=self.model,
                                      training_data=self.instances,
                                      iterator=self.iterator,
                                      optimizer=self.optimizer,
                                      callbacks=callbacks,
                                      num_epochs=4, serialization_dir=self.TEST_DIR)
        new_trainer.handler.fire_event(Events.TRAINING_START)
        assert new_trainer.epoch_number == 2
        assert new_lr_scheduler.lr_scheduler.last_epoch == 1
        new_trainer.train()
Exemplo n.º 2
0
    def test_trainer_can_run_with_lr_scheduler(self):
        lr_params = Params({"type": "reduce_on_plateau"})
        lr_scheduler = LearningRateScheduler.from_params(self.optimizer, lr_params)
        callbacks = self.default_callbacks() + [UpdateLearningRate(lr_scheduler)]

        trainer = CallbackTrainer(model=self.model,
                                  optimizer=self.optimizer,
                                  callbacks=callbacks,
                                  num_epochs=2)
        trainer.train()
Exemplo n.º 3
0
    def test_model_training(self):
        training_dataset = self.sample_instances if self.sample_only else self.train_instances
        #training_dataset = training_dataset[:500]
        validation_dataset = self.sample_instances if self.sample_only else self.test_instances
        serialization_dir = self.TEST_DATA_ROOT / "serialized_sample" if self.sample_only else "serialized"
        tensorboard_dir = self.TEST_DATA_ROOT / "tensorboard.seq2seq"

        batch_size = 64

        train_iterator = BucketIterator(sorting_keys=[("source_tokens",
                                                       "num_tokens")],
                                        padding_noise=0.1,
                                        batch_size=batch_size)
        train_iterator.index_with(vocab=self.vocab)
        multiproc_iterator = MultiprocessIterator(train_iterator,
                                                  num_workers=4,
                                                  output_queue_size=6000)

        tensorboard = TensorboardWriter(get_batch_num_total=lambda: np.ceil(
            len(training_dataset) / batch_size),
                                        serialization_dir=tensorboard_dir,
                                        summary_interval=5,
                                        histogram_interval=5,
                                        should_log_parameter_statistics=True,
                                        should_log_learning_rate=True)

        optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
        scheduler = CosineWithRestarts(optimizer=optimizer, t_initial=5)

        trainer = CallbackTrainer(
            model=self.model,
            serialization_dir=serialization_dir,
            iterator=multiproc_iterator,
            training_data=self.train_instances,
            num_epochs=100,
            cuda_device=0,
            optimizer=optimizer,
            callbacks=[
                LogToTensorboard(tensorboard),
                Validate(validation_data=self.test_instances,
                         validation_iterator=multiproc_iterator),
                TrackMetrics(),
                ResetMetricsCallback(),
                UpdateLearningRate(scheduler),
                ValidationLogCallback(self.train_reader, self.test_instances)
            ])

        # trainer = Trainer(model=self.model,
        #                   serialization_dir=serialization_dir,
        #                   iterator=train_iterator,
        #                   train_dataset=training_dataset,
        #                   num_epochs=1,
        #                   cuda_device=0,
        #                   optimizer=torch.optim.Adam(self.model.parameters(), lr=1e-3),
        #                   validation_dataset=training_dataset,
        #                   validation_iterator=train_iterator,
        #                   should_log_learning_rate=True,
        #                   learning_rate_scheduler=scheduler
        #                   )

        # for i in range(50):
        #     print('Epoch: {}'.format(i))
        #     trainer.train()
        #
        #     import itertools
        #
        #     predictor = Seq2SeqPredictor(self.model, self.train_reader)
        #
        #     for instance in itertools.islice(training_dataset, 10):
        #         print('SOURCE:', instance.fields['source_tokens'].tokens)
        #         print('GOLD:', instance.fields['target_tokens'].tokens)
        #         print('PRED:', predictor.predict_instance(instance)['predicted_tokens'])
        #
        # self.val_outputs_fp.close()

        trainer.train()