Пример #1
0
    def train(
        self,
        train_dataset: Dataset,
        eval_dataset: Optional[Dataset],
        num_epochs: int,
        reader_options: ReaderOptions,
    ) -> RLTrainingOutput:
        """
        Train the model

        Returns partially filled RLTrainingOutput.
        The field that should not be filled are:
        - output_path
        """
        batch_preprocessor = self.build_batch_preprocessor()
        reporter = self.get_reporter()
        # pyre-fixme[16]: `RLTrainer` has no attribute `add_observer`.
        self.trainer.add_observer(reporter)

        train_eval_lightning(
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            trainer_module=self.trainer,
            num_epochs=num_epochs,
            use_gpu=self.use_gpu,
            batch_preprocessor=batch_preprocessor,
            reader_options=self.reader_options,
            checkpoint_path=self._lightning_checkpoint_path,
        )
        # pyre-fixme[16]: `RLTrainingReport` has no attribute `make_union_instance`.
        training_report = RLTrainingReport.make_union_instance(
            reporter.generate_training_report())
        return RLTrainingOutput(training_report=training_report)
Пример #2
0
    def train(
        self,
        train_dataset: Optional[Dataset],
        eval_dataset: Optional[Dataset],
        data_module: Optional[ReAgentDataModule],
        num_epochs: int,
        reader_options: ReaderOptions,
    ) -> RLTrainingOutput:

        batch_preprocessor = self.build_batch_preprocessor()
        reporter = self.get_reporter()
        # pyre-fixme[16]: `Trainer` has no attribute `set_reporter`.
        # pyre-fixme[16]: `Trainer` has no attribute `set_reporter`.
        self.trainer.set_reporter(reporter)

        # assert eval_dataset is None

        self._lightning_trainer = train_eval_lightning(
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            trainer_module=self.trainer,
            data_module=data_module,
            num_epochs=num_epochs,
            use_gpu=self.use_gpu,
            batch_preprocessor=batch_preprocessor,
            reader_options=self.reader_options,
            checkpoint_path=self._lightning_checkpoint_path,
        )
        # pyre-fixme[16]: `RLTrainingReport` has no attribute `make_union_instance`.
        training_report = RLTrainingReport.make_union_instance(
            reporter.generate_training_report())
        return RLTrainingOutput(training_report=training_report)
Пример #3
0
    def train(
        self,
        train_dataset: Optional[Dataset],
        eval_dataset: Optional[Dataset],
        data_module: Optional[ReAgentDataModule],
        num_epochs: int,
        reader_options: ReaderOptions,
        resource_options: Optional[ResourceOptions] = None,
    ) -> RLTrainingOutput:
        """
        Train the model

        Returns partially filled RLTrainingOutput.
        The field that should not be filled are:
        - output_path
        """
        batch_preprocessor = self.build_batch_preprocessor()
        reporter = self.get_reporter()
        # pyre-fixme[16]: `RLTrainer` has no attribute `set_reporter`.
        self.trainer.set_reporter(reporter)

        train_eval_lightning(
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            trainer_module=self.trainer,
            data_module=None,
            num_epochs=num_epochs,
            use_gpu=self.use_gpu,
            batch_preprocessor=batch_preprocessor,
            reader_options=self.reader_options,
            checkpoint_path=self._lightning_checkpoint_path,
            resource_options=resource_options,
        )
        rank = get_rank()
        if rank == 0:
            # pyre-fixme[16]: `RLTrainingReport` has no attribute `make_union_instance`.
            training_report = RLTrainingReport.make_union_instance(
                reporter.generate_training_report())
            return RLTrainingOutput(training_report=training_report)
        # Output from processes with non-0 rank is not used
        return RLTrainingOutput()
Пример #4
0
    def train(
        self,
        train_dataset: Optional[Dataset],
        eval_dataset: Optional[Dataset],
        test_dataset: Optional[Dataset],
        data_module: Optional[ReAgentDataModule],
        num_epochs: int,
        reader_options: ReaderOptions,
        resource_options: Optional[ResourceOptions] = None,
    ) -> RLTrainingOutput:
        """
        Train the model

        Returns partially filled RLTrainingOutput.
        The field that should not be filled are:
        - output_path
        """
        reporter = self.get_reporter()
        # pyre-fixme[16]: `RLTrainer` has no attribute `set_reporter`.
        self.trainer.set_reporter(reporter)
        assert data_module

        # pyre-fixme[16]: `DiscreteDQNBase` has no attribute `_lightning_trainer`.
        self._lightning_trainer = train_eval_lightning(
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            test_dataset=test_dataset,
            trainer_module=self.trainer,
            data_module=data_module,
            num_epochs=num_epochs,
            use_gpu=self.use_gpu,
            logger_name="DiscreteDqn",
            reader_options=reader_options,
            checkpoint_path=self._lightning_checkpoint_path,
            resource_options=resource_options,
        )
        rank = get_rank()
        if rank == 0:
            # pyre-fixme[16]: `RLTrainingReport` has no attribute `make_union_instance`.
            training_report = RLTrainingReport.make_union_instance(
                reporter.generate_training_report())
            logger_data = self._lightning_trainer.logger.line_plot_aggregated
            self._lightning_trainer.logger.clear_local_data()
            return RLTrainingOutput(training_report=training_report,
                                    logger_data=logger_data)
        # Output from processes with non-0 rank is not used
        return RLTrainingOutput()
Пример #5
0
    def train(
        self,
        train_dataset: Optional[Dataset],
        eval_dataset: Optional[Dataset],
        test_dataset: Optional[Dataset],
        data_module: Optional[ReAgentDataModule],
        num_epochs: int,
        reader_options: ReaderOptions,
        resource_options: ResourceOptions,
    ) -> RLTrainingOutput:
        batch_preprocessor = self.build_batch_preprocessor(
            resource_options.use_gpu)
        reporter = self.get_reporter()
        # pyre-fixme[16]: `Trainer` has no attribute `set_reporter`.
        # pyre-fixme[16]: `Trainer` has no attribute `set_reporter`.
        self.trainer.set_reporter(reporter)

        # assert eval_dataset is None

        # pyre-fixme[16]: `ActorCriticBase` has no attribute `_lightning_trainer`.
        self._lightning_trainer = train_eval_lightning(
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            test_dataset=test_dataset,
            trainer_module=self.trainer,
            data_module=data_module,
            num_epochs=num_epochs,
            logger_name="ActorCritic",
            batch_preprocessor=batch_preprocessor,
            reader_options=self.reader_options,
            checkpoint_path=self._lightning_checkpoint_path,
            resource_options=resource_options or ResourceOptions(),
        )
        if reporter is None:
            training_report = None
        else:
            # pyre-fixme[16]: `RLTrainingReport` has no attribute `make_union_instance`.
            training_report = RLTrainingReport.make_union_instance(
                reporter.generate_training_report())
        logger_data = self._lightning_trainer.logger.line_plot_aggregated
        self._lightning_trainer.logger.clear_local_data()
        return RLTrainingOutput(training_report=training_report,
                                logger_data=logger_data)
Пример #6
0
    def train(
        self,
        trainer_module: ReAgentLightningModule,
        train_dataset: Optional[Dataset],
        eval_dataset: Optional[Dataset],
        test_dataset: Optional[Dataset],
        data_module: Optional[ReAgentDataModule],
        num_epochs: int,
        reader_options: ReaderOptions,
        resource_options: ResourceOptions,
        checkpoint_path: Optional[str] = None,
    ) -> Tuple[RLTrainingOutput, pl.Trainer]:
        """
        Train the model

        Returns partially filled RLTrainingOutput.
        The field that should not be filled are:
        - output_path

        Arguments:
            train/eval/test_dataset: what you'd expect
            data_module: [pytorch lightning only] a lightning data module that replaces the use of train/eval datasets
            num_epochs: number of training epochs
            reader_options: options for the data reader
            resource_options: options for training resources (currently only used for setting num_nodes in pytorch lightning trainer)
        """
        if isinstance(trainer_module, MultiStageTrainer):
            assert trainer_module.multi_stage_total_epochs == num_epochs, (
                f"The sum of each stage's epoch ({trainer_module.trainer_epoch_mapping})"
                f" should be equal to num_epochs ({num_epochs})."
            )

        reporter = self.get_reporter()
        trainer_module.set_reporter(reporter)
        assert data_module

        lightning_trainer = train_eval_lightning(
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            test_dataset=test_dataset,
            trainer_module=trainer_module,
            data_module=data_module,
            num_epochs=num_epochs,
            logger_name=str(type(self)),
            reader_options=reader_options,
            checkpoint_path=checkpoint_path,
            resource_options=resource_options,
        )

        rank = get_rank()
        if rank == 0:
            # pyre-ignore
            trainer_logger = lightning_trainer.logger
            logger_data = trainer_logger.line_plot_aggregated
            trainer_logger.clear_local_data()
            if reporter is None:
                training_report = None
            else:
                # pyre-ignore
                training_report = RLTrainingReport.make_union_instance(
                    reporter.generate_training_report()
                )
            return (
                RLTrainingOutput(
                    training_report=training_report, logger_data=logger_data
                ),
                lightning_trainer,
            )
        # Output from processes with non-0 rank is not used
        return RLTrainingOutput(), lightning_trainer