def __call__(self, input): w, h = input[0].size s1, s2 = SplitInSites()(input) if np.random.rand() > 0.5: s1, s2 = s2, s1 lam = np.random.uniform(0, 1) r_x = np.random.uniform(0, w) r_y = np.random.uniform(0, h) r_w = w * np.sqrt(1 - lam) r_h = h * np.sqrt(1 - lam) x1 = (r_x - r_w / 2).clip(0, w).round().astype(np.int32) x2 = (r_x + r_w / 2).clip(0, w).round().astype(np.int32) y1 = (r_y - r_h / 2).clip(0, h).round().astype(np.int32) y2 = (r_y + r_h / 2).clip(0, h).round().astype(np.int32) mode = s1[0].mode for c in input: assert c.mode == mode s1 = [np.array(c) for c in s1] s2 = [np.array(c) for c in s2] for c1, c2 in zip(s1, s2): c1[x1:x2, y1:y2] = c2[x1:x2, y1:y2] s1 = [Image.fromarray(c) for c in s1] assert s1[0].mode == mode return s1
def main(dataset_path, workers): transform = T.Compose([ ApplyTo( ['image'], T.Compose([ SplitInSites(), T.Lambda( lambda xs: torch.stack([ToTensor()(x) for x in xs], 0)), ])), Extract(['image']), ]) train_data = pd.read_csv(os.path.join(dataset_path, 'train.csv')) train_data['root'] = os.path.join(dataset_path, 'train') test_data = pd.read_csv(os.path.join(dataset_path, 'test.csv')) test_data['root'] = os.path.join(dataset_path, 'test') data = pd.concat([train_data, test_data]) stats = {} for (exp, plate), group in tqdm(data.groupby(['experiment', 'plate'])): dataset = TestDataset(group, transform=transform) data_loader = torch.utils.data.DataLoader(dataset, batch_size=32, num_workers=workers) with torch.no_grad(): images = [images for images, in data_loader] images = torch.cat(images, 0) mean = images.mean((0, 1, 3, 4)) std = images.std((0, 1, 3, 4)) stats[(exp, plate)] = mean, std del images, mean, std gc.collect() torch.save(stats, 'plate_stats.pth')
torch.load('./experiment_stats.pth')) elif config.normalize == 'plate': normalize = NormalizeByPlateStats(torch.load('./plate_stats.pth')) else: raise AssertionError('invalid normalization {}'.format(config.normalize)) eval_image_transform = T.Compose([ RandomSite(), Resize(config.resize_size), center_crop, to_tensor, ]) test_image_transform = T.Compose([ Resize(config.resize_size), center_crop, SplitInSites(), T.Lambda(lambda xs: torch.stack([to_tensor(x) for x in xs], 0)), ]) train_transform = T.Compose([ ApplyTo(['image'], T.Compose([ RandomSite(), Resize(config.resize_size), random_crop, RandomFlip(), RandomTranspose(), to_tensor, ChannelReweight(config.aug.channel_reweight), ])), normalize, Extract(['image', 'exp', 'label', 'id']),