def _load_image_sequence(self, segment: AbstractSegment) -> torch.Tensor: cache_directory = self.dataset_directory + "/segment_image_tensor_cache" self._create_cache_dir(cache_directory) try: with ThreadingTimeout(2.0) as timeout_ctx1: images = torch.load("{}/{}.pkl".format(cache_directory, segment.__hash__())) if not bool(timeout_ctx1): CometLogger.print('Took too long when loading a cache image. ' 'We will load the image directly form the dataset instead.') raise Exception() except: image_sequence = [] with ThreadingTimeout(3600.0) as timeout_ctx2: for img_as_img in segment.get_images(): img_as_tensor = self.transformer(img_as_img) if self.minus_point_5: img_as_tensor = img_as_tensor - 0.5 # from [0, 1] -> [-0.5, 0.5] img_as_tensor = self.normalizer(img_as_tensor) img_as_tensor = img_as_tensor.unsqueeze(0) image_sequence.append(img_as_tensor) images = torch.cat(image_sequence, 0) if not bool(timeout_ctx2): CometLogger.fatalprint('Encountered fatal delay when reading the uncached images from the dataset') free = -1 try: with ThreadingTimeout(2.0) as timeout_ctx3: _, _, free = shutil.disk_usage(cache_directory) if not bool(timeout_ctx3): CometLogger.print('Took too long to measure disk space. Skipping caching.') except Exception as e: print("Warning: unable to cache the segment's image tensor, there was an error while getting " "disk usage: {}".format(e), file=sys.stderr) if free == -1: pass elif free // (2**30) > 1: try: with ThreadingTimeout(5.0) as timeout_ctx4: torch.save(images, "{}/{}.pkl".format(cache_directory, segment.__hash__())) if not bool(timeout_ctx4): CometLogger.print('Took too long when saving to cache folder. Deadlock possible. Skipping caching.') except Exception as e: print("Warning: unable to cache the segment's image tensor: {}".format(e), file=sys.stderr) else: pass if self.augment_dataset: images = self._augment_image_sequence(images) return images
def __getitem__(self, item: int): with ThreadingTimeout(3600.0) as timeout_ctx1: try: segment, image_sequence = super().__getitem__(item) except Exception as e: CometLogger.print(str(e)) raise e if not bool(timeout_ctx1): CometLogger.fatalprint( 'Encountered fatal delay while getting the image sequence') with ThreadingTimeout(3600.0) as timeout_ctx2: pose = self._get_segment_pose(segment) if not bool(timeout_ctx2): CometLogger.fatalprint( 'Encountered fatal delay while getting the pose of the sequence' ) return image_sequence, pose
def _train(self) -> tuple: timer_start_time = time.time() self.model.train() losses_sum = 0 benchmark_losses_sum = 0 for i, (input, target) in enumerate(self.train_dataloader): CometLogger.get_experiment().log_metric("Current batch", i + 1) CometLogger.get_experiment().log_metric("Total nbr of batches", len(self.train_dataloader)) # Only log this if we are NOT in a multiprocessing session if CometLogger.gpu_id is None: print("--> processing batch {}/{} of size {}".format( i + 1, len(self.train_dataloader), len(input))) if cuda_is_available(): with ThreadingTimeout(14400.0) as timeout_ctx1: input = input.cuda( non_blocking=self.train_dataloader.pin_memory) target = target.cuda( non_blocking=self.train_dataloader.pin_memory) if not bool(timeout_ctx1): CometLogger.fatalprint( 'Encountered fatally long delay when moving tensors to GPUs' ) prediction = self.model.forward(input) with ThreadingTimeout(14400.0) as timeout_ctx3: if type(prediction) is tuple: benchmark_loss = self.benchmark_MSE_loss.compute( prediction[0], target) else: benchmark_loss = self.benchmark_MSE_loss.compute( prediction, target) if not bool(timeout_ctx3): CometLogger.fatalprint( 'Encountered fatally long delay during computation of benchmark loss' ) with ThreadingTimeout(14400.0) as timeout_ctx4: benchmark_losses_sum += float( benchmark_loss.data.cpu().numpy()) if not bool(timeout_ctx4): CometLogger.fatalprint( 'Encountered fatally long delay during summation of benchmark losses' ) with ThreadingTimeout(14400.0) as timeout_ctx4: loss = self.custom_loss.compute(prediction, target) if not bool(timeout_ctx4): CometLogger.fatalprint( 'Encountered fatally long delay during computation of the custom loss' ) self._backpropagate(loss) with ThreadingTimeout(14400.0) as timeout_ctx6: losses_sum += float(loss.data.cpu().numpy()) if not bool(timeout_ctx6): CometLogger.fatalprint( 'Encountered fatally long delay during loss addition') timer_end_time = time.time() CometLogger.get_experiment().log_metric( "Epoch training time", timer_end_time - timer_start_time) return losses_sum / len( self.train_dataloader), benchmark_losses_sum / len( self.train_dataloader)