Exemplo n.º 1
0
def main(args):
    with open(args.lst, 'r') as f:
        srclist = [x.strip() for x in f]

    for src in srclist:

        dst = repext(src, args.suffix)
        if os.path.exists(dst):
            print('Exists', src, '-->', dst)
            continue

        # Loading data from ramdisk incurs a one-time copy cost
        rdsrc = cpramdisk(src, args.ramdisk)
        print('File:', rdsrc)

        try:
            slide = Slide(src, args)
            slide.initialize_output('wsi',
                                    3,
                                    mode='full',
                                    compute_fn=compute_fn)
            ret = slide.compute('wsi', args)
            print('Saving {} --> {}'.format(ret.shape, dst))
            cv2.imwrite(dst, ret)
        except Exception as e:
            traceback.print_tb(e.__traceback__)
        finally:
            print('Removing {}'.format(rdsrc))
            os.remove(rdsrc)
Exemplo n.º 2
0
def main(args):

    print(f'Load model from {args.snapshot}')
    model = tf.keras.models.load_model(args.snapshot)

    print(f'Processing {args.slide}')
    dst_base = os.path.basename(args.slide).replace('svs', 'npy')

    if not os.path.isdir(args.dest):
        os.makedirs(args.dest)

    dst = os.path.join(args.dest, dst_base)
    print(f'Destination {dst}')

    ramdisk_file = cp_ramdisk(args.slide)
    slide = Slide(ramdisk_file, args)
    try:
        slide.initialize_output('prob', 4, mode='tile', compute_fn=compute_fn)
        ret = slide.compute('prob', args, model=model)
        np.save(dst, ret)

    except Exception as e:
        print('Caught error processing')
        traceback.print_tb(e.__traceback__)
        print(e)

    finally:
        os.remove(ramdisk_file)
Exemplo n.º 3
0
def main(args, sess):
    with open(args.slides, 'r') as f:
        srclist = [x.strip() for x in f]

    # image_op = tf.placeholder(tf.float32, (args.batchsize, args.process_size,
    #                           args.process_size, 3))
    # module = hub.Module(args.snapshot)
    # predict_op = module(image_op)

    image_op, predict_op = get_input_output_ops(sess, args.snapshot)

    for src in srclist:
        dst = os.path.splitext(src)[0] + '.{}'.format(args.ext)
        dst_base = os.path.basename(dst)
        dst = os.path.join(args.dest, dst_base)
        print(dst)

        slide = Slide(src, args)
        slide.initialize_output('prob', 4, mode='tile', compute_fn=compute_fn)

        ret = slide.compute('prob',
                            args,
                            predict_op=predict_op,
                            image_op=image_op)

        np.save(dst, ret)
Exemplo n.º 4
0
def main(args):
    ## Read in the slides
    slides = []
    with open(args.slides, 'r') as f:
        for line in f:
            lstrip = line.strip()
            if os.path.exists(lstrip):
                slides.append(lstrip)
            else:
                print(lstrip)
    print('Working with {} slides'.format(len(slides)))

    ## Initialize the dataset; store the metadata file
    with open(args.meta_file, 'r') as f:
        lines = [line for line in f]

    ## Read in the labels
    attrs = ['case_id', 'stage_str', 'stage_code']
    labels = read_labels(args.labels, attrs)
    print('Working with labels:', len(labels))

    ## Create the dataset
    meta_string = ''.join(lines)
    if not os.path.exists(args.data_h5):
        create_dataset(args.data_h5, meta_string)
    else:
        print('HDF5 {} already exists'.format(args.data_h5))

    ## Load the dataset
    mildataset = MILDataset(args.data_h5, meta_string)

    # slides = [s for s in slides if s not in mildataset.data_group_names]
    for i, src in enumerate(slides):
        print('\n[\t{}/{}\t]'.format(i, len(slides)))
        print('File {:04d} {} --> {}'.format(i, src, args.ramdisk))
        basename = os.path.splitext(os.path.basename(src))[0]

        try:
            lab = labels[basename]
            print(basename, lab)
            if lab[-1] > 1:
                print('Skipping unused labels')
                continue
        except Exception as e:
            print('basename {} no labels.'.format(basename))
            traceback.print_tb(e.__traceback__)

        rdsrc = cpramdisk(src, args.ramdisk)
        try:
            slide = Slide(rdsrc, args)
            tile_stack = stack_tiles(slide, args)
            mildataset.new_dataset(basename, tile_stack, attrs, lab)

        except Exception as e:
            print('Breaking')
            traceback.print_tb(e.__traceback__)

        finally:
            print('Removing {}'.format(rdsrc))
            os.remove(rdsrc)
Exemplo n.º 5
0
def main(args):
    # Define a compute_fn that should do three things:
    # 1. define an iterator over the slide's tiles
    # 2. compute an output with given model parameter
    # 3.

    if args.iter_type == 'python':

        def compute_fn(slide, args, model=None):
            print('Slide with {}'.format(len(slide.tile_list)))
            it_factory = PythonIterator(slide, args)
            for k, (img, idx) in enumerate(it_factory.yield_batch()):
                prob = model(img)
                if k % 50 == 0:
                    print('Batch #{:04d} idx:{} img:{} prob:{}'.format(
                        k, idx.shape, img.shape, prob.shape))
                slide.place_batch(prob, idx, 'prob', mode='tile')
            ret = slide.output_imgs['prob']
            return ret

    # Tensorflow multithreaded queue-based iterator (in eager mode)
    elif args.iter_type == 'tf':

        def compute_fn(slide, args, model=None):
            assert tf.executing_eagerly()
            print('Slide with {}'.format(len(slide.tile_list)))

            # In eager mode, we return a tf.contrib.eager.Iterator
            eager_iterator = TensorflowIterator(slide, args).make_iterator()

            # The iterator can be used directly. Ququeing and multithreading
            # are handled in the backend by the tf.data.Dataset ops
            features, indices = [], []
            for k, (img, idx) in enumerate(eager_iterator):
                # img = tf.expand_dims(img, axis=0)
                features.append(
                    model.encode_bag(img, training=False, return_z=True))
                indices.append(idx.numpy())

                img, idx = img.numpy(), idx.numpy()
                if k % 50 == 0:
                    print('Batch #{:04d}\t{}'.format(k, img.shape))

            features = tf.concat(features, axis=0)
            z_att, att = model.mil_attention(features,
                                             training=False,
                                             return_raw_att=True)
            att = np.squeeze(att)
            indices = np.concatenate(indices)
            slide.place_batch(att, indices, 'att', mode='tile')
            ret = slide.output_imgs['att']
            return ret

    # Set up the model first
    encoder_args = get_encoder_args(args.encoder)
    model = MilkEager(encoder_args=encoder_args,
                      mil_type=args.mil,
                      deep_classifier=args.deep_classifier,
                      batch_size=args.batchsize,
                      temperature=args.temperature,
                      heads=args.heads)

    x = tf.zeros((1, 1, args.process_size, args.process_size, 3))
    _ = model(x, verbose=True, head='all', training=True)
    model.load_weights(args.snapshot, by_name=True)

    # keras Model subclass
    model.summary()

    # Read list of inputs
    with open(args.slides, 'r') as f:
        slides = [x.strip() for x in f]

    # Loop over slides
    for src in slides:
        # Dirty substitution of the file extension give us the
        # destination. Do this first so we can just skip the slide
        # if this destination already exists.
        # Set the --suffix option to reflect the model / type of processed output
        dst = repext(src, args.suffix)

        # Loading data from ramdisk incurs a one-time copy cost
        rdsrc = cpramdisk(src, args.ramdisk)
        print('File:', rdsrc)

        # Wrapped inside of a try-except-finally.
        # We want to make sure the slide gets cleaned from
        # memory in case there's an error or stop signal in the
        # middle of processing.
        try:
            # Initialze the side from our temporary path, with
            # the arguments passed in from command-line.
            # This returns an svsutils.Slide object
            slide = Slide(rdsrc, args)

            # This step will eventually be included in slide creation
            # with some default compute_fn's provided by svsutils
            # For now, do it case-by-case, and use the compute_fn
            # that we defined just above.
            slide.initialize_output('att',
                                    args.n_classes,
                                    mode='tile',
                                    compute_fn=compute_fn)

            # Call the compute function to compute this output.
            # Again, this may change to something like...
            #     slide.compute_all
            # which would loop over all the defined output types.
            ret = slide.compute('att', args, model=model)
            print('{} --> {}'.format(ret.shape, dst))
            np.save(dst, ret[:, :, ::-1])
        except Exception as e:
            print(e)
            traceback.print_tb(e.__traceback__)
        finally:
            print('Removing {}'.format(rdsrc))
            os.remove(rdsrc)
Exemplo n.º 6
0
def main(args):


  # Define a compute_fn that should do three things:
  # 1. define an iterator over the slide's tiles
  # 2. compute an output with a given model / arguments
  # 3. return a reconstructed slide
  def compute_fn(slide, args, model=None, n_dropout=10 ):
    assert tf.executing_eagerly()
    print('Slide with {}'.format(len(slide.tile_list)))

    # In eager mode, we return a tf.contrib.eager.Iterator
    eager_iterator = TensorflowIterator(slide, args).make_iterator()

    # The iterator can be used directly. Ququeing and multithreading
    # are handled in the backend by the tf.data.Dataset ops
    features, indices = [], []
    for k, (img, idx) in enumerate(eager_iterator):
      # img = tf.expand_dims(img, axis=0)
      features.append( model.encode_bag(img, training=False, return_z=True) )
      indices.append(idx.numpy())

      img, idx = img.numpy(), idx.numpy()
      if k % 50 == 0:
        print('Batch #{:04d}\t{}'.format(k, img.shape))

    features = tf.concat(features, axis=0)

    ## Sample-dropout
    # features = features.numpy()
    # print(features.shape)
    # n_instances = features.shape[0]
    # att = np.zeros(n_instances)
    # n_choice = int(n_instances * 0.7)
    # all_heads = list(range(args.heads))
    # for j in range(n_dropout):
    #   idx = np.random.choice(range(n_instances), n_choice, replace=False)
    #   print(idx)
    #   fdrop = features[idx, :]

    z_att, att = model.mil_attention(features,
                                     training=False, 
                                     return_raw_att=True)

    # att[idx] += np.squeeze(attdrop)
    yhat_multihead = model.apply_classifier(z_att, heads=all_heads, 
      training=False)
    print('yhat mean {}'.format(np.mean(yhat_multihead, axis=0)))

    indices = np.concatenate(indices)
    att = np.squeeze(att)
    slide.place_batch(att, indices, 'att', mode='tile')
    ret = slide.output_imgs['att']
    print('Got attention image: {}'.format(ret.shape))

    return ret, features.numpy()




  ## Begin main script:
  # Set up the model first
  encoder_args = get_encoder_args(args.encoder)
  model = MilkEager(encoder_args=encoder_args,
                    mil_type=args.mil,
                    deep_classifier=args.deep_classifier,
                    batch_size=args.batchsize,
                    temperature=args.temperature,
                    heads = args.heads)
  
  x = tf.zeros((1, 1, args.process_size,
                args.process_size, 3))
  all_heads = [0,1,2,3,4,5,6,7,8,9]
  _ = model(x, verbose=True, heads=all_heads, training=True)
  model.load_weights(args.snapshot, by_name=True)

  # keras Model subclass
  model.summary()

  # Read list of inputs
  with open(args.slides, 'r') as f:
    slides = [x.strip() for x in f]

  # Loop over slides
  for src in slides:
    # Dirty substitution of the file extension give us the
    # destination. Do this first so we can just skip the slide
    # if this destination already exists.
    # Set the --suffix option to reflect the model / type of processed output
    dst = repext(src, args.suffix)
    featdst = repext(src, args.suffix+'.feat.npy')

    # Loading data from ramdisk incurs a one-time copy cost
    rdsrc = cpramdisk(src, args.ramdisk)
    print('\n\nFile:', rdsrc)

    # Wrapped inside of a try-except-finally.
    # We want to make sure the slide gets cleaned from 
    # memory in case there's an error or stop signal in the 
    # middle of processing.
    try:
      # Initialze the side from our temporary path, with 
      # the arguments passed in from command-line.
      # This returns an svsutils.Slide object
      slide = Slide(rdsrc, args)

      # This step will eventually be included in slide creation
      # with some default compute_fn's provided by svsutils
      # For now, do it case-by-case, and use the compute_fn
      # that we defined just above.
      slide.initialize_output('att', args.n_classes, mode='tile',
        compute_fn=compute_fn)

      # Call the compute function to compute this output.
      # Again, this may change to something like...
      #     slide.compute_all
      # which would loop over all the defined output types.
      ret, features = slide.compute('att', args, model=model)
      print('{} --> {}'.format(ret.shape, dst))
      print('{} --> {}'.format(features.shape, featdst))
      np.save(dst, ret)
      np.save(featdst, features)
    except Exception as e:
      print(e)
      traceback.print_tb(e.__traceback__)
    finally:
      print('Removing {}'.format(rdsrc))
      os.remove(rdsrc)
Exemplo n.º 7
0
from __future__ import print_function
import cv2
import numpy as np
import sys

from matplotlib import pyplot as plt

from svsutils import Slide

s = Slide(
    slide_path=
    '/home/nathan/data/ccrcc/TCGA_KIRC/TCGA-A3-3346-01Z-00-DX1.95280216-fd71-4a03-b452-6e3d667f2542.svs',
    process_mag=5,
    process_size=512,
    oversample_factor=1.25)
s.initialize_output(n_classes=3)
s.print_info()

for idx, img in enumerate(s.generator()):
    s.place(img[:, :, ::-1], idx)

reconstruction = s.output_img
print(reconstruction.shape)

plt.imshow(reconstruction)
plt.show()
Exemplo n.º 8
0
from svsutils import Slide

config = tf.ConfigProto()
config.gpu_options.allow_growth = True

"""
https://stackoverflow.com/questions/47086599/parallelising-tf-data-dataset-from-generator
https://www.tensorflow.org/programmers_guide/datasets
https://www.tensorflow.org/api_docs/python/tf/data/Dataset
"""

slide_path = '/media/ing/D/svs/TCGA_KIRC/TCGA-A3-3306-01Z-00-DX1.bfd320d3-f3ec-4015-b34a-98e9967ea80d.svs'

preprocess_fn = lambda x: ((x * (2/255.)) - 1).astype(np.float32)
svs = Slide(slide_path    = slide_path,
          process_mag   = 5,
          process_size  = 512,
          preprocess_fn = preprocess_fn )
svs.print_info()
svs.initialize_output('features', dim=3, mode='tile')


def wrapped_fn(idx):
    coords = svs.tile_list[idx]
    img = svs._read_tile(coords)
    return img, idx

def read_region_at_index(idx):
    return tf.py_func(func     = wrapped_fn,
                      inp      = [idx],
                      Tout     = [tf.float32, tf.int64],
                      stateful = False)
Exemplo n.º 9
0
def main(args, sess):
    # Define a compute_fn that should do three things:
    # 1. define an iterator over the slide's tiles
    # 2. compute an output with given model parameter
    # 3. asseble / gather the output
    #
    # compute_fn - function can define part of a computation
    # graph in eager mode -- possibly in graph mode.
    # We should completely reset the graph each call then
    # I still don't know how nodes are actually represented in memory
    # or if keeping them around has a real cost.

    def compute_fn(slide, args, sess=None):
        # assert tf.executing_eagerly()
        print('\n\nSlide with {}'.format(len(slide.tile_list)))

        # I'm not sure if spinning up new ops every time is bad.
        # In this example the iterator is separate from the
        # infernce function, it can also be set up with the two
        # connected to skip the feed_dict
        tf_iterator = TensorflowIterator(slide, args).make_iterator()
        img_op, idx_op = tf_iterator.get_next()
        # prob_op = model(img_op)
        # sess.run(tf.global_variables_initializer())

        # The iterator can be used directly. Ququeing and multithreading
        # are handled in the backend by the tf.data.Dataset ops
        # for k, (img, idx) in enumerate(eager_iterator):
        k, nk = 0, 0
        while True:
            try:
                img, idx = sess.run([
                    img_op,
                    idx_op,
                ])
                prob = model.inference(img)
                nk += img.shape[0]
                slide.place_batch(prob, idx, 'prob', mode='full', clobber=True)
                k += 1

                if k % 50 == 0:
                    prstr = 'Batch #{:04d} idx:{} img:{} ({:2.2f}-{:2.2f}) prob:{} T {} \
          '.format(k, idx.shape, img.shape, img.min(), img.max(), prob.shape,
                    nk)
                    print(prstr)
                    if args.verbose:
                        print('More info: ')
                        print('img: ', img.dtype, img.min(), img.max(),
                              img.mean())
                        pmax = np.argmax(prob, axis=-1).ravel()
                        for u in range(args.n_classes):
                            count_u = (pmax == u).sum()
                            print('- class {:02d} : {}'.format(u, count_u))

            except tf.errors.OutOfRangeError:
                print('Finished.')
                print('Total: {}'.format(nk))
                break

            except Exception as e:
                print(e)
                traceback.print_tb(e.__traceback__)
                break

        # We've exited the loop. Clean up the iterator
        del tf_iterator, idx_op, img_op

        # slide.make_outputs()
        slide.make_outputs()
        ret = slide.output_imgs['prob']
        return ret

    # Set up the model first
    model = gg.get_model(args.model, sess, args.process_size, args.n_classes)
    # NOTE big time wasted because you have to initialize,
    # THEN run the restore op to replace the already-created weights
    sess.run(tf.global_variables_initializer())
    model.restore(args.snapshot)

    # Read list of inputs
    with open(args.slides, 'r') as f:
        slides = [x.strip() for x in f]

    # Loop over slides; Record times
    nslides = len(slides)
    successes, ntiles, total_time, fpss = [], [], [], []
    for i, src in enumerate(slides):
        # Dirty substitution of the file extension give us the
        # destination. Do this first so we can just skip the slide
        # if this destination already exists.
        # Set the --suffix option to reflect the model / type of processed output
        dst = repext(src, args.suffix)
        if os.path.exists(dst):
            print('{} Exists.'.format(dst))
            continue

        # Loading data from ramdisk incurs a one-time copy cost
        rdsrc = cpramdisk(src, args.ramdisk)

        # Wrapped inside of a try-except-finally.
        # We want to make sure the slide gets cleaned from
        # memory in case there's an error or stop signal in the
        # middle of processing.
        try:
            # Initialze the side from our temporary path, with
            # the arguments passed in from command-line.
            # This returns an svsutils.Slide object
            print('\n\n-------------------------------')
            print('File:', rdsrc, '{:04d} / {:04d}'.format(i, nslides))
            t0 = time.time()
            slide = Slide(rdsrc, args)

            # This step will eventually be included in slide creation
            # with some default compute_fn's provided by svsutils
            # For now, do it case-by-case, and use the compute_fn
            # that we defined just above.
            # TODO pull the expected output size from the model.. ?
            # support common model types - keras, tfmodels, tfhub..
            slide.initialize_output('prob',
                                    args.n_classes,
                                    mode='full',
                                    compute_fn=compute_fn)

            # Call the compute function to compute this output.
            # Again, this may change to something like...
            #     slide.compute_all
            # which would loop over all the defined output types.
            ret = slide.compute('prob', args, sess=sess)
            print('{} --> {}'.format(ret.shape, dst))
            ret = (ret * 255).astype(np.uint8)
            np.save(dst, ret)

            # If it finishes, record some stats
            tend = time.time()
            deltat = tend - t0
            fps = len(slide.tile_list) / float(deltat)
            successes.append(rdsrc)
            ntiles.append(len(slide.tile_list))
            total_time.append(deltat)
            fpss.append(fps)
        except Exception as e:
            print(e)
            traceback.print_tb(e.__traceback__)
        finally:
            print('Removing {}'.format(rdsrc))
            os.remove(rdsrc)
            try:
                print('Cleaning slide object')
                slide.close()
                del slide
            except:
                print('No slide object not found to clean up ?')

    write_times(args.timefile, successes, ntiles, total_time, fpss)
Exemplo n.º 10
0
https://stackoverflow.com/questions/47086599/parallelising-tf-data-dataset-from-generator
https://www.tensorflow.org/programmers_guide/datasets
https://www.tensorflow.org/api_docs/python/tf/data/Dataset
"""

# slide_path = '/media/ing/D/svs/TCGA_KIRC/TCGA-A3-3306-01Z-00-DX1.bfd320d3-f3ec-4015-b34a-98e9967ea80d.svs'
slide_path = '/mnt/slowdata/slide_data/VA_PNBX/SP 02-4466 L3.svs'

print('Testing "accurate" background method')
tstart = time.time()
preprocess_fn = lambda x: (x * 1 / 255.).astype(np.float32)
svs = Slide(slide_path=slide_path,
            process_mag=5,
            process_size=96,
            oversample_factor=1.75,
            preprocess_fn=preprocess_fn,
            background_speed='accurate',
            background_threshold=210,
            background_pct=0.15,
            verbose=True)
svs.print_info()
svs.initialize_output('features', dim=3, mode='tile')
print('Initialized slide object in {}s'.format(time.time() - tstart))


def wrapped_fn(idx):
    coords = svs.tile_list[idx]
    img = svs._read_tile(coords)
    return img, idx

Exemplo n.º 11
0
"""
https://stackoverflow.com/questions/47086599/parallelising-tf-data-dataset-from-generator
https://www.tensorflow.org/programmers_guide/datasets
https://www.tensorflow.org/api_docs/python/tf/data/Dataset
"""

# slide_path = '/home/nathan/data/ccrcc/TCGA_KIRC/'
# slide_path += 'TCGA-A3-3346-01Z-00-DX1.95280216-fd71-4a03-b452-6e3d667f2542.svs'

# slide_path = '/home/nathan/data/gleason_grade/wsi/s10-3220-001.svs'
# slide_path = '/home/nathan/biobank/_DURHAM_SLIDES_/_Cases_/S10-3220/s10-3220-001.svs'
slide_path = '/dev/shm/s13_2243-016.svs'

preprocess_fn = lambda x: ((x * (2 / 255.)) - 1).astype(np.float32)
s = Slide(slide_path=slide_path,
          process_mag=10,
          process_size=256,
          preprocess_fn=preprocess_fn)
s.print_info()

# for ix in s.generate_index():
#     print ix


def wrapped_fn(idx):
    coords = s.tile_list[idx]
    img = s._read_tile(coords)
    return img, idx


def read_region_at_index(idx):
    return tf.py_func(func=wrapped_fn,
Exemplo n.º 12
0
from __future__ import print_function
import cv2
import numpy as np
import sys

from svsutils import Slide

print('\nSlide at 20x')
s = Slide(
    slide_path=
    '/home/nathan/data/ccrcc/TCGA_KIRC/TCGA-A3-3346-01Z-00-DX1.95280216-fd71-4a03-b452-6e3d667f2542.svs',
    process_mag=20,
)

print('\nSlide at 10x')
s = Slide(
    slide_path=
    '/home/nathan/data/ccrcc/TCGA_KIRC/TCGA-A3-3346-01Z-00-DX1.95280216-fd71-4a03-b452-6e3d667f2542.svs',
    process_mag=10,
)

print('\nSlide at 5x')
s = Slide(
    slide_path=
    '/home/nathan/data/ccrcc/TCGA_KIRC/TCGA-A3-3346-01Z-00-DX1.95280216-fd71-4a03-b452-6e3d667f2542.svs',
    process_mag=5,
)

s.print_info()

for idx, img in enumerate(s.generator()):
Exemplo n.º 13
0
from __future__ import print_function
import numpy as np
import tensorflow as tf
import tensorflow.contrib.eager as tfe
import sys
import cv2

from svsutils import Slide

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
tfe.enable_eager_execution(config=config)

print('\nslide at 5x')
slide_path = '/home/nathan/data/ccrcc/TCGA_KIRC/'
slide_path += 'TCGA-A3-3346-01Z-00-DX1.95280216-fd71-4a03-b452-6e3d667f2542.svs'
preprocess_fn = lambda x: (x * (2 / 255.)) - 1
s = Slide(slide_path=slide_path,
          process_mag=5,
          process_size=128,
          preprocess_fn=preprocess_fn)
s.print_info()

ds = tf.data.Dataset.from_generator(generator=s.generator,
                                    output_types=tf.float32)
ds = tfe.Iterator(ds)

for idx, x in enumerate(ds):
    print x.shape, x.dtype
    # cv2.imwrite('debug/{}.jpg'.format(idx), x.numpy()[:,:,::-1])
Exemplo n.º 14
0
def main(args, sess):
    # Define a compute_fn that should do three things:
    # 1. define an iterator over the slide's tiles
    # 2. compute an output with given model parameter
    # 3.

    # def compute_fn(slide, args, model=None):
    #   print('Slide with {}'.format(len(slide.tile_list)))
    #   it_factory = PythonIterator(slide, args)
    #   for k, (img, idx) in enumerate(it_factory.yield_batch()):
    #     prob = model.predict_on_batch(img)
    #     if k % 50 == 0:
    #       print('Batch #{:04d} idx:{} img:{} prob:{} \
    #       '.format(k, idx.shape, img.shape, prob.shape))
    #     slide.place_batch(prob, idx, 'prob', mode='tile')
    #   ret = slide.output_imgs['prob']
    #   return ret

    # Tensorflow multithreaded queue-based iterator (in eager mode)
    # elif args.iter_type == 'tf':

    def compute_fn(slide, args, sess=None, img_pl=None, prob_op=None):
        # assert tf.executing_eagerly()
        print('\n\nSlide with {}'.format(len(slide.tile_list)))

        # I'm not sure if spinning up new ops every time is bad.
        tf_iterator = TensorflowIterator(slide, args).make_iterator()
        img_op, idx_op = tf_iterator.get_next()
        # prob_op = model(img_op)
        # sess.run(tf.global_variables_initializer())

        # The iterator can be used directly. Ququeing and multithreading
        # are handled in the backend by the tf.data.Dataset ops
        # for k, (img, idx) in enumerate(eager_iterator):
        k, nk = 0, 0
        while True:
            try:
                img, idx = sess.run([
                    img_op,
                    idx_op,
                ])
                prob = sess.run(prob_op, {img_pl: img})
                nk += img.shape[0]
                if k % 50 == 0:
                    print('Batch #{:04d} idx:{} img:{} ({}) prob:{} T {} \
          '.format(k, idx.shape, img.max(), img.shape, prob.shape, nk))
                slide.place_batch(prob, idx, 'prob', mode='tile')
                k += 1
            except tf.errors.OutOfRangeError:
                print('Finished.')
                print('Total: {}'.format(nk))
                break
            finally:
                ret = slide.output_imgs['prob']
        return ret

    # Set up the model first

    # Set up a placeholder for the input
    img_pl = tf.placeholder(tf.float32,
                            (None, args.process_size, args.process_size, 3))
    model = load_model(args.snapshot)
    prob_op = model(img_pl)
    sess.run(tf.global_variables_initializer())

    # Read list of inputs
    with open(args.slides, 'r') as f:
        slides = [x.strip() for x in f]

    # Loop over slides
    for src in slides:
        # Dirty substitution of the file extension give us the
        # destination. Do this first so we can just skip the slide
        # if this destination already exists.
        # Set the --suffix option to reflect the model / type of processed output
        dst = repext(src, args.suffix)

        # Loading data from ramdisk incurs a one-time copy cost
        rdsrc = cpramdisk(src, args.ramdisk)
        print('File:', rdsrc)

        # Wrapped inside of a try-except-finally.
        # We want to make sure the slide gets cleaned from
        # memory in case there's an error or stop signal in the
        # middle of processing.
        try:
            # Initialze the side from our temporary path, with
            # the arguments passed in from command-line.
            # This returns an svsutils.Slide object
            print('\n\n-------------------------------')
            slide = Slide(rdsrc, args)

            # This step will eventually be included in slide creation
            # with some default compute_fn's provided by svsutils
            # For now, do it case-by-case, and use the compute_fn
            # that we defined just above.
            slide.initialize_output('prob',
                                    4,
                                    mode='tile',
                                    compute_fn=compute_fn)

            # Call the compute function to compute this output.
            # Again, this may change to something like...
            #     slide.compute_all
            # which would loop over all the defined output types.
            ret = slide.compute('prob',
                                args,
                                sess=sess,
                                img_pl=img_pl,
                                prob_op=prob_op)
            print('{} --> {}'.format(ret.shape, dst))
            np.save(dst, ret)
        except Exception as e:
            print(e)
            traceback.print_tb(e.__traceback__)
        finally:
            print('Removing {}'.format(rdsrc))
            os.remove(rdsrc)
Exemplo n.º 15
0
from __future__ import print_function
import numpy as np
import tensorflow as tf
import sys
import cv2
import time

from svsutils import Slide

config = tf.ConfigProto()
config.gpu_options.allow_growth = True

print('\nslide at 5x')
s = Slide(
    slide_path=
    '/home/nathan/data/ccrcc/TCGA_KIRC/TCGA-A3-3346-01Z-00-DX1.95280216-fd71-4a03-b452-6e3d667f2542.svs',
    process_mag=5,
)
s.print_info()

with tf.Session(config=config) as sess:
    ds = tf.data.Dataset.from_generator(generator=s.generator,
                                        output_types=tf.float32)
    # ds = ds.map(lambda x: x, num_parallel_calls=8)
    # ds = ds.prefetch(128)
    iterator = ds.make_one_shot_iterator()
    img = iterator.get_next()

    tstart = time.time()
    for x in range(len(s.tile_list)):
        img_ = sess.run(img)