Beispiel #1
0
    def _load_image_sequence(self, segment: AbstractSegment) -> torch.Tensor:
        cache_directory = self.dataset_directory + "/segment_image_tensor_cache"

        self._create_cache_dir(cache_directory)

        try:
            with ThreadingTimeout(2.0) as timeout_ctx1:
                images = torch.load("{}/{}.pkl".format(cache_directory, segment.__hash__()))

            if not bool(timeout_ctx1):
                CometLogger.print('Took too long when loading a cache image. '
                                  'We will load the image directly form the dataset instead.')
                raise Exception()
        except:
            image_sequence = []

            with ThreadingTimeout(3600.0) as timeout_ctx2:
                for img_as_img in segment.get_images():
                    img_as_tensor = self.transformer(img_as_img)
                    if self.minus_point_5:
                        img_as_tensor = img_as_tensor - 0.5  # from [0, 1] -> [-0.5, 0.5]
                    img_as_tensor = self.normalizer(img_as_tensor)
                    img_as_tensor = img_as_tensor.unsqueeze(0)
                    image_sequence.append(img_as_tensor)
                images = torch.cat(image_sequence, 0)
            if not bool(timeout_ctx2):
                CometLogger.fatalprint('Encountered fatal delay when reading the uncached images from the dataset')

            free = -1
            try:
                with ThreadingTimeout(2.0) as timeout_ctx3:
                    _, _, free = shutil.disk_usage(cache_directory)
                if not bool(timeout_ctx3):
                    CometLogger.print('Took too long to measure disk space. Skipping caching.')

            except Exception as e:
                print("Warning: unable to cache the segment's image tensor, there was an error while getting "
                      "disk usage: {}".format(e), file=sys.stderr)

            if free == -1:
                pass
            elif free // (2**30) > 1:
                try:
                    with ThreadingTimeout(5.0) as timeout_ctx4:
                        torch.save(images, "{}/{}.pkl".format(cache_directory, segment.__hash__()))

                    if not bool(timeout_ctx4):
                        CometLogger.print('Took too long when saving to cache folder. Deadlock possible. Skipping caching.')

                except Exception as e:
                    print("Warning: unable to cache the segment's image tensor: {}".format(e), file=sys.stderr)
            else:
                pass

        if self.augment_dataset:
            images = self._augment_image_sequence(images)

        return images
Beispiel #2
0
    def __getitem__(self, item: int):
        with ThreadingTimeout(3600.0) as timeout_ctx1:
            try:
                segment, image_sequence = super().__getitem__(item)
            except Exception as e:
                CometLogger.print(str(e))
                raise e
        if not bool(timeout_ctx1):
            CometLogger.fatalprint(
                'Encountered fatal delay while getting the image sequence')

        with ThreadingTimeout(3600.0) as timeout_ctx2:
            pose = self._get_segment_pose(segment)
        if not bool(timeout_ctx2):
            CometLogger.fatalprint(
                'Encountered fatal delay while getting the pose of the sequence'
            )

        return image_sequence, pose
    def _train(self) -> tuple:
        timer_start_time = time.time()

        self.model.train()

        losses_sum = 0
        benchmark_losses_sum = 0

        for i, (input, target) in enumerate(self.train_dataloader):
            CometLogger.get_experiment().log_metric("Current batch", i + 1)
            CometLogger.get_experiment().log_metric("Total nbr of batches",
                                                    len(self.train_dataloader))

            # Only log this if we are NOT in a multiprocessing session
            if CometLogger.gpu_id is None:
                print("--> processing batch {}/{} of size {}".format(
                    i + 1, len(self.train_dataloader), len(input)))
            if cuda_is_available():
                with ThreadingTimeout(14400.0) as timeout_ctx1:
                    input = input.cuda(
                        non_blocking=self.train_dataloader.pin_memory)
                    target = target.cuda(
                        non_blocking=self.train_dataloader.pin_memory)
                if not bool(timeout_ctx1):
                    CometLogger.fatalprint(
                        'Encountered fatally long delay when moving tensors to GPUs'
                    )

            prediction = self.model.forward(input)

            with ThreadingTimeout(14400.0) as timeout_ctx3:
                if type(prediction) is tuple:
                    benchmark_loss = self.benchmark_MSE_loss.compute(
                        prediction[0], target)
                else:
                    benchmark_loss = self.benchmark_MSE_loss.compute(
                        prediction, target)
            if not bool(timeout_ctx3):
                CometLogger.fatalprint(
                    'Encountered fatally long delay during computation of benchmark loss'
                )

            with ThreadingTimeout(14400.0) as timeout_ctx4:
                benchmark_losses_sum += float(
                    benchmark_loss.data.cpu().numpy())
            if not bool(timeout_ctx4):
                CometLogger.fatalprint(
                    'Encountered fatally long delay during summation of benchmark losses'
                )

            with ThreadingTimeout(14400.0) as timeout_ctx4:
                loss = self.custom_loss.compute(prediction, target)
            if not bool(timeout_ctx4):
                CometLogger.fatalprint(
                    'Encountered fatally long delay during computation of the custom loss'
                )

            self._backpropagate(loss)

            with ThreadingTimeout(14400.0) as timeout_ctx6:
                losses_sum += float(loss.data.cpu().numpy())
            if not bool(timeout_ctx6):
                CometLogger.fatalprint(
                    'Encountered fatally long delay during loss addition')

        timer_end_time = time.time()

        CometLogger.get_experiment().log_metric(
            "Epoch training time", timer_end_time - timer_start_time)

        return losses_sum / len(
            self.train_dataloader), benchmark_losses_sum / len(
                self.train_dataloader)