예제 #1
0
    def __call__(
        self,
        image,
        preprocess,
    ):
        myrandom = misc.get_random(self.seed)
        mynprandom = misc.get_nprandom(self.seed)

        ws = np.float32(mynprandom.dirichlet([1] * self.mixture_width))
        m = np.float32(mynprandom.beta(1, 1))

        mix = torch.zeros_like(preprocess(image))
        for i in range(self.mixture_width):
            image_aug = image.copy()
            depth = self.mixture_depth if self.mixture_depth > 0 else myrandom.randint(
                1, 4)
            for _ in range(depth):
                x = mynprandom.choice(range(0, len(self.augment_list)))
                op, minval, maxval = self.augment_list[x]
                val = (float(self.aug_severity) / 30) * float(maxval -
                                                              minval) + minval
                image_aug = op(image_aug, val, myrandom=myrandom)

            # Preprocessing commutes since all coefficients are convex
            mix += ws[i] * preprocess(image_aug)

        mixed = (1 - m) * preprocess(image) + m * mix
        return mixed
예제 #2
0
    def _init_dict_output_mixmo(self, batch_seed):
        """
        Compute MixMo block variables (masks, lams) and prepare it as a dictionary output
        """
        # Get MixMo mixing method and the corresponding masks/lams
        mixmo_mix_method = self.get_mixmo_mix_method_at_ratio_epoch(
            batch_seed=batch_seed
        )
        mixmo_lams = misc.sample_lams(self.mixmo_alpha, n=self.num_members)
        mixmo_masks, mixmo_lams = mixing_blocks.mix(
            method=mixmo_mix_method,
            lams=mixmo_lams,
            input_size=self.properties("conv1_input_size"),
        )

        # Shuffle the roles of the inputs (same for every sample in the batch)
        # Mostly useful for asymmetrical mixing (CutMix, ...)
        assert batch_seed is not None
        myrandom = misc.get_random(seed=batch_seed+config.cfg.RANDOM.SEED_OFFSET_MIXMO)
        zipped_masking = list(zip(mixmo_lams, mixmo_masks))
        myrandom.shuffle(zipped_masking)

        # Format everything nicely in dictionaries
        dict_output = {"metadata": {"mixmo_lams": [el[0] for el in zipped_masking], "mixmo_masks": [el[1] for el in zipped_masking]}}
        if mixmo_mix_method not in mixing_blocks.LIST_METHODS_NOT_INVARIANT_CHANNELS:
            dict_output["metadata"]["mixmo_masks"] = [
                mimo_mix_mask[:1, :, :].to(torch.float16)
                for mimo_mix_mask in dict_output["metadata"]["mixmo_masks"]]

        return dict_output
예제 #3
0
    def call_msda(self, index_0, mixmo_mask=None, seed_da=None):
        """
        Get two samples and mix them. Return a dictionary of sample and label
        """
        # Gather the two image/label pairs used by the augmentation
        pixels_0, target_0 = self.call_dataset(index_0, seed=seed_da)
        skip_msda = (self.msda_mix_method is None
                     or not misc.random_lower_than(self.msda_prob))
        if skip_msda:  # Early exit if we are not mixing
            return pixels_0, target_0

        index_1 = misc.get_random(seed=None).choice(range(len(self)))
        pixels_1, target_1 = self.call_dataset(index_1, seed=seed_da)

        targets = [target_0, target_1]

        # Get mixing masks
        msda_lams = misc.sample_lams(self.msda_beta, n=2)
        msda_masks, msda_lams = mixing_blocks.mix(
            method=self.msda_mix_method,
            lams=msda_lams,
            input_size=pixels_0.size(),
        )

        # Adjust the lams to account for later mixmo mixing that might alter masks
        if mixmo_mask is not None:
            ## approx for computational issues: mask should be symmetrical in channels
            mixmo_mask_0 = mixmo_mask[0, :, :]

            if self.properties("conv1_is_half_size"):
                _msda_mask_0 = torch.nn.AvgPool2d(kernel_size=(2, 2))(
                    msda_masks[0][:1, :, :])
                msda_masks_for_lam = [_msda_mask_0.to(torch.float16)]
            else:
                mixmo_mask_0 = mixmo_mask_0.to(torch.float32)
                msda_masks_for_lam = msda_masks

            ## Compute the adjusted ratios after mixmo mixing
            mean_mixmo_mask_0 = mixmo_mask_0.mean()
            msda_lams = [(msda_mask[0, :, :] * mixmo_mask_0).mean() /
                         (mean_mixmo_mask_0 + 1e-8)
                         for msda_mask in msda_masks_for_lam]

            if self.properties("conv1_is_half_size"):
                lam = msda_lams[0].to(torch.float32)
                msda_lams = [lam, 1 - lam]

        # Randomly reverse the roles of mixed samples (important to symmetrize CutMix, Patch-Up, ...)
        if self.reverse_if_first_minor and msda_lams[0] < 0.5:
            msda_pixels = msda_masks[1] * pixels_0 + msda_masks[0] * pixels_1
            msda_lams = [msda_lams[1], msda_lams[0]]
        else:
            msda_pixels = msda_masks[0] * pixels_0 + msda_masks[1] * pixels_1

        # Standard MSDA label interpolation
        msda_targets = sum(
            [lam * target for lam, target in zip(msda_lams, targets)])

        return msda_pixels, msda_targets