def start_train( self, checkpoint: Optional[str] = None, restart_pipeline: bool = False, max_sampler_processes_per_worker: Optional[int] = None, ): if not self.disable_config_saving: self.save_project_state() devices = self.worker_devices("train") num_workers = len(devices) # Be extra careful to ensure that all models start # with the same initializations. set_seed(self.seed) initial_model_state_dict = self.config.create_model( sensor_preprocessor_graph=MachineParams.instance_from( self.config.machine_params( self.mode)).sensor_preprocessor_graph).state_dict() distributed_port = 0 if num_workers > 1: distributed_port = find_free_port() for trainer_it in range(num_workers): train: BaseProcess = self.mp_ctx.Process( target=self.train_loop, kwargs=dict( id=trainer_it, checkpoint=checkpoint, restart_pipeline=restart_pipeline, experiment_name=self.experiment_name, config=self.config, results_queue=self.queues["results"], checkpoints_queue=self.queues["checkpoints"] if self.running_validation else None, checkpoints_dir=self.checkpoint_dir(), seed=self.seed, deterministic_cudnn=self.deterministic_cudnn, mp_ctx=self.mp_ctx, num_workers=num_workers, device=devices[trainer_it], distributed_port=distributed_port, max_sampler_processes_per_worker= max_sampler_processes_per_worker, initial_model_state_dict=initial_model_state_dict, ), ) train.start() self.processes["train"].append(train) get_logger().info("Started {} train processes".format( len(self.processes["train"]))) # Validation if self.running_validation: device = self.worker_devices("valid")[0] self.init_visualizer("valid") valid: BaseProcess = self.mp_ctx.Process( target=self.valid_loop, args=(0, ), kwargs=dict( config=self.config, results_queue=self.queues["results"], checkpoints_queue=self.queues["checkpoints"], seed= 12345, # TODO allow same order for randomly sampled tasks? Is this any useful anyway? deterministic_cudnn=self.deterministic_cudnn, deterministic_agents=self.deterministic_agents, mp_ctx=self.mp_ctx, device=device, max_sampler_processes_per_worker= max_sampler_processes_per_worker, ), ) valid.start() self.processes["valid"].append(valid) get_logger().info("Started {} valid processes".format( len(self.processes["valid"]))) else: get_logger().info( "No processes allocated to validation, no validation will be run." ) self.log(self.local_start_time_str, num_workers) return self.local_start_time_str
def start_test( self, experiment_date: str, checkpoint_name_fragment: Optional[str] = None, approx_ckpt_steps_count: Optional[Union[float, int]] = None, skip_checkpoints: int = 0, max_sampler_processes_per_worker: Optional[int] = None, ): devices = self.worker_devices("test") self.init_visualizer("test") num_testers = len(devices) distributed_port = 0 if num_testers > 1: distributed_port = find_free_port() for tester_it in range(num_testers): test: BaseProcess = self.mp_ctx.Process( target=self.test_loop, args=(tester_it, ), kwargs=dict( config=self.config, results_queue=self.queues["results"], checkpoints_queue=self.queues["checkpoints"], seed= 12345, # TODO allow same order for randomly sampled tasks? Is this any useful anyway? deterministic_cudnn=self.deterministic_cudnn, deterministic_agents=self.deterministic_agents, mp_ctx=self.mp_ctx, num_workers=num_testers, device=devices[tester_it], max_sampler_processes_per_worker= max_sampler_processes_per_worker, distributed_port=distributed_port, ), ) test.start() self.processes["test"].append(test) get_logger().info("Started {} test processes".format( len(self.processes["test"]))) checkpoint_paths = self.get_checkpoint_files( experiment_date=experiment_date, checkpoint_name_fragment=checkpoint_name_fragment, approx_ckpt_steps_count=approx_ckpt_steps_count, skip_checkpoints=skip_checkpoints, ) steps = [self.step_from_checkpoint(cp) for cp in checkpoint_paths] get_logger().info("Running test on {} steps {}".format( len(steps), steps)) for checkpoint_path in checkpoint_paths: # Make all testers work on each checkpoint for tester_it in range(num_testers): self.queues["checkpoints"].put(("eval", checkpoint_path)) # Signal all testers to terminate cleanly for _ in range(num_testers): self.queues["checkpoints"].put(("quit", None)) metrics_dir = self.metric_path(experiment_date) os.makedirs(metrics_dir, exist_ok=True) suffix = "__test_{}".format(self.local_start_time_str) metrics_file_path = os.path.join(metrics_dir, "metrics" + suffix + ".json") get_logger().info("Saving metrics in {}".format(metrics_file_path)) # Check output file can be written with open(metrics_file_path, "w") as f: json.dump([], f, indent=4, sort_keys=True, cls=NumpyJSONEncoder) return self.log( start_time_str=self.checkpoint_start_time_str(checkpoint_paths[0]), nworkers=num_testers, test_steps=steps, metrics_file=metrics_file_path, )
def start_train( self, checkpoint: Optional[str] = None, restart_pipeline: bool = False, max_sampler_processes_per_worker: Optional[int] = None, ): self._initialize_start_train_or_start_test() if not self.disable_config_saving: self.save_project_state() devices = self.worker_devices(TRAIN_MODE_STR) num_workers = len(devices) # Be extra careful to ensure that all models start # with the same initializations. set_seed(self.seed) initial_model_state_dict = self.config.create_model( sensor_preprocessor_graph=MachineParams.instance_from( self.config.machine_params( self.mode)).sensor_preprocessor_graph).state_dict() distributed_port = 0 if num_workers > 1: distributed_port = find_free_port() model_hash = None for trainer_it in range(num_workers): training_kwargs = dict( id=trainer_it, checkpoint=checkpoint, restart_pipeline=restart_pipeline, experiment_name=self.experiment_name, config=self.config, results_queue=self.queues["results"], checkpoints_queue=self.queues["checkpoints"] if self.running_validation else None, checkpoints_dir=self.checkpoint_dir(), seed=self.seed, deterministic_cudnn=self.deterministic_cudnn, mp_ctx=self.mp_ctx, num_workers=num_workers, device=devices[trainer_it], distributed_port=distributed_port, max_sampler_processes_per_worker= max_sampler_processes_per_worker, initial_model_state_dict=initial_model_state_dict if model_hash is None else model_hash, ) train: BaseProcess = self.mp_ctx.Process( target=self.train_loop, kwargs=training_kwargs, ) try: train.start() except ValueError as e: # If the `initial_model_state_dict` is too large we sometimes # run into errors passing it with multiprocessing. In such cases # we instead has the state_dict and confirm, in each engine worker, that # this hash equals the model the engine worker instantiates. if e.args[0] == "too many fds": model_hash = md5_hash_of_state_dict( initial_model_state_dict) training_kwargs["initial_model_state_dict"] = model_hash train = self.mp_ctx.Process( target=self.train_loop, kwargs=training_kwargs, ) train.start() else: raise e self.processes[TRAIN_MODE_STR].append(train) get_logger().info("Started {} train processes".format( len(self.processes[TRAIN_MODE_STR]))) # Validation if self.running_validation: device = self.worker_devices("valid")[0] self.init_visualizer("valid") valid: BaseProcess = self.mp_ctx.Process( target=self.valid_loop, args=(0, ), kwargs=dict( config=self.config, results_queue=self.queues["results"], checkpoints_queue=self.queues["checkpoints"], seed= 12345, # TODO allow same order for randomly sampled tasks? Is this any useful anyway? deterministic_cudnn=self.deterministic_cudnn, deterministic_agents=self.deterministic_agents, mp_ctx=self.mp_ctx, device=device, max_sampler_processes_per_worker= max_sampler_processes_per_worker, ), ) valid.start() self.processes["valid"].append(valid) get_logger().info("Started {} valid processes".format( len(self.processes["valid"]))) else: get_logger().info( "No processes allocated to validation, no validation will be run." ) self.log_and_close(self.local_start_time_str, num_workers) return self.local_start_time_str