예제 #1
0
    def __call__(self, input):
        w, h = input[0].size
        s1, s2 = SplitInSites()(input)
        if np.random.rand() > 0.5:
            s1, s2 = s2, s1

        lam = np.random.uniform(0, 1)
        r_x = np.random.uniform(0, w)
        r_y = np.random.uniform(0, h)
        r_w = w * np.sqrt(1 - lam)
        r_h = h * np.sqrt(1 - lam)
        x1 = (r_x - r_w / 2).clip(0, w).round().astype(np.int32)
        x2 = (r_x + r_w / 2).clip(0, w).round().astype(np.int32)
        y1 = (r_y - r_h / 2).clip(0, h).round().astype(np.int32)
        y2 = (r_y + r_h / 2).clip(0, h).round().astype(np.int32)

        mode = s1[0].mode
        for c in input:
            assert c.mode == mode

        s1 = [np.array(c) for c in s1]
        s2 = [np.array(c) for c in s2]

        for c1, c2 in zip(s1, s2):
            c1[x1:x2, y1:y2] = c2[x1:x2, y1:y2]

        s1 = [Image.fromarray(c) for c in s1]

        assert s1[0].mode == mode

        return s1
예제 #2
0
def main(dataset_path, workers):
    transform = T.Compose([
        ApplyTo(
            ['image'],
            T.Compose([
                SplitInSites(),
                T.Lambda(
                    lambda xs: torch.stack([ToTensor()(x) for x in xs], 0)),
            ])),
        Extract(['image']),
    ])

    train_data = pd.read_csv(os.path.join(dataset_path, 'train.csv'))
    train_data['root'] = os.path.join(dataset_path, 'train')
    test_data = pd.read_csv(os.path.join(dataset_path, 'test.csv'))
    test_data['root'] = os.path.join(dataset_path, 'test')
    data = pd.concat([train_data, test_data])

    stats = {}
    for (exp, plate), group in tqdm(data.groupby(['experiment', 'plate'])):
        dataset = TestDataset(group, transform=transform)
        data_loader = torch.utils.data.DataLoader(dataset,
                                                  batch_size=32,
                                                  num_workers=workers)

        with torch.no_grad():
            images = [images for images, in data_loader]
            images = torch.cat(images, 0)
            mean = images.mean((0, 1, 3, 4))
            std = images.std((0, 1, 3, 4))
            stats[(exp, plate)] = mean, std

            del images, mean, std
            gc.collect()

    torch.save(stats, 'plate_stats.pth')
예제 #3
0
        torch.load('./experiment_stats.pth'))
elif config.normalize == 'plate':
    normalize = NormalizeByPlateStats(torch.load('./plate_stats.pth'))
else:
    raise AssertionError('invalid normalization {}'.format(config.normalize))

eval_image_transform = T.Compose([
    RandomSite(),
    Resize(config.resize_size),
    center_crop,
    to_tensor,
])
test_image_transform = T.Compose([
    Resize(config.resize_size),
    center_crop,
    SplitInSites(),
    T.Lambda(lambda xs: torch.stack([to_tensor(x) for x in xs], 0)),
])
train_transform = T.Compose([
    ApplyTo(['image'],
            T.Compose([
                RandomSite(),
                Resize(config.resize_size),
                random_crop,
                RandomFlip(),
                RandomTranspose(),
                to_tensor,
                ChannelReweight(config.aug.channel_reweight),
            ])),
    normalize,
    Extract(['image', 'exp', 'label', 'id']),