示例#1
0
    def train(
        self, train_dataset: Dataset, eval_dataset: Optional[Dataset], num_epochs: int
    ) -> RLTrainingOutput:
        """
        Train the model

        Returns partially filled RLTrainingOutput. The field that should not be filled
        are:
        - output_path
        - warmstart_output_path
        - vis_metrics
        - validation_output
        """
        logger.info("Creating reporter")
        reporter = DiscreteDQNReporter(
            self.trainer_param.actions,
            target_action_distribution=self.target_action_distribution,
        )
        logger.info("Adding reporter to trainer")
        self.trainer.add_observer(reporter)

        training_page_handler = TrainingPageHandler(self.trainer)
        training_page_handler.add_observer(reporter)
        evaluator = Evaluator(
            self.action_names,
            self.rl_parameters.gamma,
            self.trainer,
            metrics_to_score=self.metrics_to_score,
        )
        logger.info("Adding reporter to evaluator")
        evaluator.add_observer(reporter)
        evaluation_page_handler = EvaluationPageHandler(
            self.trainer, evaluator, reporter
        )

        batch_preprocessor = self.build_batch_preprocessor()
        train_and_evaluate_generic(
            train_dataset,
            eval_dataset,
            self.trainer,
            num_epochs,
            self.use_gpu,
            batch_preprocessor,
            training_page_handler,
            evaluation_page_handler,
            reader_options=self.reader_options,
        )
        training_report = RLTrainingReport.make_union_instance(
            reporter.generate_training_report()
        )
        return RLTrainingOutput(training_report=training_report)
示例#2
0
    def train(
        self,
        train_dataset: Dataset,
        eval_dataset: Optional[Dataset],
        num_epochs: int,
        reader_options: ReaderOptions,
    ) -> RLTrainingOutput:
        """
        Train the model

        Returns partially filled RLTrainingOutput.
        The field that should not be filled are:
        - output_path
        """
        reporter = DiscreteDQNReporter(
            self.trainer_param.actions,
            target_action_distribution=self.target_action_distribution,
        )
        # pyre-fixme[16]: `RLTrainer` has no attribute `add_observer`.
        self.trainer.add_observer(reporter)

        evaluator = Evaluator(
            self.action_names,
            self.rl_parameters.gamma,
            self.trainer,
            metrics_to_score=self.metrics_to_score,
        )
        # pyre-fixme[16]: `Evaluator` has no attribute `add_observer`.
        evaluator.add_observer(reporter)

        batch_preprocessor = self.build_batch_preprocessor()
        train_and_evaluate_generic(
            train_dataset,
            eval_dataset,
            # pyre-fixme[6]: Expected `RLTrainer` for 3rd param but got `Trainer`.
            # pyre-fixme[6]: Expected `RLTrainer` for 3rd param but got `Trainer`.
            self.trainer,
            num_epochs,
            self.use_gpu,
            batch_preprocessor,
            reporter,
            evaluator,
            reader_options=self.reader_options,
        )
        # pyre-fixme[16]: `RLTrainingReport` has no attribute `make_union_instance`.
        training_report = RLTrainingReport.make_union_instance(
            reporter.generate_training_report()
        )
        return RLTrainingOutput(training_report=training_report)
示例#3
0
 def get_reporter(self):
     return DiscreteDQNReporter(
         self.trainer_param.actions,
         target_action_distribution=self.target_action_distribution,
     )