Exemplo n.º 1
0
    def __init__(self,
                 dataset,
                 num_replicas=None,
                 rank=None,
                 limit_number_of_volumes=None):
        if num_replicas is None:
            num_replicas = communication.get_world_size()
        if rank is None:
            rank = communication.get_rank()
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank

        filenames = list(
            self.dataset.volume_indices.keys())  # This is an OrderedDict
        if limit_number_of_volumes:
            filenames = filenames[:limit_number_of_volumes]

        chunked_filenames = list(chunks(filenames, self.num_replicas))
        filenames = chunked_filenames[self.rank]

        # Collect the indices belonging to these filenames.
        self.indices = []
        if self.rank < len(
                chunked_filenames):  # Otherwise there is nothing to fill.
            for filename in filenames:
                self.indices.extend(list(
                    self.dataset.volume_indices[filename]))
Exemplo n.º 2
0
    def __init__(
        self,
        size: int,
        shuffle: bool = True,
        seed: Optional[int] = None,
    ):
        """
        Parameters
        ----------
        size : int
            Size of underlying dataset.
        shuffle : bool
            If true, the indices will be shuffled.
        seed : int
            Initial seed of the shuffle, must be the same across all workers!
        """
        self._size = size
        if self._size <= 0:
            raise AssertionError
        self._shuffle = shuffle
        if seed is None:
            seed = communication.shared_random_seed()
        self._seed = int(seed)

        self._rank = communication.get_rank()
        self._world_size = communication.get_world_size()
Exemplo n.º 3
0
    def train(
            self,
            optimizer: torch.optim.Optimizer,
            lr_scheduler: torch.optim.lr_scheduler._LRScheduler,  # noqa
            training_data: Dataset,
            experiment_directory: pathlib.Path,
            validation_data: Dataset = None,
            resume: bool = False,
            num_workers: int = 0) -> None:

        # TODO: Does not need to be member of self.
        self.__optimizer = optimizer
        # TODO: Optimizer and LR scheduler need to be resumed too.
        self.__lr_scheduler = lr_scheduler

        training_sampler = self.build_sampler(training_data, 'random')
        # TODO: Configurable
        training_loader = self.build_loader(
            training_data,
            sampler=training_sampler,
            batch_size=self.cfg.training.batch_size,
            num_workers=num_workers,
            drop_last=True)

        if validation_data:
            validation_sampler = self.build_sampler(
                validation_data, 'sequential', limit_number_of_volumes=None)
            batch_sampler = BatchSampler(validation_sampler,
                                         batch_size=8 *
                                         self.cfg.training.batch_size,
                                         drop_last=False)
            # TODO: Batch size can be much larger, perhaps have a different batch size during evaluation.
            validation_loader = self.build_loader(
                validation_data,
                batch_sampler=batch_sampler,
                num_workers=num_workers,
            )

        self.model = self.model.to(self.device)

        # Optimizer
        self.__optimizer.zero_grad()  # type: ignore

        # Mixed precision setup. This requires the model to be on the gpu.
        extra_checkpointing = {}
        if self.mixed_precision > 0:
            opt_level = f'O{self.mixed_precision}'
            self.logger.info(f'Using apex level {opt_level}.')
            self.model, self.__optimizer = amp.initialize(self.model,
                                                          self.__optimizer,
                                                          opt_level=opt_level)
            extra_checkpointing['amp'] = amp
            extra_checkpointing['opt_level'] = opt_level
            git_hash = direct.utils.git_hash()
            extra_checkpointing['__author__'] = git_hash if git_hash else 'N/A'

        self.checkpointer = Checkpointer(
            self.model,
            experiment_directory,
            save_to_disk=communication.is_main_process(),
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            **extra_checkpointing)

        # Load checkpoint
        start_iter = 0
        if resume:
            self.logger.info('Attempting to resume...')
            # This changes the model inplace
            checkpoint = self.checkpointer.load(
                iteration='latest',
                checkpointable_objects=['amp']
                if self.mixed_precision > 0 else [])
            if not checkpoint:
                self.logger.info('No checkpoint found. Starting from scratch.')
            else:
                start_iter = checkpoint['iteration'] + 1
                self.logger.info(f'Starting from iteration: {start_iter}.')

            if '__author__' in checkpoint:
                self.logger.info(
                    f"Git hash of checkpoint: {checkpoint['__author__']}")
                if checkpoint['__author__'] != direct.utils.git_hash():
                    self.logger.warning(
                        f"Current git hash {direct.utils.git_hash()} is different from the one "
                        f"this checkpoint is saved with ({checkpoint['__author__']}. This can be fine, "
                        f"but beware that this can be a source of confusion.")

            if '__datetime__' in checkpoint:
                self.logger.info(
                    f"Checkpoint created at: {checkpoint['__datetime__']}")
            if 'opt_level' in checkpoint:
                if checkpoint['opt_level'] != opt_level:
                    self.logger.warning(
                        f"Mixed precision opt-levels do not match. "
                        f"Requested {opt_level} got {checkpoint['opt_level']} from checkpoint. "
                        f"This will almost surely lead to performance degradation."
                    )

        self.logger.info(f'World size: {communication.get_world_size()}.')
        if communication.get_world_size() > 1:
            self.model = DistributedDataParallel(
                self.model,
                device_ids=[communication.get_rank()],
                broadcast_buffers=False)

        # World size > 1 if distributed mode, else allow a DataParallel fallback, can be convenient for debugging.
        elif torch.cuda.device_count() > 1 and communication.get_world_size(
        ) == 1:
            self.model = DataParallel(self.model)

        self.__writers = ([
            JSONWriter(experiment_directory / 'metrics.json'),
            CommonMetricPrinter(self.cfg.training.num_iterations),
            TensorboardWriter(experiment_directory / 'tensorboard')
        ] if communication.is_main_process() else [])

        with EventStorage(start_iter):
            self.training_loop(training_loader, start_iter, validation_loader)

        self.logger.info('Training completed.')
Exemplo n.º 4
0
    def train(
        self,
        optimizer: torch.optim.Optimizer,
        lr_scheduler: torch.optim.lr_scheduler._LRScheduler,  # noqa
        training_datasets: List[Dataset],
        experiment_directory: pathlib.Path,
        validation_data: Optional[Dataset] = None,
        resume: bool = False,
        initialization: Optional[PathOrString] = None,
        num_workers: int = 6,
    ) -> None:
        self.logger.info("Starting training.")
        # TODO: Does not need to be member of self.
        self.__optimizer = optimizer
        # TODO: Optimizer and LR scheduler need to be resumed too.
        self.__lr_scheduler = lr_scheduler
        training_data = ConcatDataset(training_datasets)
        self.logger.info(f"Concatenated dataset length: {len(training_data)}.")
        self.logger.info(
            f"Building batch sampler for training set with batch size {self.cfg.training.batch_size}."
        )
        training_sampler = self.build_batch_sampler(
            training_datasets, self.cfg.training.batch_size, "random")
        training_loader = self.build_loader(
            training_data,
            batch_sampler=training_sampler,
            num_workers=num_workers,
        )

        if validation_data:
            validation_loaders = []
            for idx, curr_validation_data in enumerate(validation_data):
                text_dataset_description = curr_validation_data.text_description
                self.logger.info(
                    f"Building dataloader for dataset: {text_dataset_description}."
                )
                curr_batch_sampler = self.build_batch_sampler(
                    curr_validation_data,
                    batch_size=self.cfg.validation.batch_size,
                    sampler_type="sequential",
                    limit_number_of_volumes=None,
                )
                validation_loaders.append((
                    text_dataset_description,
                    self.build_loader(
                        curr_validation_data,
                        batch_sampler=curr_batch_sampler,
                        num_workers=
                        0,  # num_workers, # TODO(jt): This seems to choke the validation.
                    ),
                ))
        else:
            validation_loaders = None

        self.models_to_device()

        # Optimizer
        self.__optimizer.zero_grad()  # type: ignore

        # Mixed precision setup. This requires the model to be on the gpu.
        git_hash = direct.utils.git_hash()
        extra_checkpointing = {
            "__author__": git_hash if git_hash else "N/A",
            "__version__": direct.__version__,
            "__mixed_precision__": self.mixed_precision,
        }
        if self.mixed_precision:
            # TODO(jt): Check if on GPU
            self.logger.info(f"Using mixed precision training.")

        self.checkpointer = Checkpointer(
            self.model,
            experiment_directory,
            save_to_disk=communication.is_main_process(),
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            scaler=self._scaler,
            **self.models,
            **extra_checkpointing,
        )

        # Load checkpoint
        start_iter = 0
        checkpoint = {}
        if resume:
            self.logger.info("Attempting to resume...")
            # This changes the model inplace
            checkpoint = self.checkpointer.load(iteration="latest")
            if not checkpoint:
                self.logger.info("No checkpoint found. Starting from scratch.")
            else:
                start_iter = checkpoint["iteration"] + 1
                self.logger.info(f"Starting from iteration: {start_iter}.")

        if start_iter > 0 and initialization:
            self.logger.warning(
                f"Initialization checkpoint set to {initialization},"
                f" but model will resume training from previous checkpoint. Initialization ignored."
            )
        elif initialization:
            self.logger.info(f"Initializing from {initialization}...")
            self.checkpointer.load_from_file(initialization)

        if "__version__" in checkpoint:
            self.logger.info(
                f"DIRECT version of checkpoint: {checkpoint['__version__']}.")
            if checkpoint["__version__"] != direct.__version__:
                self.logger.warning(
                    f"Current DIRECT version {direct.__version__} is different from the one "
                    f"this checkpoint is saved with ({checkpoint['__version__']}. This can be fine, "
                    f"but beware that this can be a source of confusion.")

        if "__author__" in checkpoint:
            self.logger.info(
                f"Git hash of checkpoint: {checkpoint['__author__']}.")
            if checkpoint["__author__"] != direct.utils.git_hash():
                self.logger.warning(
                    f"Current git hash {direct.utils.git_hash()} is different from the one "
                    f"this checkpoint is saved with ({checkpoint['__author__']}. This can be fine, "
                    f"but beware that this can be a source of confusion.")

        if "__datetime__" in checkpoint:
            self.logger.info(
                f"Checkpoint created at: {checkpoint['__datetime__']}.")

        if "__mixed_precision__" in checkpoint:
            if (not self.mixed_precision
                ) and checkpoint["__mixed_precision__"]:
                self.logger.warning(
                    f"Mixed precision training is not enabled, yet saved checkpoint requests this"
                    f"Will now enable mixed precision.")
                self.mixed_precision = True
            elif not checkpoint["__mixed_precision__"] and self.mixed_precision:
                self.logger.warning(
                    f"Mixed precision levels of training and loading checkpoint do not match. "
                    f"Requested mixed precision but checkpoint is saved without. "
                    f"This will almost surely lead to performance degradation."
                )

        self.logger.info(f"World size: {communication.get_world_size()}.")
        self.logger.info(f"Device count: {torch.cuda.device_count()}.")
        if communication.get_world_size() > 1:
            self.model = DistributedDataParallel(
                self.model,
                device_ids=[communication.get_local_rank()],
                broadcast_buffers=False,
            )

        # World size > 1 if distributed mode, else allow a DataParallel fallback, can be convenient for debugging.
        elif torch.cuda.device_count() > 1 and communication.get_world_size(
        ) == 1:
            self.model = DataParallel(self.model)

        self.__writers = ([
            JSONWriter(experiment_directory / "metrics.json"),
            CommonMetricPrinter(self.cfg.training.num_iterations),
            TensorboardWriter(experiment_directory / "tensorboard"),
        ] if communication.is_main_process() else [])

        with EventStorage(start_iter):
            self.training_loop(
                training_loader,
                start_iter,
                validation_loaders,
                experiment_directory=experiment_directory,
            )

        self.logger.info("Training completed.")
Exemplo n.º 5
0
    def train(
        self,
        optimizer: torch.optim.Optimizer,
        lr_scheduler: torch.optim.lr_scheduler._LRScheduler,  # noqa
        training_datasets: List[Dataset],
        experiment_directory: pathlib.Path,
        validation_datasets: Optional[Dataset] = None,
        resume: bool = False,
        start_with_validation: bool = False,
        initialization: Optional[PathOrString] = None,
        num_workers: int = 6,
    ) -> None:
        self.logger.info("Starting training.")
        # Can consider not to make this a member of self, but that requires that optimizer is passed to
        # training_loop()
        self.__optimizer = optimizer
        self.__lr_scheduler = lr_scheduler

        self.models_to_device()

        # Optimizer
        self.__optimizer.zero_grad()  # type: ignore

        # Mixed precision setup. This requires the model to be on the gpu.
        git_hash = direct.utils.git_hash()
        checkpointing_metadata = {
            "__author__": git_hash if git_hash else "N/A",
            "__version__": direct.__version__,
            "__mixed_precision__": self.mixed_precision,
        }
        if self.mixed_precision:
            # TODO(jt): Check if on GPU
            self.logger.info("Using mixed precision training.")

        self.checkpointer = Checkpointer(
            save_directory=experiment_directory,
            save_to_disk=False
            if not communication.is_main_process() else True,
            model=self.model,
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            scaler=self._scaler,
            **checkpointing_metadata,  # type: ignore
            **self.models,  # type: ignore
        )

        # Load checkpoint
        start_iter = 0
        checkpoint = {}

        if resume:
            self.logger.info("Attempting to resume...")
            # This changes the model inplace
            checkpoint = self.checkpointer.load(iteration="latest")
            if not checkpoint:
                self.logger.info("No checkpoint found. Starting from scratch.")
            else:
                start_iter = checkpoint["iteration"] + 1
                self.logger.info(f"Starting from iteration: {start_iter}.")

        if start_iter > 0 and initialization:
            self.logger.warning(
                f"Initialization checkpoint set to {initialization},"
                f" but model will resume training from previous checkpoint. Initialization ignored."
            )
        elif initialization:
            self.logger.info(f"Initializing from {initialization}...")
            self.checkpointer.load_models_from_file(initialization)
            start_with_validation = True
            self.logger.info("Setting start_with_validation to True.")

        if "__version__" in checkpoint:
            self.logger.info(
                f"DIRECT version of checkpoint: {checkpoint['__version__']}.")
            if checkpoint["__version__"] != direct.__version__:
                self.logger.warning(
                    f"Current DIRECT version {direct.__version__} is different from the one "
                    f"this checkpoint is saved with: {checkpoint['__version__']}. This can be fine, "
                    f"but beware that this can be a source of confusion.")

        if "__author__" in checkpoint:
            self.logger.info(
                f"Git hash of checkpoint: {checkpoint['__author__']}.")
            if checkpoint["__author__"] != direct.utils.git_hash():
                self.logger.warning(
                    f"Current git hash {direct.utils.git_hash()} is different from the one "
                    f"this checkpoint is saved with: {checkpoint['__author__']}. This can be fine, "
                    f"but beware that this can be a source of confusion.")

        if "__datetime__" in checkpoint:
            self.logger.info(
                f"Checkpoint created at: {checkpoint['__datetime__']}.")

        if "__mixed_precision__" in checkpoint:
            if (not self.mixed_precision
                ) and checkpoint["__mixed_precision__"]:
                self.logger.warning(
                    "Mixed precision training is not enabled, yet saved checkpoint requests this"
                    f"Will now enable mixed precision.")
                self.mixed_precision = True
            elif not checkpoint["__mixed_precision__"] and self.mixed_precision:
                self.logger.warning(
                    "Mixed precision levels of training and loading checkpoint do not match. "
                    f"Requested mixed precision but checkpoint is saved without. "
                    f"This will almost surely lead to performance degradation."
                )

        if start_with_validation:
            self.logger.info("Requested to start with validation.")

        self.logger.info(f"World size: {communication.get_world_size()}.")
        self.logger.info(f"Device count: {torch.cuda.device_count()}.")
        if communication.get_world_size() > 1:
            self.model = DistributedDataParallel(
                self.model,
                device_ids=[communication.get_local_rank()],
                broadcast_buffers=False,
            )

        # World size > 1 if distributed mode, else allow a DataParallel fallback, can be convenient for debugging.
        elif torch.cuda.device_count() > 1 and communication.get_world_size(
        ) == 1:
            self.model = DataParallel(self.model)

        self.__writers = ([
            JSONWriter(experiment_directory / "metrics.json"),
            CommonMetricPrinter(
                self.cfg.training.num_iterations),  # type: ignore
            TensorboardWriter(experiment_directory / "tensorboard"),
        ] if communication.is_main_process() else [])

        with EventStorage(start_iter):
            self.training_loop(
                training_datasets,
                start_iter,
                validation_datasets,
                experiment_directory=experiment_directory,
                num_workers=num_workers,
                start_with_validation=start_with_validation,
            )

        self.logger.info("Training completed.")