Esempio n. 1
0
def _build_byol_dataset(imfiles,
                        imshape=(256, 256),
                        batch_size=256,
                        num_parallel_calls=None,
                        norm=255,
                        num_channels=3,
                        augment=True,
                        single_channel=False):
    """
    :stratify: if not None, a list of categories for each element in
        imfile.
    """
    assert augment != False, "don't you need to augment your data?"

    ds = _image_file_dataset(imfiles,
                             imshape=imshape,
                             num_parallel_calls=num_parallel_calls,
                             norm=norm,
                             num_channels=num_channels,
                             shuffle=True,
                             single_channel=single_channel,
                             augment=False)

    _aug = augment_function(imshape, augment)

    @tf.function
    def pair_augment(x):
        return (_aug(x), _aug(x)), np.array([1])

    ds = ds.map(pair_augment, num_parallel_calls=num_parallel_calls)

    ds = ds.batch(batch_size)
    ds = ds.prefetch(1)
    return ds
Esempio n. 2
0
def _build_simclr_dataset(imfiles,
                          imshape=(256, 256),
                          batch_size=256,
                          num_parallel_calls=None,
                          norm=255,
                          num_channels=3,
                          augment=True,
                          single_channel=False):
    """
    
    """
    assert augment, "don't you need to augment your data?"
    _aug = augment_function(imshape, augment)

    ds = _image_file_dataset(imfiles,
                             imshape=imshape,
                             num_parallel_calls=num_parallel_calls,
                             norm=norm,
                             num_channels=num_channels,
                             shuffle=True,
                             single_channel=single_channel)

    @tf.function
    def _augment_and_stack(x):
        y = tf.constant(np.array([1, -1]).astype(np.int32))
        return tf.stack([_aug(x), _aug(x)]), y

    ds = ds.map(_augment_and_stack, num_parallel_calls=num_parallel_calls)

    ds = ds.unbatch()
    ds = ds.batch(2 * batch_size, drop_remainder=True)
    ds = ds.prefetch(1)
    return ds
Esempio n. 3
0
def dataset(fps, ys = None, imshape=(256,256), num_channels=3, 
                 num_parallel_calls=None, norm=255, batch_size=256,
                 augment=False, unlab_fps=None, shuffle=False,
                 sobel=False, single_channel=False):
    """
    return a tf dataset that iterates over a list of images once
    
    :fps: list of filepaths
    :ys: array of corresponding labels
    :imshape: constant shape to resize images to
    :num_channels: channel depth of images
    :batch_size: just what you think it is
    :augment: augmentation parameters (or True for defaults, or False to disable)
    :unlab_fps: list of filepaths (same length as fps) for semi-
        supervised learning
    :shuffle: whether to shuffle the dataset
    :sobel: whether to replace the input image with its sobel edges
    :single_channel: if True, expect a single-channel input image and 
        stack it num_channels times.
    
    Returns
    :ds: tf.data.Dataset object to iterate over data. The dataset returns
        (x,y) tuples unless unlab_fps is included, in which case the structure
        will be ((x, x_unlab), y)
    :num_steps: number of steps (for passing to tf.keras.Model.fit())
    """
    if augment:
        _aug = augment_function(imshape, augment)
    ds = _image_file_dataset(fps, imshape=imshape, num_channels=num_channels, 
                      num_parallel_calls=num_parallel_calls, norm=norm,
                      shuffle=shuffle, single_channel=single_channel)
    
    if augment: ds = ds.map(_aug, num_parallel_calls=num_parallel_calls)
    if sobel: ds = ds.map(_sobelize, num_parallel_calls=num_parallel_calls)
        
    if unlab_fps is not None:
        u_ds = _image_file_dataset(unlab_fps, imshape=imshape, num_channels=num_channels, 
                      num_parallel_calls=num_parallel_calls, norm=norm,
                      single_channel=single_channel)
        if augment: u_ds = u_ds.map(_aug, num_parallel_calls=num_parallel_calls)
        if sobel: u_ds = u_ds.map(_sobelize, num_parallel_calls=num_parallel_calls)
        ds = tf.data.Dataset.zip((ds, u_ds))
        
    if ys is not None:
        ys = tf.data.Dataset.from_tensor_slices(ys)
        #if unlab_fps is not None:
        #    ys = ds.zip((ys,ys))
        #    #ys = ds.zip((u_ds,ys))
        ds = ds.zip((ds, ys))
        
    ds = ds.batch(batch_size)
    #if sobel:
    #    ds = ds.map(_sobelize, num_parallel_calls=num_parallel_calls)
    ds = ds.prefetch(1)
    
    num_steps = int(np.ceil(len(fps)/batch_size))
    return ds, num_steps
Esempio n. 4
0
def clip_dataset(imfiles,
                 labels,
                 encoder,
                 maxlen=76,
                 imshape=(256, 256),
                 num_channels=3,
                 num_parallel_calls=None,
                 norm=255,
                 batch_size=256,
                 augment=False,
                 shuffle=True,
                 single_channel=False):
    """
    Build a tf.data.Dataset object for training CLIP
    
    :imfiles: list of paths to training images
    :labels: list of strings containing a caption for each image 
    :encoder: sentencepiece object for mapping strings to byte pair encoded arrays
    :maxlen: int; length of sequences. BPE arrays will be padded or truncated to this.
    """
    ds = _image_file_dataset(imfiles,
                             ys=labels,
                             imshape=imshape,
                             augment=augment,
                             shuffle=shuffle)

    if augment:
        aug_func = augment_function(imshape, {"rot90": True})

    def _encode_text(y):
        y = str(y)
        y = encoder.encode(y, out_type=int, add_bos=True, add_eos=True)
        N = len(y)
        if N > maxlen:
            y = y[:maxlen]
        elif N < maxlen:
            y += [0] * (maxlen - N)
        return np.array(y)

    def _augment_and_encode(x, y):
        if augment:
            x = aug_func(x)
        y = tf.py_function(_encode_text, (y, ), Tout=tf.int64)
        return x, y

    ds = ds.map(_augment_and_encode, num_parallel_calls=num_parallel_calls)
    ds = ds.batch(batch_size)
    ds = ds.prefetch(1)
    return ds
Esempio n. 5
0
def _fixmatch_unlab_dataset(fps, weak_aug, str_aug, imshape=(256,256),
                            num_parallel_calls=None, norm=255, 
                            num_channels=3, single_channel=False,
                            batch_size=64):
    """
    Macro to build the weak/strong augmented pairs for FixMatch
    semisupervised learning
    """
    _weakaug = augment_function(imshape, weak_aug)
    _strongaug = augment_function(imshape, str_aug)
    def aug_pair(x):
        return _weakaug(x), _strongaug(x)
    
    
    ds = _image_file_dataset(fps, imshape=imshape, 
                             num_parallel_calls=num_parallel_calls, 
                             norm=norm, num_channels=num_channels,
                             shuffle=True, 
                             single_channel=single_channel, 
                             augment=False)
    ds = ds.map(aug_pair, num_parallel_calls=num_parallel_calls)
    ds = ds.batch(batch_size)
    ds = ds.prefetch(1)
    return ds
Esempio n. 6
0
def _single_augplot(filepath, aug_params=True, norm=255, num_channels=3, resize=None):
    """
    Input a path to an image and an augmentation function; sample
    15 augmentations and display using matplotlib.
    """
    img = _load_img(filepath, norm=norm, num_channels=num_channels, resize=resize)
    aug_func = augment_function(img.shape[:2], aug_params)
    plt.figure()
    plt.subplot(4,4,1)
    plt.imshow(img)
    plt.axis(False)
    plt.title("original")
    for i in range(2,17):
        plt.subplot(4,4,i)
        plt.imshow(aug_func(img).numpy())
        plt.axis(False)
Esempio n. 7
0
def _build_context_encoder_dataset(filepaths,
                                   input_shape=(256, 256, 3),
                                   norm=255,
                                   shuffle=True,
                                   num_parallel_calls=4,
                                   batch_size=32,
                                   prefetch=True,
                                   augment=False,
                                   sobel=False,
                                   single_channel=False):
    """
    Build a tf.data.Dataset object to use for training.
    """

    # first build a Dataset that generates masks
    def _gen():
        return mask_generator(*input_shape)

    mask_ds = tf.data.Dataset.from_generator(_gen,
                                             output_types=(tf.float32),
                                             output_shapes=input_shape)
    # now a Dataset to load images
    img_ds = _image_file_dataset(filepaths,
                                 imshape=input_shape[:2],
                                 num_channels=input_shape[2],
                                 norm=norm,
                                 num_parallel_calls=num_parallel_calls,
                                 shuffle=True,
                                 single_channel=single_channel)

    if augment:
        _aug = augment_function(input_shape[:2], augment)
        #_aug = augment_function(augment)
        img_ds = img_ds.map(_aug, num_parallel_calls=num_parallel_calls)
    if sobel:
        img_ds = img_ds.map(_sobelize, num_parallel_calls=num_parallel_calls)
    # combine the image and mask datasets
    zipped_ds = tf.data.Dataset.zip((img_ds, mask_ds))
    # precompute masked images for context encoder input
    masked_batched_ds = zipped_ds.batch(batch_size)

    if prefetch:
        masked_batched_ds = masked_batched_ds.prefetch(1)
    return masked_batched_ds
Esempio n. 8
0
 def _load_img(x,y):
     loaded = tf.io.read_file(x)
     img = decode(loaded)
     if single_channel:
         #resized = tf.concat(num_channels*[resized], -1)
         img = tf.concat(num_channels*[img], -1)
     # normalize
     img = tf.cast(img[:,:,:num_channels], tf.float32)/norm
     #img = tf.reshape(normed, (imshape[0], imshape[1], num_channels))
     if augment:
         img = augment_function(imshape, augment)(img)
     # TWO RESIZING OPERATIONS HERE
     # first: tf.image.resize(), which will resample the image if it
     # doesn't have the right dimensions
     img = tf.image.resize(img, imshape)
     # in some cases tf.image.resize() will return tensors of shape
     # [imshape[0], imshape[1], None], so now we'll specify the
     # channel dimension explicitly.
     img = tf.reshape(img, [imshape[0], imshape[1], num_channels])
     return img,y
Esempio n. 9
0
def _multiple_augplot(filepaths, aug_params=True, num_resamps=5, 
                      norm=255, num_channels=3, resize=None):
    """
    
    """
    num_files = len(filepaths)
    img = _load_img(filepaths[0], norm=norm, num_channels=num_channels, resize=resize)
    aug_func = augment_function(img.shape[:2], aug_params)
    
    plt.figure()
    for i in range(num_files):
        img = _load_img(filepaths[i], norm=norm, num_channels=num_channels, resize=resize)
        plt.subplot(num_files+1, num_resamps+1, 1+i*(num_resamps+1))
        plt.imshow(img)
        plt.axis(False)
        if i == 0: plt.title("original")
        for j in range(num_resamps):
            plt.subplot(num_files+1, num_resamps+1, j+2+i*(num_resamps+1))
            plt.imshow(aug_func(img).numpy())
            plt.axis(False)
Esempio n. 10
0
def build_augment_pair_dataset(imfiles, imshape=(256,256), batch_size=256, 
                      num_parallel_calls=None, norm=255,
                      num_channels=3, augment=True,
                      single_channel=False):
    """
    Build a tf.data.Dataset object for training momentum 
    contrast. Generates pairs of augmentations from a single
    image.
    """
    assert augment, "don't you need to augment your data?"
    _aug = augment_function(imshape, augment)
    
    ds = _image_file_dataset(imfiles, imshape=imshape, 
                             num_parallel_calls=num_parallel_calls,
                             norm=norm, num_channels=num_channels,
                             shuffle=True, single_channel=single_channel)  
    
    a1 = ds.map(_aug, num_parallel_calls=num_parallel_calls)
    a2 = ds.map(_aug, num_parallel_calls=num_parallel_calls)
   
    ds = ds.zip((a1, a2))
    ds = ds.batch(batch_size)
    ds = ds.prefetch(1)
    return ds
Esempio n. 11
0
def build_iic_dataset(imfiles, r=5, imshape=(256,256), batch_size=256, 
                      num_parallel_calls=None, norm=255,
                      num_channels=3, augment=True,
                      single_channel=False):
    """
    Build a tf.data.Dataset object for training IIC.
    """
    assert augment, "don't you need to augment your data?"
    _aug = augment_function(imshape, augment)
    
    ds = _image_file_dataset(imfiles, imshape=imshape, 
                             num_parallel_calls=num_parallel_calls,
                             norm=norm, num_channels=num_channels,
                             shuffle=True, single_channel=single_channel)
    
    if r > 1:
        ds = ds.flat_map(lambda x: tf.data.Dataset.from_tensors(x).repeat(r))
        
    augmented_ds = ds.map(_aug, num_parallel_calls=num_parallel_calls)
    
    ds = ds.zip((ds, augmented_ds))
    ds = ds.batch(batch_size)
    ds = ds.prefetch(1)
    return ds
Esempio n. 12
0
def _image_file_dataset(fps, ys=None, imshape=(256,256), 
                 num_parallel_calls=None, norm=255,
                 num_channels=3, shuffle=False,
                 single_channel=False, augment=False):
    """
    Basic tool to load images into a tf.data.Dataset using
    PIL.Image or gdal instead of the tensorflow decode functions
    
    :fps: list of filepaths
    :ys: optional list of labels
    :imshape: constant shape to resize images to
    :num_parallel_calls: number of processes to use for loading
    :norm: value for normalizing images
    :num_channels: channel depth to truncate images to
    :shuffle: whether to shuffle the dataset
    :single_channel: if True, expect a single-channel input image and 
        stack it num_channels times.
    :augment: optional, dictionary of augmentation parameters
    
    Returns tf.data.Dataset object with structure (x,y) if labels were passed, 
        and (x) otherwise. images (x) are a 3D float32 tensor and labels
        should be a 0D int64 tensor
    """    
    # PRE-BUILT DATASET CASE
    if isinstance(fps, tf.data.Dataset):
        ds = fps
        if augment:
            _aug = augment_function(imshape, augment)
            ds = ds.map(_aug, num_parallel_calls=num_parallel_calls)
    # DIRECTORY OF TFRECORD FILES CASE
    elif isinstance(fps, str):
        if augment:
            augment = augment_function(imshape, augment)
        ds = load_dataset_from_tfrecords(fps, imshape, num_channels,
                                         num_parallel_calls=num_parallel_calls,
                                         map_fn=augment)
    # LIST OF FILES CASE: list of filepaths (probably what almost 
    # always will get used)
    elif isinstance(fps[0], str):
        if ys is None:
            no_labels = True
            ys = np.zeros(len(fps), dtype=np.int64)
        else:
            no_labels = False
        # get an integer index for each filepath
        #imtypes = _generate_imtypes(fps)
        ds = tf.data.Dataset.from_tensor_slices((fps, ys))
        # do the shuffling before loading so we can have a big queue without
        # taking up much memory
        if shuffle:
            ds = ds.shuffle(len(fps))
        _load_img = _build_load_function(fps[0], imshape, norm, num_channels, 
                                         single_channel, augment)
        #ds = ds.map(lambda x,t,y: (_load_img(x),y), 
        #            num_parallel_calls=num_parallel_calls)
        ds = ds.map(_load_img, num_parallel_calls=num_parallel_calls)
        # if no labels were passed, strip out the y.
        if no_labels:
            ds = ds.map(lambda x,y: x)
            
    else:
        assert False, "what are these inputs"
            
    return ds
Esempio n. 13
0
def stratified_training_dataset(fps, y, imshape=(256,256), num_channels=3, 
                 num_parallel_calls=None, batch_size=256, mult=10,
                    augment=True, norm=255, sobel=False, single_channel=False):
    """
    Training dataset for DeepCluster.
    Build a dataset that provides stratified samples over labels
    
    :fps: list of strings containing paths to image files
    :y: array of cluster assignments- should have same length as fp
    :imshape: constant shape to resize images to
    :num_channels: channel depth of images
    :batch_size: just what you think it is
    :mult: not in paper; multiplication factor to increase
        number of steps/epoch. set to 1 to get paper algorithm
    :augment: augmentation parameters (or True for defaults, or False to disable)
    :sobel: whether to replace the input image with its sobel edges
    :single_channel: if True, expect a single-channel input image and 
        stack it num_channels times.
        
    Returns
    :ds: tf.data.Dataset object to iterate over data
    :num_steps: number of steps (for passing to tf.keras.Model.fit())
    """
    # sample indices to use
    indices = np.arange(len(fps))
    K = y.max()+1
    samples_per_cluster = mult*int(len(fps)/K)
    
    sampled_indices = []
    sampled_labels = []
    # for each cluster
    for k in range(K):
        # find indices of samples assigned to it
        cluster_inds = indices[y == k]
        # only sample if at least one is assigned. note that
        # the deepcluster paper takes an extra step here.
        if len(cluster_inds) > 0:
            samps = np.random.choice(cluster_inds, size=samples_per_cluster,
                            replace=True)
            sampled_indices.append(samps)
            sampled_labels.append(k*np.ones(len(samps), dtype=np.int64))
    # concatenate sampled indices for each cluster
    sampled_indices = np.concatenate(sampled_indices, 0)    
    sampled_labels = np.concatenate(sampled_labels, 0)
    # and shuffle their order together
    reorder = np.random.choice(np.arange(len(sampled_indices)),
                          size=len(sampled_indices), replace=False)
    sampled_indices = sampled_indices[reorder]
    sampled_labels = sampled_labels[reorder]
    fps = np.array(fps)[sampled_indices]
    
    # NOW CREATE THE DATASET
    im_ds = _image_file_dataset(fps, imshape=imshape, num_channels=num_channels, 
                      num_parallel_calls=num_parallel_calls, norm=norm, 
                      shuffle=False, single_channel=single_channel)

    if augment:
        #im_ds = im_ds.map(_augment, num_parallel_calls)
        _aug = augment_function(imshape, augment)
        im_ds = im_ds.map(_aug, num_parallel_calls=num_parallel_calls)
    lab_ds = tf.data.Dataset.from_tensor_slices(sampled_labels)
    ds = tf.data.Dataset.zip((im_ds, lab_ds))
    #ds = ds.batch(batch_size)
    if sobel:
        ds = ds.map(lambda x,y: (_sobelize(x),y), num_parallel_calls=num_parallel_calls)
    ds = ds.batch(batch_size)
    ds = ds.prefetch(1)
    
    num_steps = int(np.ceil(len(sampled_indices)/batch_size))
    return ds, num_steps
Esempio n. 14
0
def detcon_input_pipeline(imfiles,
                          augment,
                          mean_scale=1000,
                          num_samples=16,
                          outputsize=None,
                          imshape=(256, 256),
                          **kwargs):
    """
    """
    # build a dataset to load images one at a time
    ds = _image_file_dataset(imfiles, shuffle=False, imshape=imshape,
                             **kwargs).prefetch(1)
    # place to store some examples
    imgs = []
    segs = []
    img1s = []
    img2s = []
    seg1s = []
    seg2s = []

    N = len(imfiles)
    not_enough_segs_count = 0
    augment_removed_segmentation = 0

    aug2 = {k: augment[k] for k in augment if k not in SEG_AUG_FUNCTIONS}
    _aug = augment_function(imshape, aug2)

    progressbar = tqdm(total=N)
    # for each image
    for x in ds:
        # get the segments
        if mean_scale > 0:
            seg, enough_segs = _get_segments(x,
                                             mean_scale=mean_scale,
                                             num_samples=num_samples,
                                             return_enough_segments=True)
            # count how many times we had to sample with replacement
            if not enough_segs:
                not_enough_segs_count += 1
        else:
            seg = _get_grid_segments(imshape, num_samples)

        # now augment image and segmentation together, twice
        img1, seg1 = _segment_aug(x, seg, augment, outputsize=outputsize)
        img2, seg2 = _segment_aug(x, seg, augment, outputsize=outputsize)

        # check to see if any segments were pushed out of the image by augmentation
        segmentation_ok = _filter_out_bad_segments(img1, seg1, img2, seg2)
        if not segmentation_ok:
            augment_removed_segmentation += 1

        # finally, augment images separately

        img1 = _aug(img1).numpy()
        img2 = _aug(img2).numpy()
        seg1 = seg1.numpy()
        seg2 = seg2.numpy()
        if len(img1s) < 16:
            imgs.append(x)
            segs.append(seg)
            img1s.append(img1)
            img2s.append(img2)
            seg1s.append(seg1)
            seg2s.append(seg2)
        progressbar.update()

    img = np.stack(imgs, 0)
    seg = np.stack(segs, 0)
    img1 = np.stack(img1s, 0)
    img2 = np.stack(img2s, 0)
    seg1 = np.stack(seg1s, 0)
    seg2 = np.stack(seg2s, 0)
    progressbar.close()
    print(
        f"Had to sample with replacement for {not_enough_segs_count} of {N} images"
    )
    print(
        f"At least one missing segment in {augment_removed_segmentation} of {N} images due to augmentation"
    )

    def _segshow(s):
        plt.imshow(s,
                   alpha=0.4,
                   extent=[0, imshape[0], imshape[1], 0],
                   cmap="tab20",
                   vmin=0,
                   vmax=num_samples - 1)

    for j in range(5):
        plt.subplot(5, 4, 4 * j + 1)
        plt.imshow(img[j])
        plt.axis(False)
        if j == 0: plt.title("original")

        plt.subplot(5, 4, 4 * j + 2)
        plt.imshow(img[j])
        plt.axis(False)
        _segshow(seg[j].argmax(-1))
        if j == 0: plt.title("segments")

        plt.subplot(5, 4, 4 * j + 3)
        plt.imshow(img1[j])
        plt.axis(False)
        _segshow(seg1[j].argmax(-1))
        if j == 0: plt.title("augmentation 1")

        plt.subplot(5, 4, 4 * j + 4)
        plt.imshow(img2[j])
        plt.axis(False)
        _segshow(seg2[j].argmax(-1))
        if j == 0: plt.title("augmentation 2")
Esempio n. 15
0
def _build_simclr_dataset(imfiles,
                          imshape=(256, 256),
                          batch_size=256,
                          num_parallel_calls=None,
                          norm=255,
                          num_channels=3,
                          augment=True,
                          single_channel=False,
                          stratify=None):
    """
    :stratify: if not None, a list of categories for each element in
        imfile.
    """

    if stratify is not None:
        categories = list(set(stratify))
        # SINGLE-INPUT CASE
        if isinstance(imfiles[0], str):
            file_lists = [[
                imfiles[i] for i in range(len(imfiles)) if stratify[i] == c
            ] for c in categories]
        # DUAL-INPUT
        else:
            file_lists = [([
                imfiles[0][i] for i in range(len(imfiles[0]))
                if stratify[i] == c
            ], [
                imfiles[1][i] for i in range(len(imfiles[1]))
                if stratify[i] == c
            ]) for c in categories]
        datasets = [
            _build_simclr_dataset(f,
                                  imshape=imshape,
                                  batch_size=batch_size,
                                  num_parallel_calls=num_parallel_calls,
                                  norm=norm,
                                  num_channels=num_channels,
                                  augment=augment,
                                  single_channel=single_channel,
                                  stratify=None) for f in file_lists
        ]
        return tf.data.experimental.sample_from_datasets(datasets)

    assert augment != False, "don't you need to augment your data?"

    ds = _image_file_dataset(imfiles,
                             imshape=imshape,
                             num_parallel_calls=num_parallel_calls,
                             norm=norm,
                             num_channels=num_channels,
                             shuffle=True,
                             single_channel=single_channel,
                             augment=False)

    # SINGLE-INPUT CASE (DEFAULT)
    #if isinstance(imfiles, tf.data.Dataset) or isinstance(imfiles[0], str):
    _aug = augment_function(imshape, augment)

    @tf.function
    def _augment_and_stack(*x):
        # if there's only one input, augment it twice (standard SimCLR).
        # if there are two, augment them separately (case where user is
        # trying to express some specific semantics)
        x0 = tf.reshape(x[0], (imshape[0], imshape[1], num_channels))
        if len(x) == 2:
            x1 = tf.reshape(x[1], (imshape[0], imshape[1], num_channels))
        else:
            x1 = x0
        y = tf.constant(np.array([1, -1]).astype(np.int32))
        return tf.stack([_aug(x0), _aug(x1)]), y

    ds = ds.map(_augment_and_stack, num_parallel_calls=num_parallel_calls)

    ds = ds.unbatch()
    ds = ds.batch(2 * batch_size, drop_remainder=True)
    ds = ds.prefetch(1)
    return ds
Esempio n. 16
0
def test_default_augment():
    augfunc = augment_function(test_shape[:2])
    augmented = augfunc(test_img_tensor)
    
    assert isinstance(augmented, tf.Tensor)
    assert augmented.get_shape() == test_img_tensor.get_shape()