def train( self, train_dataset: Dataset, eval_dataset: Optional[Dataset], num_epochs: int ) -> RLTrainingOutput: """ Train the model Returns partially filled RLTrainingOutput. The field that should not be filled are: - output_path - warmstart_output_path - vis_metrics - validation_output """ logger.info("Creating reporter") reporter = DiscreteDQNReporter( self.trainer_param.actions, target_action_distribution=self.target_action_distribution, ) logger.info("Adding reporter to trainer") self.trainer.add_observer(reporter) training_page_handler = TrainingPageHandler(self.trainer) training_page_handler.add_observer(reporter) evaluator = Evaluator( self.action_names, self.rl_parameters.gamma, self.trainer, metrics_to_score=self.metrics_to_score, ) logger.info("Adding reporter to evaluator") evaluator.add_observer(reporter) evaluation_page_handler = EvaluationPageHandler( self.trainer, evaluator, reporter ) batch_preprocessor = self.build_batch_preprocessor() train_and_evaluate_generic( train_dataset, eval_dataset, self.trainer, num_epochs, self.use_gpu, batch_preprocessor, training_page_handler, evaluation_page_handler, reader_options=self.reader_options, ) training_report = RLTrainingReport.make_union_instance( reporter.generate_training_report() ) return RLTrainingOutput(training_report=training_report)
def train( self, train_dataset: Dataset, eval_dataset: Optional[Dataset], num_epochs: int, reader_options: ReaderOptions, ) -> RLTrainingOutput: """ Train the model Returns partially filled RLTrainingOutput. The field that should not be filled are: - output_path """ reporter = DiscreteDQNReporter( self.trainer_param.actions, target_action_distribution=self.target_action_distribution, ) # pyre-fixme[16]: `RLTrainer` has no attribute `add_observer`. self.trainer.add_observer(reporter) evaluator = Evaluator( self.action_names, self.rl_parameters.gamma, self.trainer, metrics_to_score=self.metrics_to_score, ) # pyre-fixme[16]: `Evaluator` has no attribute `add_observer`. evaluator.add_observer(reporter) batch_preprocessor = self.build_batch_preprocessor() train_and_evaluate_generic( train_dataset, eval_dataset, # pyre-fixme[6]: Expected `RLTrainer` for 3rd param but got `Trainer`. # pyre-fixme[6]: Expected `RLTrainer` for 3rd param but got `Trainer`. self.trainer, num_epochs, self.use_gpu, batch_preprocessor, reporter, evaluator, reader_options=self.reader_options, ) # pyre-fixme[16]: `RLTrainingReport` has no attribute `make_union_instance`. training_report = RLTrainingReport.make_union_instance( reporter.generate_training_report() ) return RLTrainingOutput(training_report=training_report)
def get_reporter(self): return DiscreteDQNReporter( self.trainer_param.actions, target_action_distribution=self.target_action_distribution, )