Ejemplo n.º 1
0
    def worker_devices(self, mode: str):
        machine_params: MachineParams = MachineParams.instance_from(
            self.config.machine_params(mode))
        devices = machine_params.devices

        assert all_equal(devices) or all(
            d.index >= 0 for d in devices
        ), f"Cannot have a mix of CPU and GPU devices (`devices == {devices}`)"

        get_logger().info("Using {} {} workers on devices {}".format(
            len(devices), mode, devices))
        return devices
Ejemplo n.º 2
0
    def start_train(
        self,
        checkpoint: Optional[str] = None,
        restart_pipeline: bool = False,
        max_sampler_processes_per_worker: Optional[int] = None,
    ):
        if not self.disable_config_saving:
            self.save_project_state()

        devices = self.worker_devices("train")
        num_workers = len(devices)

        # Be extra careful to ensure that all models start
        # with the same initializations.
        set_seed(self.seed)
        initial_model_state_dict = self.config.create_model(
            sensor_preprocessor_graph=MachineParams.instance_from(
                self.config.machine_params(
                    self.mode)).sensor_preprocessor_graph).state_dict()

        distributed_port = 0
        if num_workers > 1:
            distributed_port = find_free_port()

        for trainer_it in range(num_workers):
            train: BaseProcess = self.mp_ctx.Process(
                target=self.train_loop,
                kwargs=dict(
                    id=trainer_it,
                    checkpoint=checkpoint,
                    restart_pipeline=restart_pipeline,
                    experiment_name=self.experiment_name,
                    config=self.config,
                    results_queue=self.queues["results"],
                    checkpoints_queue=self.queues["checkpoints"]
                    if self.running_validation else None,
                    checkpoints_dir=self.checkpoint_dir(),
                    seed=self.seed,
                    deterministic_cudnn=self.deterministic_cudnn,
                    mp_ctx=self.mp_ctx,
                    num_workers=num_workers,
                    device=devices[trainer_it],
                    distributed_port=distributed_port,
                    max_sampler_processes_per_worker=
                    max_sampler_processes_per_worker,
                    initial_model_state_dict=initial_model_state_dict,
                ),
            )
            train.start()
            self.processes["train"].append(train)

        get_logger().info("Started {} train processes".format(
            len(self.processes["train"])))

        # Validation
        if self.running_validation:
            device = self.worker_devices("valid")[0]
            self.init_visualizer("valid")
            valid: BaseProcess = self.mp_ctx.Process(
                target=self.valid_loop,
                args=(0, ),
                kwargs=dict(
                    config=self.config,
                    results_queue=self.queues["results"],
                    checkpoints_queue=self.queues["checkpoints"],
                    seed=
                    12345,  # TODO allow same order for randomly sampled tasks? Is this any useful anyway?
                    deterministic_cudnn=self.deterministic_cudnn,
                    deterministic_agents=self.deterministic_agents,
                    mp_ctx=self.mp_ctx,
                    device=device,
                    max_sampler_processes_per_worker=
                    max_sampler_processes_per_worker,
                ),
            )
            valid.start()
            self.processes["valid"].append(valid)

            get_logger().info("Started {} valid processes".format(
                len(self.processes["valid"])))
        else:
            get_logger().info(
                "No processes allocated to validation, no validation will be run."
            )

        self.log(self.local_start_time_str, num_workers)

        return self.local_start_time_str
Ejemplo n.º 3
0
 def init_visualizer(self, mode: str):
     if not self.disable_tensorboard:
         # Note: Avoid instantiating anything in machine_params (use Builder if needed)
         machine_params = MachineParams.instance_from(
             self.config.machine_params(mode))
         self.visualizer = machine_params.visualizer
Ejemplo n.º 4
0
 def running_validation(self):
     return (sum(
         MachineParams.instance_from(
             self.config.machine_params("valid")).nprocesses) > 0)
Ejemplo n.º 5
0
    def start_train(
        self,
        checkpoint: Optional[str] = None,
        restart_pipeline: bool = False,
        max_sampler_processes_per_worker: Optional[int] = None,
    ):
        self._initialize_start_train_or_start_test()

        if not self.disable_config_saving:
            self.save_project_state()

        devices = self.worker_devices(TRAIN_MODE_STR)
        num_workers = len(devices)

        # Be extra careful to ensure that all models start
        # with the same initializations.
        set_seed(self.seed)
        initial_model_state_dict = self.config.create_model(
            sensor_preprocessor_graph=MachineParams.instance_from(
                self.config.machine_params(
                    self.mode)).sensor_preprocessor_graph).state_dict()

        distributed_port = 0
        if num_workers > 1:
            distributed_port = find_free_port()

        model_hash = None
        for trainer_it in range(num_workers):
            training_kwargs = dict(
                id=trainer_it,
                checkpoint=checkpoint,
                restart_pipeline=restart_pipeline,
                experiment_name=self.experiment_name,
                config=self.config,
                results_queue=self.queues["results"],
                checkpoints_queue=self.queues["checkpoints"]
                if self.running_validation else None,
                checkpoints_dir=self.checkpoint_dir(),
                seed=self.seed,
                deterministic_cudnn=self.deterministic_cudnn,
                mp_ctx=self.mp_ctx,
                num_workers=num_workers,
                device=devices[trainer_it],
                distributed_port=distributed_port,
                max_sampler_processes_per_worker=
                max_sampler_processes_per_worker,
                initial_model_state_dict=initial_model_state_dict
                if model_hash is None else model_hash,
            )
            train: BaseProcess = self.mp_ctx.Process(
                target=self.train_loop,
                kwargs=training_kwargs,
            )
            try:
                train.start()
            except ValueError as e:
                # If the `initial_model_state_dict` is too large we sometimes
                # run into errors passing it with multiprocessing. In such cases
                # we instead has the state_dict and confirm, in each engine worker, that
                # this hash equals the model the engine worker instantiates.
                if e.args[0] == "too many fds":
                    model_hash = md5_hash_of_state_dict(
                        initial_model_state_dict)
                    training_kwargs["initial_model_state_dict"] = model_hash
                    train = self.mp_ctx.Process(
                        target=self.train_loop,
                        kwargs=training_kwargs,
                    )
                    train.start()
                else:
                    raise e

            self.processes[TRAIN_MODE_STR].append(train)

        get_logger().info("Started {} train processes".format(
            len(self.processes[TRAIN_MODE_STR])))

        # Validation
        if self.running_validation:
            device = self.worker_devices("valid")[0]
            self.init_visualizer("valid")
            valid: BaseProcess = self.mp_ctx.Process(
                target=self.valid_loop,
                args=(0, ),
                kwargs=dict(
                    config=self.config,
                    results_queue=self.queues["results"],
                    checkpoints_queue=self.queues["checkpoints"],
                    seed=
                    12345,  # TODO allow same order for randomly sampled tasks? Is this any useful anyway?
                    deterministic_cudnn=self.deterministic_cudnn,
                    deterministic_agents=self.deterministic_agents,
                    mp_ctx=self.mp_ctx,
                    device=device,
                    max_sampler_processes_per_worker=
                    max_sampler_processes_per_worker,
                ),
            )
            valid.start()
            self.processes["valid"].append(valid)

            get_logger().info("Started {} valid processes".format(
                len(self.processes["valid"])))
        else:
            get_logger().info(
                "No processes allocated to validation, no validation will be run."
            )

        self.log_and_close(self.local_start_time_str, num_workers)

        return self.local_start_time_str