예제 #1
0
    def start_train(
        self,
        checkpoint: Optional[str] = None,
        restart_pipeline: bool = False,
        max_sampler_processes_per_worker: Optional[int] = None,
    ):
        if not self.disable_config_saving:
            self.save_project_state()

        devices = self.worker_devices("train")
        num_workers = len(devices)

        # Be extra careful to ensure that all models start
        # with the same initializations.
        set_seed(self.seed)
        initial_model_state_dict = self.config.create_model(
            sensor_preprocessor_graph=MachineParams.instance_from(
                self.config.machine_params(
                    self.mode)).sensor_preprocessor_graph).state_dict()

        distributed_port = 0
        if num_workers > 1:
            distributed_port = find_free_port()

        for trainer_it in range(num_workers):
            train: BaseProcess = self.mp_ctx.Process(
                target=self.train_loop,
                kwargs=dict(
                    id=trainer_it,
                    checkpoint=checkpoint,
                    restart_pipeline=restart_pipeline,
                    experiment_name=self.experiment_name,
                    config=self.config,
                    results_queue=self.queues["results"],
                    checkpoints_queue=self.queues["checkpoints"]
                    if self.running_validation else None,
                    checkpoints_dir=self.checkpoint_dir(),
                    seed=self.seed,
                    deterministic_cudnn=self.deterministic_cudnn,
                    mp_ctx=self.mp_ctx,
                    num_workers=num_workers,
                    device=devices[trainer_it],
                    distributed_port=distributed_port,
                    max_sampler_processes_per_worker=
                    max_sampler_processes_per_worker,
                    initial_model_state_dict=initial_model_state_dict,
                ),
            )
            train.start()
            self.processes["train"].append(train)

        get_logger().info("Started {} train processes".format(
            len(self.processes["train"])))

        # Validation
        if self.running_validation:
            device = self.worker_devices("valid")[0]
            self.init_visualizer("valid")
            valid: BaseProcess = self.mp_ctx.Process(
                target=self.valid_loop,
                args=(0, ),
                kwargs=dict(
                    config=self.config,
                    results_queue=self.queues["results"],
                    checkpoints_queue=self.queues["checkpoints"],
                    seed=
                    12345,  # TODO allow same order for randomly sampled tasks? Is this any useful anyway?
                    deterministic_cudnn=self.deterministic_cudnn,
                    deterministic_agents=self.deterministic_agents,
                    mp_ctx=self.mp_ctx,
                    device=device,
                    max_sampler_processes_per_worker=
                    max_sampler_processes_per_worker,
                ),
            )
            valid.start()
            self.processes["valid"].append(valid)

            get_logger().info("Started {} valid processes".format(
                len(self.processes["valid"])))
        else:
            get_logger().info(
                "No processes allocated to validation, no validation will be run."
            )

        self.log(self.local_start_time_str, num_workers)

        return self.local_start_time_str
예제 #2
0
    def start_test(
        self,
        experiment_date: str,
        checkpoint_name_fragment: Optional[str] = None,
        approx_ckpt_steps_count: Optional[Union[float, int]] = None,
        skip_checkpoints: int = 0,
        max_sampler_processes_per_worker: Optional[int] = None,
    ):
        devices = self.worker_devices("test")
        self.init_visualizer("test")
        num_testers = len(devices)

        distributed_port = 0
        if num_testers > 1:
            distributed_port = find_free_port()

        for tester_it in range(num_testers):
            test: BaseProcess = self.mp_ctx.Process(
                target=self.test_loop,
                args=(tester_it, ),
                kwargs=dict(
                    config=self.config,
                    results_queue=self.queues["results"],
                    checkpoints_queue=self.queues["checkpoints"],
                    seed=
                    12345,  # TODO allow same order for randomly sampled tasks? Is this any useful anyway?
                    deterministic_cudnn=self.deterministic_cudnn,
                    deterministic_agents=self.deterministic_agents,
                    mp_ctx=self.mp_ctx,
                    num_workers=num_testers,
                    device=devices[tester_it],
                    max_sampler_processes_per_worker=
                    max_sampler_processes_per_worker,
                    distributed_port=distributed_port,
                ),
            )

            test.start()
            self.processes["test"].append(test)

        get_logger().info("Started {} test processes".format(
            len(self.processes["test"])))

        checkpoint_paths = self.get_checkpoint_files(
            experiment_date=experiment_date,
            checkpoint_name_fragment=checkpoint_name_fragment,
            approx_ckpt_steps_count=approx_ckpt_steps_count,
            skip_checkpoints=skip_checkpoints,
        )
        steps = [self.step_from_checkpoint(cp) for cp in checkpoint_paths]

        get_logger().info("Running test on {} steps {}".format(
            len(steps), steps))

        for checkpoint_path in checkpoint_paths:
            # Make all testers work on each checkpoint
            for tester_it in range(num_testers):
                self.queues["checkpoints"].put(("eval", checkpoint_path))
        # Signal all testers to terminate cleanly
        for _ in range(num_testers):
            self.queues["checkpoints"].put(("quit", None))

        metrics_dir = self.metric_path(experiment_date)
        os.makedirs(metrics_dir, exist_ok=True)
        suffix = "__test_{}".format(self.local_start_time_str)
        metrics_file_path = os.path.join(metrics_dir,
                                         "metrics" + suffix + ".json")

        get_logger().info("Saving metrics in {}".format(metrics_file_path))

        # Check output file can be written
        with open(metrics_file_path, "w") as f:
            json.dump([], f, indent=4, sort_keys=True, cls=NumpyJSONEncoder)

        return self.log(
            start_time_str=self.checkpoint_start_time_str(checkpoint_paths[0]),
            nworkers=num_testers,
            test_steps=steps,
            metrics_file=metrics_file_path,
        )
예제 #3
0
    def start_train(
        self,
        checkpoint: Optional[str] = None,
        restart_pipeline: bool = False,
        max_sampler_processes_per_worker: Optional[int] = None,
    ):
        self._initialize_start_train_or_start_test()

        if not self.disable_config_saving:
            self.save_project_state()

        devices = self.worker_devices(TRAIN_MODE_STR)
        num_workers = len(devices)

        # Be extra careful to ensure that all models start
        # with the same initializations.
        set_seed(self.seed)
        initial_model_state_dict = self.config.create_model(
            sensor_preprocessor_graph=MachineParams.instance_from(
                self.config.machine_params(
                    self.mode)).sensor_preprocessor_graph).state_dict()

        distributed_port = 0
        if num_workers > 1:
            distributed_port = find_free_port()

        model_hash = None
        for trainer_it in range(num_workers):
            training_kwargs = dict(
                id=trainer_it,
                checkpoint=checkpoint,
                restart_pipeline=restart_pipeline,
                experiment_name=self.experiment_name,
                config=self.config,
                results_queue=self.queues["results"],
                checkpoints_queue=self.queues["checkpoints"]
                if self.running_validation else None,
                checkpoints_dir=self.checkpoint_dir(),
                seed=self.seed,
                deterministic_cudnn=self.deterministic_cudnn,
                mp_ctx=self.mp_ctx,
                num_workers=num_workers,
                device=devices[trainer_it],
                distributed_port=distributed_port,
                max_sampler_processes_per_worker=
                max_sampler_processes_per_worker,
                initial_model_state_dict=initial_model_state_dict
                if model_hash is None else model_hash,
            )
            train: BaseProcess = self.mp_ctx.Process(
                target=self.train_loop,
                kwargs=training_kwargs,
            )
            try:
                train.start()
            except ValueError as e:
                # If the `initial_model_state_dict` is too large we sometimes
                # run into errors passing it with multiprocessing. In such cases
                # we instead has the state_dict and confirm, in each engine worker, that
                # this hash equals the model the engine worker instantiates.
                if e.args[0] == "too many fds":
                    model_hash = md5_hash_of_state_dict(
                        initial_model_state_dict)
                    training_kwargs["initial_model_state_dict"] = model_hash
                    train = self.mp_ctx.Process(
                        target=self.train_loop,
                        kwargs=training_kwargs,
                    )
                    train.start()
                else:
                    raise e

            self.processes[TRAIN_MODE_STR].append(train)

        get_logger().info("Started {} train processes".format(
            len(self.processes[TRAIN_MODE_STR])))

        # Validation
        if self.running_validation:
            device = self.worker_devices("valid")[0]
            self.init_visualizer("valid")
            valid: BaseProcess = self.mp_ctx.Process(
                target=self.valid_loop,
                args=(0, ),
                kwargs=dict(
                    config=self.config,
                    results_queue=self.queues["results"],
                    checkpoints_queue=self.queues["checkpoints"],
                    seed=
                    12345,  # TODO allow same order for randomly sampled tasks? Is this any useful anyway?
                    deterministic_cudnn=self.deterministic_cudnn,
                    deterministic_agents=self.deterministic_agents,
                    mp_ctx=self.mp_ctx,
                    device=device,
                    max_sampler_processes_per_worker=
                    max_sampler_processes_per_worker,
                ),
            )
            valid.start()
            self.processes["valid"].append(valid)

            get_logger().info("Started {} valid processes".format(
                len(self.processes["valid"])))
        else:
            get_logger().info(
                "No processes allocated to validation, no validation will be run."
            )

        self.log_and_close(self.local_start_time_str, num_workers)

        return self.local_start_time_str