コード例 #1
0
ファイル: cutmix.py プロジェクト: valemore/leaf-disease
    def __getitem__(self, index):
        img, lb, _ = self.dataset[index]
        if self.soft_targets:
            lb_onehot = lb
        else:
            lb_onehot = onehot(self.num_class, lb)

        for _ in range(self.num_mix):
            r = np.random.rand(1)
            if r > self.prob:
                continue

            # generate mixed sample
            lam = np.random.beta(self.beta, self.beta)
            rand_index = np.random.randint(0, len(self))

            img2, lb2, _ = self.dataset[rand_index]
            if self.soft_targets:
                lb2_onehot = lb2
            else:
                lb2_onehot = onehot(self.num_class, lb2)

            bbx1, bby1, bbx2, bby2 = rand_bbox(img.size(), lam)
            img[:, bbx1:bbx2, bby1:bby2] = img2[:, bbx1:bbx2, bby1:bby2]
            lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (img.size()[-1] * img.size()[-2]))
            lb_onehot = lb_onehot * lam + lb2_onehot * (1. - lam)

            if self.transform:
                img = self.transform(image=img)["image"]

        return img, lb_onehot, index
コード例 #2
0
 def __to_oh(self, lb):
     if self.soft_targets:
         return lb
     return onehot(self.num_class, lb)