def get_mixmo_mix_method_at_ratio_epoch(self, batch_seed=None):
        """
        Select which mixing method should be used according to training scheduling.

        Procedure:
        Select self.dict_mixmo_mix_method["method_name"] with proba self.dict_mixmo_mix_method["prob"] that is linearly decreased towards 0 after 11/12 of training process.
        Otherwise, use self.dict_mixmo_mix_method["replacement_method_name"] (in general mixup)
        """

        method = self.dict_mixmo_mix_method["method_name"]
        replacement_method = self.dict_mixmo_mix_method[
            "replacement_method_name"]
        if method == replacement_method:
            return method

        # Check the actual switch probability according to scheduler and current epoch
        default_prob = self.dict_mixmo_mix_method["prob"]
        if self.ratio_epoch_current < config.cfg.RATIO_EPOCH_DECREASE:
            prob = default_prob
        else:
            eta = max(0, (1 - self.ratio_epoch_current) /
                      (1 - config.cfg.RATIO_EPOCH_DECREASE))
            prob = default_prob * eta

        # Choose the method depending on draw result
        if misc.random_lower_than(prob, seed=batch_seed):
            return method
        return replacement_method
    def call_msda(self, index_0, mixmo_mask=None, seed_da=None):
        """
        Get two samples and mix them. Return a dictionary of sample and label
        """
        # Gather the two image/label pairs used by the augmentation
        pixels_0, target_0 = self.call_dataset(index_0, seed=seed_da)
        skip_msda = (self.msda_mix_method is None
                     or not misc.random_lower_than(self.msda_prob))
        if skip_msda:  # Early exit if we are not mixing
            return pixels_0, target_0

        index_1 = misc.get_random(seed=None).choice(range(len(self)))
        pixels_1, target_1 = self.call_dataset(index_1, seed=seed_da)

        targets = [target_0, target_1]

        # Get mixing masks
        msda_lams = misc.sample_lams(self.msda_beta, n=2)
        msda_masks, msda_lams = mixing_blocks.mix(
            method=self.msda_mix_method,
            lams=msda_lams,
            input_size=pixels_0.size(),
        )

        # Adjust the lams to account for later mixmo mixing that might alter masks
        if mixmo_mask is not None:
            ## approx for computational issues: mask should be symmetrical in channels
            mixmo_mask_0 = mixmo_mask[0, :, :]

            if self.properties("conv1_is_half_size"):
                _msda_mask_0 = torch.nn.AvgPool2d(kernel_size=(2, 2))(
                    msda_masks[0][:1, :, :])
                msda_masks_for_lam = [_msda_mask_0.to(torch.float16)]
            else:
                mixmo_mask_0 = mixmo_mask_0.to(torch.float32)
                msda_masks_for_lam = msda_masks

            ## Compute the adjusted ratios after mixmo mixing
            mean_mixmo_mask_0 = mixmo_mask_0.mean()
            msda_lams = [(msda_mask[0, :, :] * mixmo_mask_0).mean() /
                         (mean_mixmo_mask_0 + 1e-8)
                         for msda_mask in msda_masks_for_lam]

            if self.properties("conv1_is_half_size"):
                lam = msda_lams[0].to(torch.float32)
                msda_lams = [lam, 1 - lam]

        # Randomly reverse the roles of mixed samples (important to symmetrize CutMix, Patch-Up, ...)
        if self.reverse_if_first_minor and msda_lams[0] < 0.5:
            msda_pixels = msda_masks[1] * pixels_0 + msda_masks[0] * pixels_1
            msda_lams = [msda_lams[1], msda_lams[0]]
        else:
            msda_pixels = msda_masks[0] * pixels_0 + msda_masks[1] * pixels_1

        # Standard MSDA label interpolation
        msda_targets = sum(
            [lam * target for lam, target in zip(msda_lams, targets)])

        return msda_pixels, msda_targets
Example #3
0
def _stack_mask(input_size, lam, config_mix):
    """
    Compute masks for Channel/Horizontal/Vertical concat
    (number of images) x channel x (image width) x (image height)
    """
    # Default config
    misc.ifnotfound_update(config_mix,
                           {
                               "stack_dim": 1,
                               "stack_rdflip": True,
                           })

    dim = config_mix["stack_dim"]
    random_flip = config_mix["stack_rdflip"]

    flip = random_flip and misc.random_lower_than(
        prob=0.5, seed=None, r=None)
    if flip:
        lam = 1 - lam

    # Split the dimension in two
    border = int(lam * input_size[dim])

    ones_size = list(input_size)
    ones_size[dim] = border

    zeros_size = list(input_size)
    zeros_size[dim] = input_size[dim] - border

    ones_mask = torch.ones(ones_size)
    zeros_mask = torch.zeros(zeros_size)

    # Merge the two split masks
    if flip:
        mask = torch.cat([zeros_mask, ones_mask], dim=dim)
    else:
        mask = torch.cat([ones_mask, zeros_mask], dim=dim)

    return mask