def test_batch_sampler(self):
        input_data = torch.randn((103, 1, 4, 10))

        buck_sampler_no_drop_last = BucketingSampler(input_data,
                                                     batch_size=20,
                                                     drop_last=False,
                                                     num_bins=3,
                                                     replacement=False)
        buck_batch_sampler_no_drop_last = BatchSampler(
            buck_sampler_no_drop_last, batch_size=20, drop_last=False)

        sim_sampler_no_drop_last = SimilarSizeSampler(input_data,
                                                      batch_size=20,
                                                      drop_last=False,
                                                      replacement=False)
        sim_batch_sampler_no_drop_last = BatchSampler(sim_sampler_no_drop_last,
                                                      batch_size=20,
                                                      drop_last=False)

        assert (sim_batch_sampler_no_drop_last.__len__() == 6)
        assert (buck_batch_sampler_no_drop_last.__len__() == 6)

        counter = 0
        for _ in buck_batch_sampler_no_drop_last.__iter__():
            counter = counter + 1
        assert (counter == 6)

        counter = 0
        for _ in sim_batch_sampler_no_drop_last.__iter__():
            counter = counter + 1
        assert (counter == 6)

        buck_sampler_drop_last = BucketingSampler(input_data,
                                                  batch_size=20,
                                                  drop_last=True,
                                                  num_bins=3,
                                                  replacement=False)
        buck_batch_sampler_drop_last = BatchSampler(buck_sampler_drop_last,
                                                    batch_size=20,
                                                    drop_last=True)

        sim_sampler_drop_last = SimilarSizeSampler(input_data,
                                                   batch_size=20,
                                                   drop_last=True,
                                                   replacement=False)
        sim_batch_sampler_drop_last = BatchSampler(sim_sampler_drop_last,
                                                   batch_size=20,
                                                   drop_last=True)

        assert (sim_batch_sampler_drop_last.__len__() == 5)
        assert (buck_batch_sampler_drop_last.__len__() == 5)

        counter = 0
        for _ in buck_batch_sampler_drop_last.__iter__():
            counter = counter + 1
        assert (counter == 5)

        counter = 0
        for _ in sim_batch_sampler_drop_last.__iter__():
            counter = counter + 1
        assert (counter == 5)
Beispiel #2
0
class BatchSamplerTripletClassif(object):
    """Wraps BatchSampler for items associated to background and BatchSamplerClassif for items with classes.

    Args:
        indices_by_class (list of list of int): 
        batch_size (int): Size of mini-batch.
        pc_noclassif (float): Percentage of items associated to background in the batch
        nb_indices_same_class (int): nb of indices from the same class returned after a class sampling

    Warning: `indices_by_class` assumes that the list in position 0 contains 
             indices associated to items without classes (background)

    Warning: `pc_noclassif` is used to calculate the number of items associated to classes in the batch,
             the latter must be a multiple of `nb_indices_same_class`.

    Example:
        >>> list(BatchSamplerTripletClassif([
                list(range(8)), # indices of background
                list(range(10,14)), # class 1
                list(range(20,25)), # class 2
                list(range(30,36))], # class 3
                4, # batch_size
                pc_noclassif=0.5,
                nb_indices_same_class=2))
        [[13, 12, 2, 5], [31, 32, 4, 0], [33, 30, 6, 3], [23, 22, 7, 1]]
    """

    def __init__(self, indices_by_class, indices_by_cluster, cluster_by_index, batch_size, use_cluster, pc_noclassif=0.25, nb_indices_same_class=2, nb_indices_same_cluster=2, init=False):
        self.indices_by_cluster = copy.copy(indices_by_cluster)
        self.indices_by_class = copy.copy(indices_by_class)
        self.indices_no_class = self.indices_by_class.pop(0)
        self.cluster_by_index = cluster_by_index
        self.use_cluster = use_cluster

        #switch to no class but in clusters
        if init:
            remove_target = copy.copy(self.indices_by_class) #swith back and forth of indices_no_class and indices_by_class
            for indices in remove_target
                for index in indices:
                    cluster = self.cluster_by_index[index]
                    try: 
                        self.indices_by_cluster[cluster].remove(index)
                    except:
                        print("not found cluster for {}".format(index))
                        continue

        self.batch_size = batch_size
        self.pc_noclassif = pc_noclassif
        self.use_cluster = use_cluster
        self.nb_indices_same_class = nb_indices_same_class
        self.nb_indices_same_cluster = nb_indices_same_cluster

        self.batch_size_classif = round((1 - self.pc_noclassif) * self.batch_size)
        self.batch_size_noclassif = self.batch_size - self.batch_size_classif

        print("none class length : {}".format(self.batch_size_noclassif))

        # Batch Sampler Same Clusters
        # self.batch_sampler_noclassif = BatchSamplerCluster(
        #     RandomSamplerValues(self.indices_by_cluster),
        #     self.batch_size_noclassif,
        #     self.nb_indices_same_cluster) if self.use_cluster else 

        #switch to no class but in clusters
        self.batch_sampler_noclassif = BatchSampler(
            RandomSamplerValues(self.indices_no_class),
            self.batch_size_noclassif, True)

        # Batch Sampler Classif
        if use_cluster:
            self.batch_sampler_classif = BatchSamplerClassWithNegInSameCluster(
                RandomSamplerValues(self.indices_by_class),
                self.indices_by_cluster, #list of 30 clusters with list of index
                self.cluster_by_index,
                self.batch_size_classif,
                self.nb_indices_same_class)
        else:
            self.batch_sampler_classif = BatchSamplerClassif(
                RandomSamplerValues(self.indices_by_class),
                self.batch_size_classif,
                self.nb_indices_same_class)

        

    def __iter__(self):
        gen_classif = self.batch_sampler_classif.__iter__()
        gen_noclassif = self.batch_sampler_noclassif.__iter__()
        for i in range(len(self)):
            batch = []
            batch += gen_classif.__next__()
            batch += gen_noclassif.__next__()[:25]
            yield batch

    def __len__(self):
        length = min([len(self.batch_sampler_classif), len(self.batch_sampler_noclassif)])
        return length
class BatchSamplerTripletClassif(object):
    """Wraps BatchSampler for items associated to background and BatchSamplerClassif for items with classes.

    Args:
        indices_by_class (list of list of int): 
        batch_size (int): Size of mini-batch.
        pc_noclassif (float): Percentage of items associated to background in the batch
        nb_indices_same_class (int): nb of indices from the same class returned after a class sampling

    Warning: `indices_by_class` assumes that the list in position 0 contains 
             indices associated to items without classes (background)

    Warning: `pc_noclassif` is used to calculate the number of items associated to classes in the batch,
             the latter must be a multiple of `nb_indices_same_class`.

    Example:
        >>> list(BatchSamplerTripletClassif([
                list(range(8)), # indices of background
                list(range(10,14)), # class 1
                list(range(20,25)), # class 2
                list(range(30,36))], # class 3
                4, # batch_size
                pc_noclassif=0.5,
                nb_indices_same_class=2))
        [[13, 12, 2, 5], [31, 32, 4, 0], [33, 30, 6, 3], [23, 22, 7, 1]]
    """
    def __init__(self,
                 indices_by_class,
                 batch_size,
                 pc_noclassif=0.5,
                 nb_indices_same_class=2):
        self.indices_by_class = copy.copy(indices_by_class)
        self.indices_no_class = self.indices_by_class.pop(0)
        self.batch_size = batch_size
        self.pc_noclassif = pc_noclassif
        self.nb_indices_same_class = nb_indices_same_class

        self.batch_size_classif = round(
            (1 - self.pc_noclassif) * self.batch_size)
        self.batch_size_noclassif = self.batch_size - self.batch_size_classif

        # Batch Sampler NoClassif
        self.batch_sampler_noclassif = BatchSampler(
            RandomSamplerValues(self.indices_no_class),
            self.batch_size_noclassif, True)

        # Batch Sampler Classif
        self.batch_sampler_classif = BatchSamplerClassif(
            RandomSamplerValues(self.indices_by_class),
            self.batch_size_classif, self.nb_indices_same_class)

    def __iter__(self):
        gen_classif = self.batch_sampler_classif.__iter__()
        gen_noclassif = self.batch_sampler_noclassif.__iter__()
        for i in range(len(self)):
            batch = []
            batch += gen_classif.__next__()
            batch += gen_noclassif.__next__()
            yield batch

    def __len__(self):
        return min([
            len(self.batch_sampler_classif),
            len(self.batch_sampler_noclassif)
        ])