def get_mixmo_mix_method_at_ratio_epoch(self, batch_seed=None): """ Select which mixing method should be used according to training scheduling. Procedure: Select self.dict_mixmo_mix_method["method_name"] with proba self.dict_mixmo_mix_method["prob"] that is linearly decreased towards 0 after 11/12 of training process. Otherwise, use self.dict_mixmo_mix_method["replacement_method_name"] (in general mixup) """ method = self.dict_mixmo_mix_method["method_name"] replacement_method = self.dict_mixmo_mix_method[ "replacement_method_name"] if method == replacement_method: return method # Check the actual switch probability according to scheduler and current epoch default_prob = self.dict_mixmo_mix_method["prob"] if self.ratio_epoch_current < config.cfg.RATIO_EPOCH_DECREASE: prob = default_prob else: eta = max(0, (1 - self.ratio_epoch_current) / (1 - config.cfg.RATIO_EPOCH_DECREASE)) prob = default_prob * eta # Choose the method depending on draw result if misc.random_lower_than(prob, seed=batch_seed): return method return replacement_method
def call_msda(self, index_0, mixmo_mask=None, seed_da=None): """ Get two samples and mix them. Return a dictionary of sample and label """ # Gather the two image/label pairs used by the augmentation pixels_0, target_0 = self.call_dataset(index_0, seed=seed_da) skip_msda = (self.msda_mix_method is None or not misc.random_lower_than(self.msda_prob)) if skip_msda: # Early exit if we are not mixing return pixels_0, target_0 index_1 = misc.get_random(seed=None).choice(range(len(self))) pixels_1, target_1 = self.call_dataset(index_1, seed=seed_da) targets = [target_0, target_1] # Get mixing masks msda_lams = misc.sample_lams(self.msda_beta, n=2) msda_masks, msda_lams = mixing_blocks.mix( method=self.msda_mix_method, lams=msda_lams, input_size=pixels_0.size(), ) # Adjust the lams to account for later mixmo mixing that might alter masks if mixmo_mask is not None: ## approx for computational issues: mask should be symmetrical in channels mixmo_mask_0 = mixmo_mask[0, :, :] if self.properties("conv1_is_half_size"): _msda_mask_0 = torch.nn.AvgPool2d(kernel_size=(2, 2))( msda_masks[0][:1, :, :]) msda_masks_for_lam = [_msda_mask_0.to(torch.float16)] else: mixmo_mask_0 = mixmo_mask_0.to(torch.float32) msda_masks_for_lam = msda_masks ## Compute the adjusted ratios after mixmo mixing mean_mixmo_mask_0 = mixmo_mask_0.mean() msda_lams = [(msda_mask[0, :, :] * mixmo_mask_0).mean() / (mean_mixmo_mask_0 + 1e-8) for msda_mask in msda_masks_for_lam] if self.properties("conv1_is_half_size"): lam = msda_lams[0].to(torch.float32) msda_lams = [lam, 1 - lam] # Randomly reverse the roles of mixed samples (important to symmetrize CutMix, Patch-Up, ...) if self.reverse_if_first_minor and msda_lams[0] < 0.5: msda_pixels = msda_masks[1] * pixels_0 + msda_masks[0] * pixels_1 msda_lams = [msda_lams[1], msda_lams[0]] else: msda_pixels = msda_masks[0] * pixels_0 + msda_masks[1] * pixels_1 # Standard MSDA label interpolation msda_targets = sum( [lam * target for lam, target in zip(msda_lams, targets)]) return msda_pixels, msda_targets
def _stack_mask(input_size, lam, config_mix): """ Compute masks for Channel/Horizontal/Vertical concat (number of images) x channel x (image width) x (image height) """ # Default config misc.ifnotfound_update(config_mix, { "stack_dim": 1, "stack_rdflip": True, }) dim = config_mix["stack_dim"] random_flip = config_mix["stack_rdflip"] flip = random_flip and misc.random_lower_than( prob=0.5, seed=None, r=None) if flip: lam = 1 - lam # Split the dimension in two border = int(lam * input_size[dim]) ones_size = list(input_size) ones_size[dim] = border zeros_size = list(input_size) zeros_size[dim] = input_size[dim] - border ones_mask = torch.ones(ones_size) zeros_mask = torch.zeros(zeros_size) # Merge the two split masks if flip: mask = torch.cat([zeros_mask, ones_mask], dim=dim) else: mask = torch.cat([ones_mask, zeros_mask], dim=dim) return mask