def start_test( self, experiment_date: str, checkpoint: Optional[str] = None, skip_checkpoints: int = 0, max_sampler_processes_per_worker: Optional[int] = None, ): devices = self.worker_devices("test") self.init_visualizer("test") num_testers = len(devices) distributed_port = 0 if num_testers > 1: distributed_port = find_free_port() for tester_it in range(num_testers): test: BaseProcess = self.mp_ctx.Process( target=self.test_loop, args=(tester_it, ), kwargs=dict( config=self.config, results_queue=self.queues["results"], checkpoints_queue=self.queues["checkpoints"], seed= 12345, # TODO allow same order for randomly sampled tasks? Is this any useful anyway? deterministic_cudnn=self.deterministic_cudnn, deterministic_agents=self.deterministic_agents, mp_ctx=self.mp_ctx, num_workers=num_testers, device=devices[tester_it], max_sampler_processes_per_worker= max_sampler_processes_per_worker, distributed_port=distributed_port, ), ) test.start() self.processes["test"].append(test) get_logger().info("Started {} test processes".format( len(self.processes["test"]))) checkpoints = self.get_checkpoint_files(experiment_date, checkpoint, skip_checkpoints) steps = [self.step_from_checkpoint(cp) for cp in checkpoints] get_logger().info("Running test on {} steps {}".format( len(steps), steps)) for checkpoint in checkpoints: # Make all testers work on each checkpoint for tester_it in range(num_testers): self.queues["checkpoints"].put(("eval", checkpoint)) # Signal all testers to terminate cleanly for _ in range(num_testers): self.queues["checkpoints"].put(("quit", None)) metric_folder = self.metric_path(experiment_date) os.makedirs(metric_folder, exist_ok=True) suffix = "__test_{}".format(self.local_start_time_str) fname = os.path.join(metric_folder, "metrics" + suffix + ".json") get_logger().info("Saving metrics in {}".format(fname)) # Check output file can be written with open(fname, "w") as f: json.dump([], f, indent=4, sort_keys=True) return self.log(self.checkpoint_start_time_str(checkpoints[0]), num_testers, steps, fname)
def start_train( self, checkpoint: Optional[str] = None, restart_pipeline: bool = False, max_sampler_processes_per_worker: Optional[int] = None, ): self.save_project_state() devices = self.worker_devices("train") num_workers = len(devices) seed = ( self.seed ) # same for all workers. used during initialization of the model distributed_port = 0 if num_workers > 1: distributed_port = find_free_port() for trainer_it in range(num_workers): train: BaseProcess = self.mp_ctx.Process( target=self.train_loop, kwargs=dict( id=trainer_it, checkpoint=checkpoint, restart_pipeline=restart_pipeline, experiment_name=self.experiment_name, config=self.config, results_queue=self.queues["results"], checkpoints_queue=self.queues["checkpoints"] if self.running_validation else None, checkpoints_dir=self.checkpoint_dir(), seed=seed, deterministic_cudnn=self.deterministic_cudnn, mp_ctx=self.mp_ctx, num_workers=num_workers, device=devices[trainer_it], distributed_port=distributed_port, max_sampler_processes_per_worker= max_sampler_processes_per_worker, ), ) train.start() self.processes["train"].append(train) get_logger().info("Started {} train processes".format( len(self.processes["train"]))) # Validation if self.running_validation: device = self.worker_devices("valid")[0] self.init_visualizer("valid") valid: BaseProcess = self.mp_ctx.Process( target=self.valid_loop, args=(0, ), kwargs=dict( config=self.config, results_queue=self.queues["results"], checkpoints_queue=self.queues["checkpoints"], seed= 12345, # TODO allow same order for randomly sampled tasks? Is this any useful anyway? deterministic_cudnn=self.deterministic_cudnn, deterministic_agents=self.deterministic_agents, mp_ctx=self.mp_ctx, device=device, max_sampler_processes_per_worker= max_sampler_processes_per_worker, ), ) valid.start() self.processes["valid"].append(valid) get_logger().info("Started {} valid processes".format( len(self.processes["valid"]))) else: get_logger().info( "No processes allocated to validation, no validation will be run." ) self.log(self.local_start_time_str, num_workers) return self.local_start_time_str