def training_loop(self) -> None: scaling_config_dataclass = self._validate_and_get_scaling_config_data_class( self.scaling_config ) train_loop_per_worker = construct_train_func( self._train_loop_per_worker, self._train_loop_config, fn_arg_name="train_loop_per_worker", ) additional_resources_per_worker = ( scaling_config_dataclass.additional_resources_per_worker ) trial_info = TrialInfo( name=session.get_trial_name(), id=session.get_trial_id(), resources=session.get_trial_resources(), logdir=os.getcwd(), ) backend_executor = BackendExecutor( backend_config=self._backend_config, trial_info=trial_info, num_workers=scaling_config_dataclass.num_workers, num_cpus_per_worker=scaling_config_dataclass.num_cpus_per_worker, num_gpus_per_worker=scaling_config_dataclass.num_gpus_per_worker, additional_resources_per_worker=additional_resources_per_worker, max_retries=0, ) checkpoint_manager = self._checkpoint_manager_cls( preprocessor=self.preprocessor ) # Start the remote actors. backend_executor.start(initialization_hook=None) training_iterator = TrainingIterator( backend_executor=backend_executor, backend_config=self._backend_config, train_func=train_loop_per_worker, dataset_spec=self._ingest_spec, checkpoint_manager=checkpoint_manager, checkpoint=self.resume_from_checkpoint, checkpoint_strategy=None, ) for results in training_iterator: # TODO(ml-team): add ability to report results from multiple workers. first_worker_results = results[0] tune.report(**first_worker_results) # Shutdown workers. backend_executor.shutdown()
def train_fn(config): start_epoch = 0 print(session.get_trial_resources()) checkpoint = session.get_checkpoint() if checkpoint: # assume that we have run the session.report() example # and successfully save some model weights checkpoint_dict = checkpoint.to_dict() start_epoch = checkpoint_dict.get("epoch", -1) + 1 # wrap the model in DDP for epoch in range(start_epoch, config["num_epochs"]): checkpoint = Checkpoint.from_dict(dict(epoch=epoch)) session.report( { "metric": config["metric"] * epoch, "epoch": epoch, "num_cpus": session.get_trial_resources().required_resources["CPU"], }, checkpoint=checkpoint, )
def _ray_params(self): scaling_config = self._validate_scaling_config(self.scaling_config) assert ( scaling_config.as_placement_group_factory() == session.get_trial_resources() ) return super()._ray_params
def training_loop(self) -> None: scaling_config = self._validate_scaling_config(self.scaling_config) pgf = scaling_config.as_placement_group_factory() tr = session.get_trial_resources() assert pgf == tr, (pgf, tr) return super().training_loop()