def slice_imgs(imgs, count, size=224, transform=None, overscan=False, micro=None): def map(x, a, b): return x * (b-a) + a rnd_size = torch.rand(count) rnd_offx = torch.rand(count) rnd_offy = torch.rand(count) sz = [img.shape[2:] for img in imgs] sz_min = [torch.min(torch.tensor(s)) for s in sz] if overscan is True: sz = [[2*s[0], 2*s[1]] for s in list(sz)] imgs = [pad_up_to(imgs[i], sz[i], type='centr') for i in range(len(imgs))] sliced = [] for i, img in enumerate(imgs): cuts = [] for c in range(count): if micro is True: # both scales, micro mode csize = map(rnd_size[c], 64, max(size, 0.25*sz_min[i])).int() elif micro is False: # both scales, macro mode csize = map(rnd_size[c], 0.5*sz_min[i], 0.98*sz_min[i]).int() else: # single scale csize = map(rnd_size[c], 112, 0.98*sz_min[i]).int() offsetx = map(rnd_offx[c], 0, sz[i][1] - csize).int() offsety = map(rnd_offy[c], 0, sz[i][0] - csize).int() cut = img[:, :, offsety:offsety + csize, offsetx:offsetx + csize] cut = F.interpolate(cut, (size,size), mode='bicubic', align_corners=False) # bilinear if transform is not None: cut = transform(cut) cuts.append(cut) sliced.append(torch.cat(cuts, 0)) return sliced
def slice_imgs(imgs, count, size=224, transform=None, align='uniform', micro=1.): def map(x, a, b): return x * (b - a) + a rnd_size = torch.rand(count) if align == 'central': # normal around center rnd_offx = torch.clip(torch.randn(count) * 0.2 + 0.5, 0., 1.) rnd_offy = torch.clip(torch.randn(count) * 0.2 + 0.5, 0., 1.) else: # uniform rnd_offx = torch.rand(count) rnd_offy = torch.rand(count) sz = [img.shape[2:] for img in imgs] sz_max = [torch.min(torch.tensor(s)) for s in sz] if align == 'overscan': # add space sz = [[2 * s[0], 2 * s[1]] for s in list(sz)] imgs = [ pad_up_to(imgs[i], sz[i], type='centr') for i in range(len(imgs)) ] sliced = [] for i, img in enumerate(imgs): cuts = [] sz_max_i = max(size, 0.25 * sz_max[i]) if micro is True else sz_max[i] if micro is True: # both scales, micro mode sz_min_i = size // 4 elif micro is False: # both scales, macro mode sz_min_i = 0.5 * sz_max[i] else: # single scale sz_min_i = size if torch.rand(1) < micro else 0.9 * sz_max[i] for c in range(count): csize = map(rnd_size[c], sz_min_i, sz_max_i).int() offsetx = map(rnd_offx[c], 0, sz[i][1] - csize).int() offsety = map(rnd_offy[c], 0, sz[i][0] - csize).int() cut = img[:, :, offsety:offsety + csize, offsetx:offsetx + csize] cut = F.interpolate(cut, (size, size), mode='bicubic', align_corners=False) # bilinear if transform is not None: cut = transform(cut) cuts.append(cut) sliced.append(torch.cat(cuts, 0)) return sliced