def _build_byol_dataset(imfiles, imshape=(256, 256), batch_size=256, num_parallel_calls=None, norm=255, num_channels=3, augment=True, single_channel=False): """ :stratify: if not None, a list of categories for each element in imfile. """ assert augment != False, "don't you need to augment your data?" ds = _image_file_dataset(imfiles, imshape=imshape, num_parallel_calls=num_parallel_calls, norm=norm, num_channels=num_channels, shuffle=True, single_channel=single_channel, augment=False) _aug = augment_function(imshape, augment) @tf.function def pair_augment(x): return (_aug(x), _aug(x)), np.array([1]) ds = ds.map(pair_augment, num_parallel_calls=num_parallel_calls) ds = ds.batch(batch_size) ds = ds.prefetch(1) return ds
def _build_simclr_dataset(imfiles, imshape=(256, 256), batch_size=256, num_parallel_calls=None, norm=255, num_channels=3, augment=True, single_channel=False): """ """ assert augment, "don't you need to augment your data?" _aug = augment_function(imshape, augment) ds = _image_file_dataset(imfiles, imshape=imshape, num_parallel_calls=num_parallel_calls, norm=norm, num_channels=num_channels, shuffle=True, single_channel=single_channel) @tf.function def _augment_and_stack(x): y = tf.constant(np.array([1, -1]).astype(np.int32)) return tf.stack([_aug(x), _aug(x)]), y ds = ds.map(_augment_and_stack, num_parallel_calls=num_parallel_calls) ds = ds.unbatch() ds = ds.batch(2 * batch_size, drop_remainder=True) ds = ds.prefetch(1) return ds
def dataset(fps, ys = None, imshape=(256,256), num_channels=3, num_parallel_calls=None, norm=255, batch_size=256, augment=False, unlab_fps=None, shuffle=False, sobel=False, single_channel=False): """ return a tf dataset that iterates over a list of images once :fps: list of filepaths :ys: array of corresponding labels :imshape: constant shape to resize images to :num_channels: channel depth of images :batch_size: just what you think it is :augment: augmentation parameters (or True for defaults, or False to disable) :unlab_fps: list of filepaths (same length as fps) for semi- supervised learning :shuffle: whether to shuffle the dataset :sobel: whether to replace the input image with its sobel edges :single_channel: if True, expect a single-channel input image and stack it num_channels times. Returns :ds: tf.data.Dataset object to iterate over data. The dataset returns (x,y) tuples unless unlab_fps is included, in which case the structure will be ((x, x_unlab), y) :num_steps: number of steps (for passing to tf.keras.Model.fit()) """ if augment: _aug = augment_function(imshape, augment) ds = _image_file_dataset(fps, imshape=imshape, num_channels=num_channels, num_parallel_calls=num_parallel_calls, norm=norm, shuffle=shuffle, single_channel=single_channel) if augment: ds = ds.map(_aug, num_parallel_calls=num_parallel_calls) if sobel: ds = ds.map(_sobelize, num_parallel_calls=num_parallel_calls) if unlab_fps is not None: u_ds = _image_file_dataset(unlab_fps, imshape=imshape, num_channels=num_channels, num_parallel_calls=num_parallel_calls, norm=norm, single_channel=single_channel) if augment: u_ds = u_ds.map(_aug, num_parallel_calls=num_parallel_calls) if sobel: u_ds = u_ds.map(_sobelize, num_parallel_calls=num_parallel_calls) ds = tf.data.Dataset.zip((ds, u_ds)) if ys is not None: ys = tf.data.Dataset.from_tensor_slices(ys) #if unlab_fps is not None: # ys = ds.zip((ys,ys)) # #ys = ds.zip((u_ds,ys)) ds = ds.zip((ds, ys)) ds = ds.batch(batch_size) #if sobel: # ds = ds.map(_sobelize, num_parallel_calls=num_parallel_calls) ds = ds.prefetch(1) num_steps = int(np.ceil(len(fps)/batch_size)) return ds, num_steps
def clip_dataset(imfiles, labels, encoder, maxlen=76, imshape=(256, 256), num_channels=3, num_parallel_calls=None, norm=255, batch_size=256, augment=False, shuffle=True, single_channel=False): """ Build a tf.data.Dataset object for training CLIP :imfiles: list of paths to training images :labels: list of strings containing a caption for each image :encoder: sentencepiece object for mapping strings to byte pair encoded arrays :maxlen: int; length of sequences. BPE arrays will be padded or truncated to this. """ ds = _image_file_dataset(imfiles, ys=labels, imshape=imshape, augment=augment, shuffle=shuffle) if augment: aug_func = augment_function(imshape, {"rot90": True}) def _encode_text(y): y = str(y) y = encoder.encode(y, out_type=int, add_bos=True, add_eos=True) N = len(y) if N > maxlen: y = y[:maxlen] elif N < maxlen: y += [0] * (maxlen - N) return np.array(y) def _augment_and_encode(x, y): if augment: x = aug_func(x) y = tf.py_function(_encode_text, (y, ), Tout=tf.int64) return x, y ds = ds.map(_augment_and_encode, num_parallel_calls=num_parallel_calls) ds = ds.batch(batch_size) ds = ds.prefetch(1) return ds
def _fixmatch_unlab_dataset(fps, weak_aug, str_aug, imshape=(256,256), num_parallel_calls=None, norm=255, num_channels=3, single_channel=False, batch_size=64): """ Macro to build the weak/strong augmented pairs for FixMatch semisupervised learning """ _weakaug = augment_function(imshape, weak_aug) _strongaug = augment_function(imshape, str_aug) def aug_pair(x): return _weakaug(x), _strongaug(x) ds = _image_file_dataset(fps, imshape=imshape, num_parallel_calls=num_parallel_calls, norm=norm, num_channels=num_channels, shuffle=True, single_channel=single_channel, augment=False) ds = ds.map(aug_pair, num_parallel_calls=num_parallel_calls) ds = ds.batch(batch_size) ds = ds.prefetch(1) return ds
def _single_augplot(filepath, aug_params=True, norm=255, num_channels=3, resize=None): """ Input a path to an image and an augmentation function; sample 15 augmentations and display using matplotlib. """ img = _load_img(filepath, norm=norm, num_channels=num_channels, resize=resize) aug_func = augment_function(img.shape[:2], aug_params) plt.figure() plt.subplot(4,4,1) plt.imshow(img) plt.axis(False) plt.title("original") for i in range(2,17): plt.subplot(4,4,i) plt.imshow(aug_func(img).numpy()) plt.axis(False)
def _build_context_encoder_dataset(filepaths, input_shape=(256, 256, 3), norm=255, shuffle=True, num_parallel_calls=4, batch_size=32, prefetch=True, augment=False, sobel=False, single_channel=False): """ Build a tf.data.Dataset object to use for training. """ # first build a Dataset that generates masks def _gen(): return mask_generator(*input_shape) mask_ds = tf.data.Dataset.from_generator(_gen, output_types=(tf.float32), output_shapes=input_shape) # now a Dataset to load images img_ds = _image_file_dataset(filepaths, imshape=input_shape[:2], num_channels=input_shape[2], norm=norm, num_parallel_calls=num_parallel_calls, shuffle=True, single_channel=single_channel) if augment: _aug = augment_function(input_shape[:2], augment) #_aug = augment_function(augment) img_ds = img_ds.map(_aug, num_parallel_calls=num_parallel_calls) if sobel: img_ds = img_ds.map(_sobelize, num_parallel_calls=num_parallel_calls) # combine the image and mask datasets zipped_ds = tf.data.Dataset.zip((img_ds, mask_ds)) # precompute masked images for context encoder input masked_batched_ds = zipped_ds.batch(batch_size) if prefetch: masked_batched_ds = masked_batched_ds.prefetch(1) return masked_batched_ds
def _load_img(x,y): loaded = tf.io.read_file(x) img = decode(loaded) if single_channel: #resized = tf.concat(num_channels*[resized], -1) img = tf.concat(num_channels*[img], -1) # normalize img = tf.cast(img[:,:,:num_channels], tf.float32)/norm #img = tf.reshape(normed, (imshape[0], imshape[1], num_channels)) if augment: img = augment_function(imshape, augment)(img) # TWO RESIZING OPERATIONS HERE # first: tf.image.resize(), which will resample the image if it # doesn't have the right dimensions img = tf.image.resize(img, imshape) # in some cases tf.image.resize() will return tensors of shape # [imshape[0], imshape[1], None], so now we'll specify the # channel dimension explicitly. img = tf.reshape(img, [imshape[0], imshape[1], num_channels]) return img,y
def _multiple_augplot(filepaths, aug_params=True, num_resamps=5, norm=255, num_channels=3, resize=None): """ """ num_files = len(filepaths) img = _load_img(filepaths[0], norm=norm, num_channels=num_channels, resize=resize) aug_func = augment_function(img.shape[:2], aug_params) plt.figure() for i in range(num_files): img = _load_img(filepaths[i], norm=norm, num_channels=num_channels, resize=resize) plt.subplot(num_files+1, num_resamps+1, 1+i*(num_resamps+1)) plt.imshow(img) plt.axis(False) if i == 0: plt.title("original") for j in range(num_resamps): plt.subplot(num_files+1, num_resamps+1, j+2+i*(num_resamps+1)) plt.imshow(aug_func(img).numpy()) plt.axis(False)
def build_augment_pair_dataset(imfiles, imshape=(256,256), batch_size=256, num_parallel_calls=None, norm=255, num_channels=3, augment=True, single_channel=False): """ Build a tf.data.Dataset object for training momentum contrast. Generates pairs of augmentations from a single image. """ assert augment, "don't you need to augment your data?" _aug = augment_function(imshape, augment) ds = _image_file_dataset(imfiles, imshape=imshape, num_parallel_calls=num_parallel_calls, norm=norm, num_channels=num_channels, shuffle=True, single_channel=single_channel) a1 = ds.map(_aug, num_parallel_calls=num_parallel_calls) a2 = ds.map(_aug, num_parallel_calls=num_parallel_calls) ds = ds.zip((a1, a2)) ds = ds.batch(batch_size) ds = ds.prefetch(1) return ds
def build_iic_dataset(imfiles, r=5, imshape=(256,256), batch_size=256, num_parallel_calls=None, norm=255, num_channels=3, augment=True, single_channel=False): """ Build a tf.data.Dataset object for training IIC. """ assert augment, "don't you need to augment your data?" _aug = augment_function(imshape, augment) ds = _image_file_dataset(imfiles, imshape=imshape, num_parallel_calls=num_parallel_calls, norm=norm, num_channels=num_channels, shuffle=True, single_channel=single_channel) if r > 1: ds = ds.flat_map(lambda x: tf.data.Dataset.from_tensors(x).repeat(r)) augmented_ds = ds.map(_aug, num_parallel_calls=num_parallel_calls) ds = ds.zip((ds, augmented_ds)) ds = ds.batch(batch_size) ds = ds.prefetch(1) return ds
def _image_file_dataset(fps, ys=None, imshape=(256,256), num_parallel_calls=None, norm=255, num_channels=3, shuffle=False, single_channel=False, augment=False): """ Basic tool to load images into a tf.data.Dataset using PIL.Image or gdal instead of the tensorflow decode functions :fps: list of filepaths :ys: optional list of labels :imshape: constant shape to resize images to :num_parallel_calls: number of processes to use for loading :norm: value for normalizing images :num_channels: channel depth to truncate images to :shuffle: whether to shuffle the dataset :single_channel: if True, expect a single-channel input image and stack it num_channels times. :augment: optional, dictionary of augmentation parameters Returns tf.data.Dataset object with structure (x,y) if labels were passed, and (x) otherwise. images (x) are a 3D float32 tensor and labels should be a 0D int64 tensor """ # PRE-BUILT DATASET CASE if isinstance(fps, tf.data.Dataset): ds = fps if augment: _aug = augment_function(imshape, augment) ds = ds.map(_aug, num_parallel_calls=num_parallel_calls) # DIRECTORY OF TFRECORD FILES CASE elif isinstance(fps, str): if augment: augment = augment_function(imshape, augment) ds = load_dataset_from_tfrecords(fps, imshape, num_channels, num_parallel_calls=num_parallel_calls, map_fn=augment) # LIST OF FILES CASE: list of filepaths (probably what almost # always will get used) elif isinstance(fps[0], str): if ys is None: no_labels = True ys = np.zeros(len(fps), dtype=np.int64) else: no_labels = False # get an integer index for each filepath #imtypes = _generate_imtypes(fps) ds = tf.data.Dataset.from_tensor_slices((fps, ys)) # do the shuffling before loading so we can have a big queue without # taking up much memory if shuffle: ds = ds.shuffle(len(fps)) _load_img = _build_load_function(fps[0], imshape, norm, num_channels, single_channel, augment) #ds = ds.map(lambda x,t,y: (_load_img(x),y), # num_parallel_calls=num_parallel_calls) ds = ds.map(_load_img, num_parallel_calls=num_parallel_calls) # if no labels were passed, strip out the y. if no_labels: ds = ds.map(lambda x,y: x) else: assert False, "what are these inputs" return ds
def stratified_training_dataset(fps, y, imshape=(256,256), num_channels=3, num_parallel_calls=None, batch_size=256, mult=10, augment=True, norm=255, sobel=False, single_channel=False): """ Training dataset for DeepCluster. Build a dataset that provides stratified samples over labels :fps: list of strings containing paths to image files :y: array of cluster assignments- should have same length as fp :imshape: constant shape to resize images to :num_channels: channel depth of images :batch_size: just what you think it is :mult: not in paper; multiplication factor to increase number of steps/epoch. set to 1 to get paper algorithm :augment: augmentation parameters (or True for defaults, or False to disable) :sobel: whether to replace the input image with its sobel edges :single_channel: if True, expect a single-channel input image and stack it num_channels times. Returns :ds: tf.data.Dataset object to iterate over data :num_steps: number of steps (for passing to tf.keras.Model.fit()) """ # sample indices to use indices = np.arange(len(fps)) K = y.max()+1 samples_per_cluster = mult*int(len(fps)/K) sampled_indices = [] sampled_labels = [] # for each cluster for k in range(K): # find indices of samples assigned to it cluster_inds = indices[y == k] # only sample if at least one is assigned. note that # the deepcluster paper takes an extra step here. if len(cluster_inds) > 0: samps = np.random.choice(cluster_inds, size=samples_per_cluster, replace=True) sampled_indices.append(samps) sampled_labels.append(k*np.ones(len(samps), dtype=np.int64)) # concatenate sampled indices for each cluster sampled_indices = np.concatenate(sampled_indices, 0) sampled_labels = np.concatenate(sampled_labels, 0) # and shuffle their order together reorder = np.random.choice(np.arange(len(sampled_indices)), size=len(sampled_indices), replace=False) sampled_indices = sampled_indices[reorder] sampled_labels = sampled_labels[reorder] fps = np.array(fps)[sampled_indices] # NOW CREATE THE DATASET im_ds = _image_file_dataset(fps, imshape=imshape, num_channels=num_channels, num_parallel_calls=num_parallel_calls, norm=norm, shuffle=False, single_channel=single_channel) if augment: #im_ds = im_ds.map(_augment, num_parallel_calls) _aug = augment_function(imshape, augment) im_ds = im_ds.map(_aug, num_parallel_calls=num_parallel_calls) lab_ds = tf.data.Dataset.from_tensor_slices(sampled_labels) ds = tf.data.Dataset.zip((im_ds, lab_ds)) #ds = ds.batch(batch_size) if sobel: ds = ds.map(lambda x,y: (_sobelize(x),y), num_parallel_calls=num_parallel_calls) ds = ds.batch(batch_size) ds = ds.prefetch(1) num_steps = int(np.ceil(len(sampled_indices)/batch_size)) return ds, num_steps
def detcon_input_pipeline(imfiles, augment, mean_scale=1000, num_samples=16, outputsize=None, imshape=(256, 256), **kwargs): """ """ # build a dataset to load images one at a time ds = _image_file_dataset(imfiles, shuffle=False, imshape=imshape, **kwargs).prefetch(1) # place to store some examples imgs = [] segs = [] img1s = [] img2s = [] seg1s = [] seg2s = [] N = len(imfiles) not_enough_segs_count = 0 augment_removed_segmentation = 0 aug2 = {k: augment[k] for k in augment if k not in SEG_AUG_FUNCTIONS} _aug = augment_function(imshape, aug2) progressbar = tqdm(total=N) # for each image for x in ds: # get the segments if mean_scale > 0: seg, enough_segs = _get_segments(x, mean_scale=mean_scale, num_samples=num_samples, return_enough_segments=True) # count how many times we had to sample with replacement if not enough_segs: not_enough_segs_count += 1 else: seg = _get_grid_segments(imshape, num_samples) # now augment image and segmentation together, twice img1, seg1 = _segment_aug(x, seg, augment, outputsize=outputsize) img2, seg2 = _segment_aug(x, seg, augment, outputsize=outputsize) # check to see if any segments were pushed out of the image by augmentation segmentation_ok = _filter_out_bad_segments(img1, seg1, img2, seg2) if not segmentation_ok: augment_removed_segmentation += 1 # finally, augment images separately img1 = _aug(img1).numpy() img2 = _aug(img2).numpy() seg1 = seg1.numpy() seg2 = seg2.numpy() if len(img1s) < 16: imgs.append(x) segs.append(seg) img1s.append(img1) img2s.append(img2) seg1s.append(seg1) seg2s.append(seg2) progressbar.update() img = np.stack(imgs, 0) seg = np.stack(segs, 0) img1 = np.stack(img1s, 0) img2 = np.stack(img2s, 0) seg1 = np.stack(seg1s, 0) seg2 = np.stack(seg2s, 0) progressbar.close() print( f"Had to sample with replacement for {not_enough_segs_count} of {N} images" ) print( f"At least one missing segment in {augment_removed_segmentation} of {N} images due to augmentation" ) def _segshow(s): plt.imshow(s, alpha=0.4, extent=[0, imshape[0], imshape[1], 0], cmap="tab20", vmin=0, vmax=num_samples - 1) for j in range(5): plt.subplot(5, 4, 4 * j + 1) plt.imshow(img[j]) plt.axis(False) if j == 0: plt.title("original") plt.subplot(5, 4, 4 * j + 2) plt.imshow(img[j]) plt.axis(False) _segshow(seg[j].argmax(-1)) if j == 0: plt.title("segments") plt.subplot(5, 4, 4 * j + 3) plt.imshow(img1[j]) plt.axis(False) _segshow(seg1[j].argmax(-1)) if j == 0: plt.title("augmentation 1") plt.subplot(5, 4, 4 * j + 4) plt.imshow(img2[j]) plt.axis(False) _segshow(seg2[j].argmax(-1)) if j == 0: plt.title("augmentation 2")
def _build_simclr_dataset(imfiles, imshape=(256, 256), batch_size=256, num_parallel_calls=None, norm=255, num_channels=3, augment=True, single_channel=False, stratify=None): """ :stratify: if not None, a list of categories for each element in imfile. """ if stratify is not None: categories = list(set(stratify)) # SINGLE-INPUT CASE if isinstance(imfiles[0], str): file_lists = [[ imfiles[i] for i in range(len(imfiles)) if stratify[i] == c ] for c in categories] # DUAL-INPUT else: file_lists = [([ imfiles[0][i] for i in range(len(imfiles[0])) if stratify[i] == c ], [ imfiles[1][i] for i in range(len(imfiles[1])) if stratify[i] == c ]) for c in categories] datasets = [ _build_simclr_dataset(f, imshape=imshape, batch_size=batch_size, num_parallel_calls=num_parallel_calls, norm=norm, num_channels=num_channels, augment=augment, single_channel=single_channel, stratify=None) for f in file_lists ] return tf.data.experimental.sample_from_datasets(datasets) assert augment != False, "don't you need to augment your data?" ds = _image_file_dataset(imfiles, imshape=imshape, num_parallel_calls=num_parallel_calls, norm=norm, num_channels=num_channels, shuffle=True, single_channel=single_channel, augment=False) # SINGLE-INPUT CASE (DEFAULT) #if isinstance(imfiles, tf.data.Dataset) or isinstance(imfiles[0], str): _aug = augment_function(imshape, augment) @tf.function def _augment_and_stack(*x): # if there's only one input, augment it twice (standard SimCLR). # if there are two, augment them separately (case where user is # trying to express some specific semantics) x0 = tf.reshape(x[0], (imshape[0], imshape[1], num_channels)) if len(x) == 2: x1 = tf.reshape(x[1], (imshape[0], imshape[1], num_channels)) else: x1 = x0 y = tf.constant(np.array([1, -1]).astype(np.int32)) return tf.stack([_aug(x0), _aug(x1)]), y ds = ds.map(_augment_and_stack, num_parallel_calls=num_parallel_calls) ds = ds.unbatch() ds = ds.batch(2 * batch_size, drop_remainder=True) ds = ds.prefetch(1) return ds
def test_default_augment(): augfunc = augment_function(test_shape[:2]) augmented = augfunc(test_img_tensor) assert isinstance(augmented, tf.Tensor) assert augmented.get_shape() == test_img_tensor.get_shape()