Ejemplo n.º 1
0
def setup_training_environment(
    run_name,
    base_directory,
    cfg_filename,
    device,
    machine_rank,
    mixed_precision,
    debug=False,
):

    env = setup_common_environment(
        run_name,
        base_directory,
        cfg_filename,
        device,
        machine_rank,
        mixed_precision,
        debug=debug,
    )
    # Write config file to experiment directory.
    config_file_in_project_folder = env.experiment_dir / "config.yaml"
    logger.info(
        f"Writing configuration file to: {config_file_in_project_folder}.")
    if communication.is_main_process():
        with open(config_file_in_project_folder, "w") as f:
            f.write(OmegaConf.to_yaml(env.cfg))
    communication.synchronize()

    return env
Ejemplo n.º 2
0
def setup_logging(machine_rank, output_directory, run_name, cfg_filename, cfg, debug):
    # Setup logging
    log_file = (
        output_directory / f"log_{machine_rank}_{communication.get_local_rank()}.txt"
    )

    direct.utils.logging.setup(
        use_stdout=communication.is_main_process() or debug,
        filename=log_file,
        log_level=("INFO" if not debug else "DEBUG"),
    )
    logger.info(f"Machine rank: {machine_rank}.")
    logger.info(f"Local rank: {communication.get_local_rank()}.")
    logger.info(f"Logging: {log_file}.")
    logger.info(f"Saving to: {output_directory}.")
    logger.info(f"Run name: {run_name}.")
    logger.info(f"Config file: {cfg_filename}.")
    logger.info(f"CUDA {torch.version.cuda} - cuDNN {torch.backends.cudnn.version()}.")
    logger.info(f"Environment information: {collect_env.get_pretty_env_info()}.")
    logger.info(f"DIRECT version: {direct.__version__}.")  # noqa
    git_hash = direct.utils.git_hash()
    logger.info(f"Git hash: {git_hash if git_hash else 'N/A'}.")  # noqa
    logger.info(f"Configuration: {OmegaConf.to_yaml(cfg)}.")
Ejemplo n.º 3
0
    def train(
            self,
            optimizer: torch.optim.Optimizer,
            lr_scheduler: torch.optim.lr_scheduler._LRScheduler,  # noqa
            training_data: Dataset,
            experiment_directory: pathlib.Path,
            validation_data: Dataset = None,
            resume: bool = False,
            num_workers: int = 0) -> None:

        # TODO: Does not need to be member of self.
        self.__optimizer = optimizer
        # TODO: Optimizer and LR scheduler need to be resumed too.
        self.__lr_scheduler = lr_scheduler

        training_sampler = self.build_sampler(training_data, 'random')
        # TODO: Configurable
        training_loader = self.build_loader(
            training_data,
            sampler=training_sampler,
            batch_size=self.cfg.training.batch_size,
            num_workers=num_workers,
            drop_last=True)

        if validation_data:
            validation_sampler = self.build_sampler(
                validation_data, 'sequential', limit_number_of_volumes=None)
            batch_sampler = BatchSampler(validation_sampler,
                                         batch_size=8 *
                                         self.cfg.training.batch_size,
                                         drop_last=False)
            # TODO: Batch size can be much larger, perhaps have a different batch size during evaluation.
            validation_loader = self.build_loader(
                validation_data,
                batch_sampler=batch_sampler,
                num_workers=num_workers,
            )

        self.model = self.model.to(self.device)

        # Optimizer
        self.__optimizer.zero_grad()  # type: ignore

        # Mixed precision setup. This requires the model to be on the gpu.
        extra_checkpointing = {}
        if self.mixed_precision > 0:
            opt_level = f'O{self.mixed_precision}'
            self.logger.info(f'Using apex level {opt_level}.')
            self.model, self.__optimizer = amp.initialize(self.model,
                                                          self.__optimizer,
                                                          opt_level=opt_level)
            extra_checkpointing['amp'] = amp
            extra_checkpointing['opt_level'] = opt_level
            git_hash = direct.utils.git_hash()
            extra_checkpointing['__author__'] = git_hash if git_hash else 'N/A'

        self.checkpointer = Checkpointer(
            self.model,
            experiment_directory,
            save_to_disk=communication.is_main_process(),
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            **extra_checkpointing)

        # Load checkpoint
        start_iter = 0
        if resume:
            self.logger.info('Attempting to resume...')
            # This changes the model inplace
            checkpoint = self.checkpointer.load(
                iteration='latest',
                checkpointable_objects=['amp']
                if self.mixed_precision > 0 else [])
            if not checkpoint:
                self.logger.info('No checkpoint found. Starting from scratch.')
            else:
                start_iter = checkpoint['iteration'] + 1
                self.logger.info(f'Starting from iteration: {start_iter}.')

            if '__author__' in checkpoint:
                self.logger.info(
                    f"Git hash of checkpoint: {checkpoint['__author__']}")
                if checkpoint['__author__'] != direct.utils.git_hash():
                    self.logger.warning(
                        f"Current git hash {direct.utils.git_hash()} is different from the one "
                        f"this checkpoint is saved with ({checkpoint['__author__']}. This can be fine, "
                        f"but beware that this can be a source of confusion.")

            if '__datetime__' in checkpoint:
                self.logger.info(
                    f"Checkpoint created at: {checkpoint['__datetime__']}")
            if 'opt_level' in checkpoint:
                if checkpoint['opt_level'] != opt_level:
                    self.logger.warning(
                        f"Mixed precision opt-levels do not match. "
                        f"Requested {opt_level} got {checkpoint['opt_level']} from checkpoint. "
                        f"This will almost surely lead to performance degradation."
                    )

        self.logger.info(f'World size: {communication.get_world_size()}.')
        if communication.get_world_size() > 1:
            self.model = DistributedDataParallel(
                self.model,
                device_ids=[communication.get_rank()],
                broadcast_buffers=False)

        # World size > 1 if distributed mode, else allow a DataParallel fallback, can be convenient for debugging.
        elif torch.cuda.device_count() > 1 and communication.get_world_size(
        ) == 1:
            self.model = DataParallel(self.model)

        self.__writers = ([
            JSONWriter(experiment_directory / 'metrics.json'),
            CommonMetricPrinter(self.cfg.training.num_iterations),
            TensorboardWriter(experiment_directory / 'tensorboard')
        ] if communication.is_main_process() else [])

        with EventStorage(start_iter):
            self.training_loop(training_loader, start_iter, validation_loader)

        self.logger.info('Training completed.')
Ejemplo n.º 4
0
    def train(
        self,
        optimizer: torch.optim.Optimizer,
        lr_scheduler: torch.optim.lr_scheduler._LRScheduler,  # noqa
        training_datasets: List[Dataset],
        experiment_directory: pathlib.Path,
        validation_data: Optional[Dataset] = None,
        resume: bool = False,
        initialization: Optional[PathOrString] = None,
        num_workers: int = 6,
    ) -> None:
        self.logger.info("Starting training.")
        # TODO: Does not need to be member of self.
        self.__optimizer = optimizer
        # TODO: Optimizer and LR scheduler need to be resumed too.
        self.__lr_scheduler = lr_scheduler
        training_data = ConcatDataset(training_datasets)
        self.logger.info(f"Concatenated dataset length: {len(training_data)}.")
        self.logger.info(
            f"Building batch sampler for training set with batch size {self.cfg.training.batch_size}."
        )
        training_sampler = self.build_batch_sampler(
            training_datasets, self.cfg.training.batch_size, "random")
        training_loader = self.build_loader(
            training_data,
            batch_sampler=training_sampler,
            num_workers=num_workers,
        )

        if validation_data:
            validation_loaders = []
            for idx, curr_validation_data in enumerate(validation_data):
                text_dataset_description = curr_validation_data.text_description
                self.logger.info(
                    f"Building dataloader for dataset: {text_dataset_description}."
                )
                curr_batch_sampler = self.build_batch_sampler(
                    curr_validation_data,
                    batch_size=self.cfg.validation.batch_size,
                    sampler_type="sequential",
                    limit_number_of_volumes=None,
                )
                validation_loaders.append((
                    text_dataset_description,
                    self.build_loader(
                        curr_validation_data,
                        batch_sampler=curr_batch_sampler,
                        num_workers=
                        0,  # num_workers, # TODO(jt): This seems to choke the validation.
                    ),
                ))
        else:
            validation_loaders = None

        self.models_to_device()

        # Optimizer
        self.__optimizer.zero_grad()  # type: ignore

        # Mixed precision setup. This requires the model to be on the gpu.
        git_hash = direct.utils.git_hash()
        extra_checkpointing = {
            "__author__": git_hash if git_hash else "N/A",
            "__version__": direct.__version__,
            "__mixed_precision__": self.mixed_precision,
        }
        if self.mixed_precision:
            # TODO(jt): Check if on GPU
            self.logger.info(f"Using mixed precision training.")

        self.checkpointer = Checkpointer(
            self.model,
            experiment_directory,
            save_to_disk=communication.is_main_process(),
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            scaler=self._scaler,
            **self.models,
            **extra_checkpointing,
        )

        # Load checkpoint
        start_iter = 0
        checkpoint = {}
        if resume:
            self.logger.info("Attempting to resume...")
            # This changes the model inplace
            checkpoint = self.checkpointer.load(iteration="latest")
            if not checkpoint:
                self.logger.info("No checkpoint found. Starting from scratch.")
            else:
                start_iter = checkpoint["iteration"] + 1
                self.logger.info(f"Starting from iteration: {start_iter}.")

        if start_iter > 0 and initialization:
            self.logger.warning(
                f"Initialization checkpoint set to {initialization},"
                f" but model will resume training from previous checkpoint. Initialization ignored."
            )
        elif initialization:
            self.logger.info(f"Initializing from {initialization}...")
            self.checkpointer.load_from_file(initialization)

        if "__version__" in checkpoint:
            self.logger.info(
                f"DIRECT version of checkpoint: {checkpoint['__version__']}.")
            if checkpoint["__version__"] != direct.__version__:
                self.logger.warning(
                    f"Current DIRECT version {direct.__version__} is different from the one "
                    f"this checkpoint is saved with ({checkpoint['__version__']}. This can be fine, "
                    f"but beware that this can be a source of confusion.")

        if "__author__" in checkpoint:
            self.logger.info(
                f"Git hash of checkpoint: {checkpoint['__author__']}.")
            if checkpoint["__author__"] != direct.utils.git_hash():
                self.logger.warning(
                    f"Current git hash {direct.utils.git_hash()} is different from the one "
                    f"this checkpoint is saved with ({checkpoint['__author__']}. This can be fine, "
                    f"but beware that this can be a source of confusion.")

        if "__datetime__" in checkpoint:
            self.logger.info(
                f"Checkpoint created at: {checkpoint['__datetime__']}.")

        if "__mixed_precision__" in checkpoint:
            if (not self.mixed_precision
                ) and checkpoint["__mixed_precision__"]:
                self.logger.warning(
                    f"Mixed precision training is not enabled, yet saved checkpoint requests this"
                    f"Will now enable mixed precision.")
                self.mixed_precision = True
            elif not checkpoint["__mixed_precision__"] and self.mixed_precision:
                self.logger.warning(
                    f"Mixed precision levels of training and loading checkpoint do not match. "
                    f"Requested mixed precision but checkpoint is saved without. "
                    f"This will almost surely lead to performance degradation."
                )

        self.logger.info(f"World size: {communication.get_world_size()}.")
        self.logger.info(f"Device count: {torch.cuda.device_count()}.")
        if communication.get_world_size() > 1:
            self.model = DistributedDataParallel(
                self.model,
                device_ids=[communication.get_local_rank()],
                broadcast_buffers=False,
            )

        # World size > 1 if distributed mode, else allow a DataParallel fallback, can be convenient for debugging.
        elif torch.cuda.device_count() > 1 and communication.get_world_size(
        ) == 1:
            self.model = DataParallel(self.model)

        self.__writers = ([
            JSONWriter(experiment_directory / "metrics.json"),
            CommonMetricPrinter(self.cfg.training.num_iterations),
            TensorboardWriter(experiment_directory / "tensorboard"),
        ] if communication.is_main_process() else [])

        with EventStorage(start_iter):
            self.training_loop(
                training_loader,
                start_iter,
                validation_loaders,
                experiment_directory=experiment_directory,
            )

        self.logger.info("Training completed.")
Ejemplo n.º 5
0
    def validation_loop(
        self,
        validation_datasets,
        loss_fns,
        experiment_directory,
        iter_idx,
        num_workers: int = 6,
    ):
        if not validation_datasets:
            return

        storage = get_event_storage()

        data_loaders = self.build_validation_loaders(
            validation_data=validation_datasets,
            num_workers=num_workers,
        )
        for curr_dataset_name, curr_data_loader in data_loaders:
            self.logger.info(f"Evaluating: {curr_dataset_name}...")
            (
                curr_loss_dict,
                curr_metrics_per_case,
                visualize_slices,
                visualize_target,
            ) = self.evaluate(
                curr_data_loader,
                loss_fns,
                is_validation_process=True,
            )

            if experiment_directory:
                json_output_fn = (
                    experiment_directory
                    / f"metrics_val_{curr_dataset_name}_{iter_idx}.json"
                )
                json_output_fn.parent.mkdir(
                    exist_ok=True, parents=True
                )  # A / in the filename can create a folder
                if communication.is_main_process():
                    write_json(
                        json_output_fn,
                        curr_metrics_per_case,
                    )
                self.logger.info(f"Wrote per image logs to: {json_output_fn}.")

            # Metric dict still needs to be reduced as it gives values *per* data
            curr_metric_dict = reduce_list_of_dicts(
                list(curr_metrics_per_case.values()), mode="average"
            )

            key_prefix = (
                "val/" if not curr_dataset_name else f"val/{curr_dataset_name}/"
            )
            loss_reduced = sum(curr_loss_dict.values())
            storage.add_scalars(
                **{key_prefix + "loss": loss_reduced},
                **{
                    **prefix_dict_keys(curr_metric_dict, key_prefix),
                    **prefix_dict_keys(curr_loss_dict, key_prefix),
                },
                smoothing_hint=False,
            )
            visualize_slices = self.process_slices_for_visualization(
                visualize_slices, visualize_target
            )
            storage.add_image(f"{key_prefix}prediction", visualize_slices)

            if iter_idx // self.cfg.training.validation_steps - 1 == 0:
                visualize_target = make_grid(
                    crop_to_largest(visualize_target, pad_value=0),
                    nrow=self.cfg.logging.tensorboard.num_images,
                    scale_each=True,
                )
                storage.add_image(f"{key_prefix}target", visualize_target)

            self.logger.info(
                f"Done evaluation of {curr_dataset_name} at iteration {iter_idx}."
            )
        self.model.train()
Ejemplo n.º 6
0
    def train(
        self,
        optimizer: torch.optim.Optimizer,
        lr_scheduler: torch.optim.lr_scheduler._LRScheduler,  # noqa
        training_datasets: List[Dataset],
        experiment_directory: pathlib.Path,
        validation_datasets: Optional[Dataset] = None,
        resume: bool = False,
        start_with_validation: bool = False,
        initialization: Optional[PathOrString] = None,
        num_workers: int = 6,
    ) -> None:
        self.logger.info("Starting training.")
        # Can consider not to make this a member of self, but that requires that optimizer is passed to
        # training_loop()
        self.__optimizer = optimizer
        self.__lr_scheduler = lr_scheduler

        self.models_to_device()

        # Optimizer
        self.__optimizer.zero_grad()  # type: ignore

        # Mixed precision setup. This requires the model to be on the gpu.
        git_hash = direct.utils.git_hash()
        checkpointing_metadata = {
            "__author__": git_hash if git_hash else "N/A",
            "__version__": direct.__version__,
            "__mixed_precision__": self.mixed_precision,
        }
        if self.mixed_precision:
            # TODO(jt): Check if on GPU
            self.logger.info("Using mixed precision training.")

        self.checkpointer = Checkpointer(
            save_directory=experiment_directory,
            save_to_disk=False
            if not communication.is_main_process() else True,
            model=self.model,
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            scaler=self._scaler,
            **checkpointing_metadata,  # type: ignore
            **self.models,  # type: ignore
        )

        # Load checkpoint
        start_iter = 0
        checkpoint = {}

        if resume:
            self.logger.info("Attempting to resume...")
            # This changes the model inplace
            checkpoint = self.checkpointer.load(iteration="latest")
            if not checkpoint:
                self.logger.info("No checkpoint found. Starting from scratch.")
            else:
                start_iter = checkpoint["iteration"] + 1
                self.logger.info(f"Starting from iteration: {start_iter}.")

        if start_iter > 0 and initialization:
            self.logger.warning(
                f"Initialization checkpoint set to {initialization},"
                f" but model will resume training from previous checkpoint. Initialization ignored."
            )
        elif initialization:
            self.logger.info(f"Initializing from {initialization}...")
            self.checkpointer.load_models_from_file(initialization)
            start_with_validation = True
            self.logger.info("Setting start_with_validation to True.")

        if "__version__" in checkpoint:
            self.logger.info(
                f"DIRECT version of checkpoint: {checkpoint['__version__']}.")
            if checkpoint["__version__"] != direct.__version__:
                self.logger.warning(
                    f"Current DIRECT version {direct.__version__} is different from the one "
                    f"this checkpoint is saved with: {checkpoint['__version__']}. This can be fine, "
                    f"but beware that this can be a source of confusion.")

        if "__author__" in checkpoint:
            self.logger.info(
                f"Git hash of checkpoint: {checkpoint['__author__']}.")
            if checkpoint["__author__"] != direct.utils.git_hash():
                self.logger.warning(
                    f"Current git hash {direct.utils.git_hash()} is different from the one "
                    f"this checkpoint is saved with: {checkpoint['__author__']}. This can be fine, "
                    f"but beware that this can be a source of confusion.")

        if "__datetime__" in checkpoint:
            self.logger.info(
                f"Checkpoint created at: {checkpoint['__datetime__']}.")

        if "__mixed_precision__" in checkpoint:
            if (not self.mixed_precision
                ) and checkpoint["__mixed_precision__"]:
                self.logger.warning(
                    "Mixed precision training is not enabled, yet saved checkpoint requests this"
                    f"Will now enable mixed precision.")
                self.mixed_precision = True
            elif not checkpoint["__mixed_precision__"] and self.mixed_precision:
                self.logger.warning(
                    "Mixed precision levels of training and loading checkpoint do not match. "
                    f"Requested mixed precision but checkpoint is saved without. "
                    f"This will almost surely lead to performance degradation."
                )

        if start_with_validation:
            self.logger.info("Requested to start with validation.")

        self.logger.info(f"World size: {communication.get_world_size()}.")
        self.logger.info(f"Device count: {torch.cuda.device_count()}.")
        if communication.get_world_size() > 1:
            self.model = DistributedDataParallel(
                self.model,
                device_ids=[communication.get_local_rank()],
                broadcast_buffers=False,
            )

        # World size > 1 if distributed mode, else allow a DataParallel fallback, can be convenient for debugging.
        elif torch.cuda.device_count() > 1 and communication.get_world_size(
        ) == 1:
            self.model = DataParallel(self.model)

        self.__writers = ([
            JSONWriter(experiment_directory / "metrics.json"),
            CommonMetricPrinter(
                self.cfg.training.num_iterations),  # type: ignore
            TensorboardWriter(experiment_directory / "tensorboard"),
        ] if communication.is_main_process() else [])

        with EventStorage(start_iter):
            self.training_loop(
                training_datasets,
                start_iter,
                validation_datasets,
                experiment_directory=experiment_directory,
                num_workers=num_workers,
                start_with_validation=start_with_validation,
            )

        self.logger.info("Training completed.")