def generate_init_samples(self, im: torch.Tensor) -> TensorList: """Perform data augmentation to generate initial training samples.""" if getattr(self.params, 'border_mode', 'replicate') == 'inside': # Get new sample size if forced inside the image im_sz = torch.Tensor([im.shape[2], im.shape[3]]) sample_sz = self.target_scale * self.img_sample_sz shrink_factor = (sample_sz.float() / im_sz).max().clamp(1) sample_sz = (sample_sz.float() / shrink_factor) self.init_sample_scale = (sample_sz / self.img_sample_sz).prod().sqrt() tl = self.pos - (sample_sz - 1) / 2 br = self.pos + sample_sz / 2 + 1 global_shift = - ((-tl).clamp(0) - (br - im_sz).clamp(0)) / self.init_sample_scale else: self.init_sample_scale = self.target_scale global_shift = torch.zeros(2) self.init_sample_pos = self.pos.round() # Compute augmentation size aug_expansion_factor = getattr(self.params, 'augmentation_expansion_factor', None) aug_expansion_sz = self.img_sample_sz.clone() aug_output_sz = None if aug_expansion_factor is not None and aug_expansion_factor != 1: aug_expansion_sz = (self.img_sample_sz * aug_expansion_factor).long() aug_expansion_sz += (aug_expansion_sz - self.img_sample_sz.long()) % 2 aug_expansion_sz = aug_expansion_sz.float() aug_output_sz = self.img_sample_sz.long().tolist() # Random shift for each sample get_rand_shift = lambda: None random_shift_factor = getattr(self.params, 'random_shift_factor', 0) if random_shift_factor > 0: get_rand_shift = lambda: ((torch.rand(2) - 0.5) * self.img_sample_sz * random_shift_factor + global_shift).long().tolist() # Always put identity transformation first, since it is the unaugmented sample that is always used self.transforms = [augmentation.Identity(aug_output_sz, global_shift.long().tolist())] augs = self.params.augmentation if getattr(self.params, 'use_augmentation', True) else {} # Add all augmentations if 'shift' in augs: self.transforms.extend([augmentation.Translation(shift, aug_output_sz, global_shift.long().tolist()) for shift in augs['shift']]) if 'relativeshift' in augs: get_absolute = lambda shift: (torch.Tensor(shift) * self.img_sample_sz/2).long().tolist() self.transforms.extend([augmentation.Translation(get_absolute(shift), aug_output_sz, global_shift.long().tolist()) for shift in augs['relativeshift']]) if 'fliplr' in augs and augs['fliplr']: self.transforms.append(augmentation.FlipHorizontal(aug_output_sz, get_rand_shift())) if 'blur' in augs: self.transforms.extend([augmentation.Blur(sigma, aug_output_sz, get_rand_shift()) for sigma in augs['blur']]) if 'scale' in augs: self.transforms.extend([augmentation.Scale(scale_factor, aug_output_sz, get_rand_shift()) for scale_factor in augs['scale']]) if 'rotate' in augs: self.transforms.extend([augmentation.Rotate(angle, aug_output_sz, get_rand_shift()) for angle in augs['rotate']]) # Extract augmented image patches im_patches = sample_patch_transformed(im, self.init_sample_pos, self.init_sample_scale, aug_expansion_sz, self.transforms) # Extract initial backbone features with torch.no_grad(): init_backbone_feat = self.net.extract_backbone(im_patches) return init_backbone_feat
def generate_init_samples(self, im: torch.Tensor, init_mask): """ Generate initial training sample.""" mode = self.params.get('border_mode', 'replicate') if 'inside' in mode: # Get new sample size if forced inside the image im_sz = torch.Tensor([im.shape[2], im.shape[3]]) sample_sz = self.target_scale * self.img_sample_sz shrink_factor = (sample_sz.float() / im_sz) if mode == 'inside': shrink_factor = shrink_factor.max() elif mode == 'inside_major': shrink_factor = shrink_factor.min() shrink_factor.clamp_(min=1, max=self.params.get('patch_max_scale_change', None)) sample_sz = (sample_sz.float() / shrink_factor) init_sample_scale = (sample_sz / self.img_sample_sz).prod().sqrt() tl = self.pos - (sample_sz - 1) / 2 br = self.pos + sample_sz / 2 + 1 global_shift = -((-tl).clamp(0) - (br - im_sz).clamp(0)) / init_sample_scale else: init_sample_scale = self.target_scale global_shift = torch.zeros(2) init_sample_pos = self.pos.round() # Compute augmentation size aug_expansion_factor = 2.0 aug_expansion_sz = self.img_sample_sz.clone() aug_output_sz = None if aug_expansion_factor is not None and aug_expansion_factor != 1: aug_expansion_sz = (self.img_sample_sz * aug_expansion_factor).long() aug_expansion_sz += (aug_expansion_sz - self.img_sample_sz.long()) % 2 aug_expansion_sz = aug_expansion_sz.float() aug_output_sz = self.img_sample_sz.long().tolist() # Can be extended to include data augmentation on the initial frame self.transforms = [ augmentation.Identity(aug_output_sz, global_shift.long().tolist()) ] # Extract image patches im_patches = sample_patch_transformed(im, init_sample_pos, init_sample_scale, aug_expansion_sz, self.transforms) init_masks = sample_patch_transformed(init_mask, init_sample_pos, init_sample_scale, aug_expansion_sz, self.transforms, is_mask=True) init_masks = init_masks.to(self.params.device) # Extract initial backbone features with torch.no_grad(): init_backbone_feat = self.net.extract_backbone(im_patches) return init_backbone_feat, init_masks