def __call__( self, image, preprocess, ): myrandom = misc.get_random(self.seed) mynprandom = misc.get_nprandom(self.seed) ws = np.float32(mynprandom.dirichlet([1] * self.mixture_width)) m = np.float32(mynprandom.beta(1, 1)) mix = torch.zeros_like(preprocess(image)) for i in range(self.mixture_width): image_aug = image.copy() depth = self.mixture_depth if self.mixture_depth > 0 else myrandom.randint( 1, 4) for _ in range(depth): x = mynprandom.choice(range(0, len(self.augment_list))) op, minval, maxval = self.augment_list[x] val = (float(self.aug_severity) / 30) * float(maxval - minval) + minval image_aug = op(image_aug, val, myrandom=myrandom) # Preprocessing commutes since all coefficients are convex mix += ws[i] * preprocess(image_aug) mixed = (1 - m) * preprocess(image) + m * mix return mixed
def _init_dict_output_mixmo(self, batch_seed): """ Compute MixMo block variables (masks, lams) and prepare it as a dictionary output """ # Get MixMo mixing method and the corresponding masks/lams mixmo_mix_method = self.get_mixmo_mix_method_at_ratio_epoch( batch_seed=batch_seed ) mixmo_lams = misc.sample_lams(self.mixmo_alpha, n=self.num_members) mixmo_masks, mixmo_lams = mixing_blocks.mix( method=mixmo_mix_method, lams=mixmo_lams, input_size=self.properties("conv1_input_size"), ) # Shuffle the roles of the inputs (same for every sample in the batch) # Mostly useful for asymmetrical mixing (CutMix, ...) assert batch_seed is not None myrandom = misc.get_random(seed=batch_seed+config.cfg.RANDOM.SEED_OFFSET_MIXMO) zipped_masking = list(zip(mixmo_lams, mixmo_masks)) myrandom.shuffle(zipped_masking) # Format everything nicely in dictionaries dict_output = {"metadata": {"mixmo_lams": [el[0] for el in zipped_masking], "mixmo_masks": [el[1] for el in zipped_masking]}} if mixmo_mix_method not in mixing_blocks.LIST_METHODS_NOT_INVARIANT_CHANNELS: dict_output["metadata"]["mixmo_masks"] = [ mimo_mix_mask[:1, :, :].to(torch.float16) for mimo_mix_mask in dict_output["metadata"]["mixmo_masks"]] return dict_output
def call_msda(self, index_0, mixmo_mask=None, seed_da=None): """ Get two samples and mix them. Return a dictionary of sample and label """ # Gather the two image/label pairs used by the augmentation pixels_0, target_0 = self.call_dataset(index_0, seed=seed_da) skip_msda = (self.msda_mix_method is None or not misc.random_lower_than(self.msda_prob)) if skip_msda: # Early exit if we are not mixing return pixels_0, target_0 index_1 = misc.get_random(seed=None).choice(range(len(self))) pixels_1, target_1 = self.call_dataset(index_1, seed=seed_da) targets = [target_0, target_1] # Get mixing masks msda_lams = misc.sample_lams(self.msda_beta, n=2) msda_masks, msda_lams = mixing_blocks.mix( method=self.msda_mix_method, lams=msda_lams, input_size=pixels_0.size(), ) # Adjust the lams to account for later mixmo mixing that might alter masks if mixmo_mask is not None: ## approx for computational issues: mask should be symmetrical in channels mixmo_mask_0 = mixmo_mask[0, :, :] if self.properties("conv1_is_half_size"): _msda_mask_0 = torch.nn.AvgPool2d(kernel_size=(2, 2))( msda_masks[0][:1, :, :]) msda_masks_for_lam = [_msda_mask_0.to(torch.float16)] else: mixmo_mask_0 = mixmo_mask_0.to(torch.float32) msda_masks_for_lam = msda_masks ## Compute the adjusted ratios after mixmo mixing mean_mixmo_mask_0 = mixmo_mask_0.mean() msda_lams = [(msda_mask[0, :, :] * mixmo_mask_0).mean() / (mean_mixmo_mask_0 + 1e-8) for msda_mask in msda_masks_for_lam] if self.properties("conv1_is_half_size"): lam = msda_lams[0].to(torch.float32) msda_lams = [lam, 1 - lam] # Randomly reverse the roles of mixed samples (important to symmetrize CutMix, Patch-Up, ...) if self.reverse_if_first_minor and msda_lams[0] < 0.5: msda_pixels = msda_masks[1] * pixels_0 + msda_masks[0] * pixels_1 msda_lams = [msda_lams[1], msda_lams[0]] else: msda_pixels = msda_masks[0] * pixels_0 + msda_masks[1] * pixels_1 # Standard MSDA label interpolation msda_targets = sum( [lam * target for lam, target in zip(msda_lams, targets)]) return msda_pixels, msda_targets