Beispiel #1
0
    def __init__(self,
                 dataset,
                 num_replicas=None,
                 rank=None,
                 limit_number_of_volumes=None):
        if num_replicas is None:
            num_replicas = communication.get_world_size()
        if rank is None:
            rank = communication.get_rank()
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank

        filenames = list(
            self.dataset.volume_indices.keys())  # This is an OrderedDict
        if limit_number_of_volumes:
            filenames = filenames[:limit_number_of_volumes]

        chunked_filenames = list(chunks(filenames, self.num_replicas))
        filenames = chunked_filenames[self.rank]

        # Collect the indices belonging to these filenames.
        self.indices = []
        if self.rank < len(
                chunked_filenames):  # Otherwise there is nothing to fill.
            for filename in filenames:
                self.indices.extend(list(
                    self.dataset.volume_indices[filename]))
Beispiel #2
0
    def __init__(
        self,
        size: int,
        shuffle: bool = True,
        seed: Optional[int] = None,
    ):
        """
        Parameters
        ----------
        size : int
            Size of underlying dataset.
        shuffle : bool
            If true, the indices will be shuffled.
        seed : int
            Initial seed of the shuffle, must be the same across all workers!
        """
        self._size = size
        if self._size <= 0:
            raise AssertionError
        self._shuffle = shuffle
        if seed is None:
            seed = communication.shared_random_seed()
        self._seed = int(seed)

        self._rank = communication.get_rank()
        self._world_size = communication.get_world_size()
Beispiel #3
0
    def train(
            self,
            optimizer: torch.optim.Optimizer,
            lr_scheduler: torch.optim.lr_scheduler._LRScheduler,  # noqa
            training_data: Dataset,
            experiment_directory: pathlib.Path,
            validation_data: Dataset = None,
            resume: bool = False,
            num_workers: int = 0) -> None:

        # TODO: Does not need to be member of self.
        self.__optimizer = optimizer
        # TODO: Optimizer and LR scheduler need to be resumed too.
        self.__lr_scheduler = lr_scheduler

        training_sampler = self.build_sampler(training_data, 'random')
        # TODO: Configurable
        training_loader = self.build_loader(
            training_data,
            sampler=training_sampler,
            batch_size=self.cfg.training.batch_size,
            num_workers=num_workers,
            drop_last=True)

        if validation_data:
            validation_sampler = self.build_sampler(
                validation_data, 'sequential', limit_number_of_volumes=None)
            batch_sampler = BatchSampler(validation_sampler,
                                         batch_size=8 *
                                         self.cfg.training.batch_size,
                                         drop_last=False)
            # TODO: Batch size can be much larger, perhaps have a different batch size during evaluation.
            validation_loader = self.build_loader(
                validation_data,
                batch_sampler=batch_sampler,
                num_workers=num_workers,
            )

        self.model = self.model.to(self.device)

        # Optimizer
        self.__optimizer.zero_grad()  # type: ignore

        # Mixed precision setup. This requires the model to be on the gpu.
        extra_checkpointing = {}
        if self.mixed_precision > 0:
            opt_level = f'O{self.mixed_precision}'
            self.logger.info(f'Using apex level {opt_level}.')
            self.model, self.__optimizer = amp.initialize(self.model,
                                                          self.__optimizer,
                                                          opt_level=opt_level)
            extra_checkpointing['amp'] = amp
            extra_checkpointing['opt_level'] = opt_level
            git_hash = direct.utils.git_hash()
            extra_checkpointing['__author__'] = git_hash if git_hash else 'N/A'

        self.checkpointer = Checkpointer(
            self.model,
            experiment_directory,
            save_to_disk=communication.is_main_process(),
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            **extra_checkpointing)

        # Load checkpoint
        start_iter = 0
        if resume:
            self.logger.info('Attempting to resume...')
            # This changes the model inplace
            checkpoint = self.checkpointer.load(
                iteration='latest',
                checkpointable_objects=['amp']
                if self.mixed_precision > 0 else [])
            if not checkpoint:
                self.logger.info('No checkpoint found. Starting from scratch.')
            else:
                start_iter = checkpoint['iteration'] + 1
                self.logger.info(f'Starting from iteration: {start_iter}.')

            if '__author__' in checkpoint:
                self.logger.info(
                    f"Git hash of checkpoint: {checkpoint['__author__']}")
                if checkpoint['__author__'] != direct.utils.git_hash():
                    self.logger.warning(
                        f"Current git hash {direct.utils.git_hash()} is different from the one "
                        f"this checkpoint is saved with ({checkpoint['__author__']}. This can be fine, "
                        f"but beware that this can be a source of confusion.")

            if '__datetime__' in checkpoint:
                self.logger.info(
                    f"Checkpoint created at: {checkpoint['__datetime__']}")
            if 'opt_level' in checkpoint:
                if checkpoint['opt_level'] != opt_level:
                    self.logger.warning(
                        f"Mixed precision opt-levels do not match. "
                        f"Requested {opt_level} got {checkpoint['opt_level']} from checkpoint. "
                        f"This will almost surely lead to performance degradation."
                    )

        self.logger.info(f'World size: {communication.get_world_size()}.')
        if communication.get_world_size() > 1:
            self.model = DistributedDataParallel(
                self.model,
                device_ids=[communication.get_rank()],
                broadcast_buffers=False)

        # World size > 1 if distributed mode, else allow a DataParallel fallback, can be convenient for debugging.
        elif torch.cuda.device_count() > 1 and communication.get_world_size(
        ) == 1:
            self.model = DataParallel(self.model)

        self.__writers = ([
            JSONWriter(experiment_directory / 'metrics.json'),
            CommonMetricPrinter(self.cfg.training.num_iterations),
            TensorboardWriter(experiment_directory / 'tensorboard')
        ] if communication.is_main_process() else [])

        with EventStorage(start_iter):
            self.training_loop(training_loader, start_iter, validation_loader)

        self.logger.info('Training completed.')