def train( self, train_dataset: Dataset, eval_dataset: Optional[Dataset], num_epochs: int, reader_options: ReaderOptions, ) -> RLTrainingOutput: """ Train the model Returns partially filled RLTrainingOutput. The field that should not be filled are: - output_path """ batch_preprocessor = self.build_batch_preprocessor() reporter = self.get_reporter() # pyre-fixme[16]: `RLTrainer` has no attribute `add_observer`. self.trainer.add_observer(reporter) train_eval_lightning( train_dataset=train_dataset, eval_dataset=eval_dataset, trainer_module=self.trainer, num_epochs=num_epochs, use_gpu=self.use_gpu, batch_preprocessor=batch_preprocessor, reader_options=self.reader_options, checkpoint_path=self._lightning_checkpoint_path, ) # pyre-fixme[16]: `RLTrainingReport` has no attribute `make_union_instance`. training_report = RLTrainingReport.make_union_instance( reporter.generate_training_report()) return RLTrainingOutput(training_report=training_report)
def train( self, train_dataset: Optional[Dataset], eval_dataset: Optional[Dataset], data_module: Optional[ReAgentDataModule], num_epochs: int, reader_options: ReaderOptions, ) -> RLTrainingOutput: batch_preprocessor = self.build_batch_preprocessor() reporter = self.get_reporter() # pyre-fixme[16]: `Trainer` has no attribute `set_reporter`. # pyre-fixme[16]: `Trainer` has no attribute `set_reporter`. self.trainer.set_reporter(reporter) # assert eval_dataset is None self._lightning_trainer = train_eval_lightning( train_dataset=train_dataset, eval_dataset=eval_dataset, trainer_module=self.trainer, data_module=data_module, num_epochs=num_epochs, use_gpu=self.use_gpu, batch_preprocessor=batch_preprocessor, reader_options=self.reader_options, checkpoint_path=self._lightning_checkpoint_path, ) # pyre-fixme[16]: `RLTrainingReport` has no attribute `make_union_instance`. training_report = RLTrainingReport.make_union_instance( reporter.generate_training_report()) return RLTrainingOutput(training_report=training_report)
def train(self, train_dataset: Dataset, eval_dataset: Optional[Dataset], num_epochs: int) -> RLTrainingOutput: reporter = ActorCriticReporter() # pyre-fixme[16]: `RLTrainer` has no attribute `add_observer`. self.trainer.add_observer(reporter) evaluator = Evaluator( action_names=None, # pyre-fixme[16]: `ActorCriticBase` has no attribute `rl_parameters`. gamma=self.rl_parameters.gamma, model=self.trainer, metrics_to_score=self.metrics_to_score, ) # pyre-fixme[16]: `Evaluator` has no attribute `add_observer`. evaluator.add_observer(reporter) batch_preprocessor = self.build_batch_preprocessor() train_and_evaluate_generic( train_dataset=train_dataset, eval_dataset=eval_dataset, trainer=self.trainer, num_epochs=num_epochs, use_gpu=self.use_gpu, batch_preprocessor=batch_preprocessor, reporter=reporter, evaluator=evaluator, reader_options=self.reader_options, ) # pyre-fixme[16]: `RLTrainingReport` has no attribute `make_union_instance`. training_report = RLTrainingReport.make_union_instance( reporter.generate_training_report()) return RLTrainingOutput(training_report=training_report)
def train( self, train_dataset: Dataset, eval_dataset: Optional[Dataset], num_epochs: int ) -> RLTrainingOutput: """ Train the model Returns partially filled RLTrainingOutput. The field that should not be filled are: - output_path - warmstart_output_path - vis_metrics - validation_output """ logger.info("Creating reporter") reporter = DiscreteDQNReporter( self.trainer_param.actions, target_action_distribution=self.target_action_distribution, ) logger.info("Adding reporter to trainer") self.trainer.add_observer(reporter) training_page_handler = TrainingPageHandler(self.trainer) training_page_handler.add_observer(reporter) evaluator = Evaluator( self.action_names, self.rl_parameters.gamma, self.trainer, metrics_to_score=self.metrics_to_score, ) logger.info("Adding reporter to evaluator") evaluator.add_observer(reporter) evaluation_page_handler = EvaluationPageHandler( self.trainer, evaluator, reporter ) batch_preprocessor = self.build_batch_preprocessor() train_and_evaluate_generic( train_dataset, eval_dataset, self.trainer, num_epochs, self.use_gpu, batch_preprocessor, training_page_handler, evaluation_page_handler, reader_options=self.reader_options, ) training_report = RLTrainingReport.make_union_instance( reporter.generate_training_report() ) return RLTrainingOutput(training_report=training_report)
def train( self, train_dataset: Dataset, eval_dataset: Optional[Dataset], num_epochs: int, reader_options: ReaderOptions, ) -> RLTrainingOutput: """ Train the model Returns partially filled RLTrainingOutput. The field that should not be filled are: - output_path """ reporter = DiscreteDQNReporter( self.trainer_param.actions, target_action_distribution=self.target_action_distribution, ) # pyre-fixme[16]: `RLTrainer` has no attribute `add_observer`. self.trainer.add_observer(reporter) evaluator = Evaluator( self.action_names, self.rl_parameters.gamma, self.trainer, metrics_to_score=self.metrics_to_score, ) # pyre-fixme[16]: `Evaluator` has no attribute `add_observer`. evaluator.add_observer(reporter) batch_preprocessor = self.build_batch_preprocessor() train_and_evaluate_generic( train_dataset, eval_dataset, # pyre-fixme[6]: Expected `RLTrainer` for 3rd param but got `Trainer`. # pyre-fixme[6]: Expected `RLTrainer` for 3rd param but got `Trainer`. self.trainer, num_epochs, self.use_gpu, batch_preprocessor, reporter, evaluator, reader_options=self.reader_options, ) # pyre-fixme[16]: `RLTrainingReport` has no attribute `make_union_instance`. training_report = RLTrainingReport.make_union_instance( reporter.generate_training_report() ) return RLTrainingOutput(training_report=training_report)
def train( self, train_dataset: Optional[Dataset], eval_dataset: Optional[Dataset], test_dataset: Optional[Dataset], data_module: Optional[ReAgentDataModule], num_epochs: int, reader_options: ReaderOptions, resource_options: Optional[ResourceOptions] = None, ) -> RLTrainingOutput: """ Train the model Returns partially filled RLTrainingOutput. The field that should not be filled are: - output_path """ reporter = self.get_reporter() # pyre-fixme[16]: `RLTrainer` has no attribute `set_reporter`. self.trainer.set_reporter(reporter) assert data_module # pyre-fixme[16]: `DiscreteDQNBase` has no attribute `_lightning_trainer`. self._lightning_trainer = train_eval_lightning( train_dataset=train_dataset, eval_dataset=eval_dataset, test_dataset=test_dataset, trainer_module=self.trainer, data_module=data_module, num_epochs=num_epochs, use_gpu=self.use_gpu, logger_name="DiscreteDqn", reader_options=reader_options, checkpoint_path=self._lightning_checkpoint_path, resource_options=resource_options, ) rank = get_rank() if rank == 0: # pyre-fixme[16]: `RLTrainingReport` has no attribute `make_union_instance`. training_report = RLTrainingReport.make_union_instance( reporter.generate_training_report()) logger_data = self._lightning_trainer.logger.line_plot_aggregated self._lightning_trainer.logger.clear_local_data() return RLTrainingOutput(training_report=training_report, logger_data=logger_data) # Output from processes with non-0 rank is not used return RLTrainingOutput()
def train( self, train_dataset: Optional[Dataset], eval_dataset: Optional[Dataset], test_dataset: Optional[Dataset], data_module: Optional[ReAgentDataModule], num_epochs: int, reader_options: ReaderOptions, resource_options: ResourceOptions, ) -> RLTrainingOutput: batch_preprocessor = self.build_batch_preprocessor( resource_options.use_gpu) reporter = self.get_reporter() # pyre-fixme[16]: `Trainer` has no attribute `set_reporter`. # pyre-fixme[16]: `Trainer` has no attribute `set_reporter`. self.trainer.set_reporter(reporter) # assert eval_dataset is None # pyre-fixme[16]: `ActorCriticBase` has no attribute `_lightning_trainer`. self._lightning_trainer = train_eval_lightning( train_dataset=train_dataset, eval_dataset=eval_dataset, test_dataset=test_dataset, trainer_module=self.trainer, data_module=data_module, num_epochs=num_epochs, logger_name="ActorCritic", batch_preprocessor=batch_preprocessor, reader_options=self.reader_options, checkpoint_path=self._lightning_checkpoint_path, resource_options=resource_options or ResourceOptions(), ) if reporter is None: training_report = None else: # pyre-fixme[16]: `RLTrainingReport` has no attribute `make_union_instance`. training_report = RLTrainingReport.make_union_instance( reporter.generate_training_report()) logger_data = self._lightning_trainer.logger.line_plot_aggregated self._lightning_trainer.logger.clear_local_data() return RLTrainingOutput(training_report=training_report, logger_data=logger_data)
def train( self, train_dataset: Optional[Dataset], eval_dataset: Optional[Dataset], data_module: Optional[ReAgentDataModule], num_epochs: int, reader_options: ReaderOptions, resource_options: Optional[ResourceOptions] = None, ) -> RLTrainingOutput: """ Train the model Returns partially filled RLTrainingOutput. The field that should not be filled are: - output_path """ batch_preprocessor = self.build_batch_preprocessor() reporter = self.get_reporter() # pyre-fixme[16]: `RLTrainer` has no attribute `set_reporter`. self.trainer.set_reporter(reporter) train_eval_lightning( train_dataset=train_dataset, eval_dataset=eval_dataset, trainer_module=self.trainer, data_module=None, num_epochs=num_epochs, use_gpu=self.use_gpu, batch_preprocessor=batch_preprocessor, reader_options=self.reader_options, checkpoint_path=self._lightning_checkpoint_path, resource_options=resource_options, ) rank = get_rank() if rank == 0: # pyre-fixme[16]: `RLTrainingReport` has no attribute `make_union_instance`. training_report = RLTrainingReport.make_union_instance( reporter.generate_training_report()) return RLTrainingOutput(training_report=training_report) # Output from processes with non-0 rank is not used return RLTrainingOutput()
def train( self, trainer_module: ReAgentLightningModule, train_dataset: Optional[Dataset], eval_dataset: Optional[Dataset], test_dataset: Optional[Dataset], data_module: Optional[ReAgentDataModule], num_epochs: int, reader_options: ReaderOptions, resource_options: ResourceOptions, checkpoint_path: Optional[str] = None, ) -> Tuple[RLTrainingOutput, pl.Trainer]: """ Train the model Returns partially filled RLTrainingOutput. The field that should not be filled are: - output_path Arguments: train/eval/test_dataset: what you'd expect data_module: [pytorch lightning only] a lightning data module that replaces the use of train/eval datasets num_epochs: number of training epochs reader_options: options for the data reader resource_options: options for training resources (currently only used for setting num_nodes in pytorch lightning trainer) """ if isinstance(trainer_module, MultiStageTrainer): assert trainer_module.multi_stage_total_epochs == num_epochs, ( f"The sum of each stage's epoch ({trainer_module.trainer_epoch_mapping})" f" should be equal to num_epochs ({num_epochs})." ) reporter = self.get_reporter() trainer_module.set_reporter(reporter) assert data_module lightning_trainer = train_eval_lightning( train_dataset=train_dataset, eval_dataset=eval_dataset, test_dataset=test_dataset, trainer_module=trainer_module, data_module=data_module, num_epochs=num_epochs, logger_name=str(type(self)), reader_options=reader_options, checkpoint_path=checkpoint_path, resource_options=resource_options, ) rank = get_rank() if rank == 0: # pyre-ignore trainer_logger = lightning_trainer.logger logger_data = trainer_logger.line_plot_aggregated trainer_logger.clear_local_data() if reporter is None: training_report = None else: # pyre-ignore training_report = RLTrainingReport.make_union_instance( reporter.generate_training_report() ) return ( RLTrainingOutput( training_report=training_report, logger_data=logger_data ), lightning_trainer, ) # Output from processes with non-0 rank is not used return RLTrainingOutput(), lightning_trainer