示例#1
0
    def upsample_data(self):
        """we generate some samples so that all classes will have equal number of training samples"""
        dhl = DataHelper()
        '''count samples & categorize their address based on their category'''
        if self.ds_type == DatasetType.train:
            sample_count_by_class = np.zeros([8])
            img_addr_by_class = [[] for i in range(8)]
            anno_addr_by_class = [[] for i in range(8)]
            lnd_addr_by_class = [[] for i in range(8)]
        else:
            sample_count_by_class = np.zeros([7])
            img_addr_by_class = [[] for i in range(7)]
            anno_addr_by_class = [[] for i in range(7)]
            lnd_addr_by_class = [[] for i in range(7)]
        """"""
        print("counting classes:")
        for i, file in tqdm(enumerate(os.listdir(self.anno_path))):
            if file.endswith("_exp.npy"):
                exp = int(np.load(os.path.join(self.anno_path, file)))
                sample_count_by_class[exp] += 1
                '''adding ex'''
                anno_addr_by_class[exp].append(
                    os.path.join(self.anno_path, file))
                img_addr_by_class[exp].append(
                    os.path.join(self.img_path, file[:-8] + '.jpg'))
                lnd_addr_by_class[exp].append(
                    os.path.join(self.anno_path, file[:-8] + '_slnd.npy'))

        print("sample_count_by_category: ====>>")
        print(sample_count_by_class)
        '''calculate augmentation factor for each class:'''
        aug_factor_by_class, aug_factor_by_class_freq = dhl.calculate_augmentation_rate(
            sample_count_by_class=sample_count_by_class,
            base_aug_factor=AffectnetConf.augmentation_factor)
        '''after we have calculated those two array, we will augment samples '''
        for i in range(len(anno_addr_by_class)):
            dhl.do_random_augment(img_addrs=img_addr_by_class[i],
                                  anno_addrs=anno_addr_by_class[i],
                                  lnd_addrs=lnd_addr_by_class[i],
                                  aug_factor=int(aug_factor_by_class[i]),
                                  aug_factor_freq=int(
                                      aug_factor_by_class_freq[i]),
                                  img_save_path=self.img_path_aug,
                                  anno_save_path=self.anno_path_aug,
                                  class_index=i)
    def upsample_data(self):
        """we generate some samples so that all classes will have equal number of training samples"""
        dhl = DataHelper()
        '''count samples & categorize their address based on their category'''
        sample_count_by_class = np.zeros([7])
        img_addr_by_class = [[] for i in range(7)]
        anno_addr_by_class = [[] for i in range(7)]
        lnd_addr_by_class = [[] for i in range(7)]
        """"""
        print("counting classes:")
        count = 0
        for i, file in tqdm(enumerate(os.listdir(self.anno_path))):
            if file.endswith("_exp.npy"):
                exp = int(np.load(os.path.join(self.anno_path, file)))
                sample_count_by_class[exp] += 1
                '''adding ex'''
                anno_addr_by_class[exp].append(
                    os.path.join(self.anno_path, file))
                img_addr_by_class[exp].append(
                    os.path.join(self.img_path, file[:-8] + '.jpg'))
                lnd_addr_by_class[exp].append(
                    os.path.join(self.anno_path, file[:-8] + '_slnd.npy'))
                count += 1

        print("sample_count_by_category: ====>>")
        print(sample_count_by_class)
        # {Surprise 1290}===={ Fear 281.}===[Disgust 717}===[Happiness 4772]
        # ==={ Sadness 1982}=={Anger 705}===.[ Neutral 2524}
        '''calculate augmentation factor for each class:'''
        aug_factor_by_class, aug_factor_by_class_freq = dhl.calculate_augmentation_rate(
            sample_count_by_class=sample_count_by_class,
            base_aug_factor=RafDBConf.augmentation_factor)
        '''after we have calculated those two array, we will augment samples '''
        for i in range(len(anno_addr_by_class)):
            dhl.do_random_augment(img_addrs=img_addr_by_class[i],
                                  anno_addrs=anno_addr_by_class[i],
                                  lnd_addrs=lnd_addr_by_class[i],
                                  aug_factor=int(aug_factor_by_class[i]),
                                  aug_factor_freq=int(
                                      aug_factor_by_class_freq[i]),
                                  img_save_path=self.img_path_aug,
                                  anno_save_path=self.anno_path_aug,
                                  class_index=i)