Esempio n. 1
0
def main(args, sess):
    ## Search for slides
    slide_list = sorted(glob.glob(os.path.join(args.slide_dir, '*.svs')))
    print('Found {} slides'.format(len(slide_list)))

    if args.shuffle:
        np.random.shuffle(slide_list)

    model = load_model(args.snapshot)

    ## Loop over found slides:
    for src in slide_list[:5]:

        ramdisk_path = transfer_to_ramdisk(
            src, args.ramdisk)  # never use the original src
        svs = Slide(slide_path=ramdisk_path,
                    preprocess_fn=lambda x: (x / 255.).astype(np.float32),
                    process_mag=args.mag,
                    process_size=args.input_dim,
                    oversample_factor=1.5)
        svs.initialize_output(name='prob', dim=args.n_classes, mode='full')
        svs.initialize_output(name='rgb', dim=3, mode='full')

        n_tiles = len(svs.tile_list)
        prefetch = min(512, n_tiles)

        # Get tensors for image an index
        img, idx = get_img_idx(svs, args.batch_size, prefetch)
        yhat_op = model(img)

        batches = 0
        while True:
            try:
                batches += 1
                yhat, img_, idx_ = sess.run([yhat_op, img, idx])
                yhat = vect_to_tile(yhat, args.input_dim)
                svs.place_batch(yhat, idx_, 'prob', mode='full')
                svs.place_batch(img_, idx_, 'rgb', mode='full')
                if batches % 50 == 0:
                    print('batch {:04d}'.format(batches))
            except tf.errors.OutOfRangeError:
                print('Done')
                break

        svs.make_outputs(reference='prob')
        prob_img = svs.output_imgs['prob']
        rgb_img = svs.output_imgs['rgb'] * 255
        color_img = colorize(rgb_img, prob_img)

        basename = os.path.basename(src).replace('.svs', '')
        dst = os.path.join(args.save_dir, '{}.npy'.format(basename))
        np.save(dst, (prob_img * 255).astype(np.uint8))
        dst = os.path.join(args.save_dir, '{}.jpg'.format(basename))
        cv2.imwrite(dst, rgb_img[:, :, ::-1])
        dst = os.path.join(args.save_dir, '{}_c.jpg'.format(basename))
        cv2.imwrite(dst, color_img[:, :, ::-1])

        os.remove(ramdisk_path)
Esempio n. 2
0
def process_slide(slide_path, fg_path, model, sess, out_dir, process_mag,
                  process_size, oversample, batch_size, n_classes):
    """ Process a slide

  Args:
  slide_path: str
    absolute or relative path to svs formatted slide
  fb_path: str
    absolute or relative path to foreground png
  model: tfmodels.SegmentationBasemodel object
    model definition to use. Weights must be restored first, 
    i.e. call model.restore() before passing model
  sess: tf.Session
  out_dir: str
    path to use for output
  process_mag: int
    Usually one of: 5, 10, 20, 40.
    Other values may work but have not been tested
  process_size: int
    The input size required by model. 
  oversample: float. Usually in [1., 2.]
    How much to oversample between tiles. Larger values 
    will increase processing time.
  batch_size: int
    The batch size for inference. If the batch size is too 
    large given the model and process_size, then OOM errors
    will be raised
  n_classes: int
    The number of classes output by model. 
    i.e. shape(model.yhat) = (batch, h, w, n_classes)
  """

    print('Working {}'.format(slide_path))
    # print('Working {}'.format(fg_path))
    # fgimg = cv2.imread(fg_path, 0)
    # fgimg = cv2.morphologyEx(fgimg, cv2.MORPH_CLOSE,
    #                          cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7,7)))
    svs = Slide(
        slide_path=slide_path,
        background_speed='all',
        # background_image = fgimg,
        preprocess_fn=preprocess_fn,
        normalize_fn=lambda x: x,
        process_mag=process_mag,
        process_size=process_size,
        oversample=oversample,
        verbose=False,
    )
    svs.initialize_output('prob', dim=n_classes)
    svs.initialize_output('rgb', dim=3)
    PREFETCH = min(len(svs.place_list), 256)

    def wrapped_fn(idx):
        try:
            coords = svs.tile_list[idx]
            img = svs._read_tile(coords)
            return img, idx
        except:
            return 0

    def read_region_at_index(idx):
        return tf.py_func(func=wrapped_fn,
                          inp=[idx],
                          Tout=[tf.float32, tf.int64],
                          stateful=False)

    ds = tf.data.Dataset.from_generator(generator=svs.generate_index,
                                        output_types=tf.int64)
    ds = ds.map(read_region_at_index, num_parallel_calls=6)
    ds = ds.prefetch(PREFETCH)
    ds = ds.batch(batch_size)

    iterator = ds.make_one_shot_iterator()
    img, idx = iterator.get_next()

    print('Processing {} tiles'.format(len(svs.tile_list)))
    tstart = time.time()
    n_processed = 0
    while True:
        try:
            tile, idx_ = sess.run([img, idx])
            output = model.inference(tile)
            svs.place_batch(output, idx_, 'prob')
            svs.place_batch(tile, idx_, 'rgb')

            n_processed += BATCH_SIZE
            if n_processed % PRINT_ITER == 0:
                print('[{:06d}] elapsed time [{:3.3f}]'.format(
                    n_processed,
                    time.time() - tstart))

        except tf.errors.OutOfRangeError:
            print('Finished')
            dt = time.time() - tstart
            spt = dt / float(len(svs.tile_list))
            fps = len(svs.tile_list) / dt
            print('\nFinished. {:2.2f}min {:3.3f}s/tile\n'.format(
                dt / 60., spt))
            print('\t {:3.3f} fps\n'.format(fps))

            svs.make_outputs()
            prob_img = prob_output(svs)
            rgb_img = rgb_output(svs)
            break

        except Exception as e:
            print('Caught exception at tiles {}'.format(idx_))
            # print(e.__doc__)
            # print(e.message)
            prob_img = None
            rgb_img = None
            break

    svs.close()

    return prob_img, rgb_img, fps
Esempio n. 3
0
def main(args):
    ## Search for slides
    # slide_list = sorted(glob.glob(os.path.join(args.slide_dir, '*.svs')))
    slide_list = read_list(args.slide_list)
    print('Found {} slides'.format(len(slide_list)))
    if args.shuffle:
        np.random.shuffle(slide_list)

    encoder_args = get_encoder_args(args.encoder)
    model = ClassifierEager(encoder_args=encoder_args,
                            deep_classifier=True,
                            n_classes=args.n_classes)
    fake_data = tf.constant(
        np.zeros((1, args.input_dim, args.input_dim, 3), dtype=np.float32))
    yhat_ = model(fake_data)
    model.load_weights(args.snapshot)
    model.summary()
    if not os.path.exists(args.save_dir):
        # shutil.rmtree(args.save_dir)
        os.makedirs(args.save_dir)

    ## Loop over found slides:
    for src in slide_list:
        basename = os.path.basename(src).replace('.svs', '')
        dst = os.path.join(args.save_dir, '{}.npy'.format(basename))
        if os.path.exists(dst):
            print('{} exists. Skipping'.format(dst))
            continue

        ramdisk_path = transfer_to_ramdisk(
            src, args.ramdisk)  # never use the original src
        try:
            fgpth = os.path.join(args.fgdir, '{}_fg.png'.format(basename))
            fgimg = cv2.imread(fgpth, 0)
            fgimg = fill_fg(fgimg)
            svs = Slide(slide_path=ramdisk_path,
                        preprocess_fn=lambda x: (x / 255.).astype(np.float32),
                        normalize_fn=lambda x: x,
                        background_speed='image',
                        background_image=fgimg,
                        process_mag=args.mag,
                        process_size=args.input_dim,
                        oversample_factor=1.1)
            svs.initialize_output(name='prob', dim=args.n_classes, mode='full')
            svs.initialize_output(name='rgb', dim=3, mode='full')
            n_tiles = len(svs.tile_list)
            prefetch = min(512, n_tiles)
            # Get tensors for image an index
            iterator = get_img_idx(svs, args.batch_size, prefetch)
            batches = 0
            for img_, idx_ in iterator:
                batches += 1
                yhat = model(img_, training=False)
                yhat = yhat.numpy()
                idx_ = idx_.numpy()
                img_ = img_.numpy()
                yhat = vect_to_tile(yhat, args.input_dim)
                svs.place_batch(yhat, idx_, 'prob', mode='full')
                svs.place_batch(img_, idx_, 'rgb', mode='full')
                if batches % 50 == 0:
                    print('\tbatch {:04d}'.format(batches))

            svs.make_outputs(reference='prob')
            prob_img = svs.output_imgs['prob']
            rgb_img = svs.output_imgs['rgb'] * 255
            color_img = colorize(rgb_img, prob_img)
            dst = os.path.join(args.save_dir, '{}.npy'.format(basename))
            np.save(dst, (prob_img * 255).astype(np.uint8))
            dst = os.path.join(args.save_dir, '{}.jpg'.format(basename))
            cv2.imwrite(dst, rgb_img[:, :, ::-1])
            dst = os.path.join(args.save_dir, '{}_c.jpg'.format(basename))
            cv2.imwrite(dst, color_img[:, :, ::-1])
        except Exception as e:
            print(e)
        finally:
            try:
                print('Closing SVS')
                svs.close()
            except:
                print('No SVS to close')

            os.remove(ramdisk_path)
Esempio n. 4
0
def main(args):
    ## Search for slides
    slide_list = sorted(glob.glob(os.path.join(args.slide_dir, '*.svs')))
    print('Found {} slides'.format(len(slide_list)))

    if args.shuffle:
        np.random.shuffle(slide_list)

    model = ClassifierEager(encoder_args=encoder_args,
                            n_classes=args.n_classes)
    xdummy = tf.constant(
        np.zeros((args.batch_size, args.input_dim, args.input_dim, 3),
                 dtype=np.float32))
    yhat_op = model(xdummy, verbose=True)
    model.load_weights(args.snapshot, by_name=True)

    # model = tf.keras.models.load_model(args.snapshot, compile=False)
    model.summary()

    ## Loop over found slides:
    for src in slide_list:
        basename = os.path.basename(src).replace('.svs', '')
        dst = os.path.join(args.save_dir, '{}.npy'.format(basename))

        if os.path.exists(dst):
            print('{} exists. continuing'.format(dst))
            continue

        fgpth = os.path.join(args.fgdir, '{}_fg.png'.format(basename))
        if os.path.exists(fgpth):
            print('fg image found:', fgpth)
            fgimg = cv2.imread(fgpth, 0)
            speed = 'image'
        else:
            speed = 'fast'
            fgimg = None

        ramdisk_path = transfer_to_ramdisk(
            src, args.ramdisk)  # never use the original src
        try:
            svs = Slide(slide_path=ramdisk_path,
                        preprocess_fn=lambda x:
                        (reinhard(x) / 255.).astype(np.float32),
                        background_speed=speed,
                        background_image=fgimg,
                        process_mag=args.mag,
                        process_size=args.input_dim,
                        oversample_factor=1.75)

            svs.initialize_output(name='prob', dim=args.n_classes, mode='full')
            svs.initialize_output(name='rgb', dim=3, mode='full')

            n_tiles = len(svs.tile_list)
            prefetch = min(128, n_tiles)
            print('Tiles:', n_tiles)

            # Get tensors for image an index
            # img, idx = get_img_idx(svs, args.batch_size, prefetch)
            iterator = get_iterator(svs, args.batch_size, prefetch)

            batches = 0
            for img_, idx_ in iterator:
                batches += 1
                # img_, idx_ = next(iterator)

                yhat = model(img_, training=False)
                yhat = yhat.numpy()
                idx_ = idx_.numpy()

                yhat = vect_to_tile(yhat, args.input_dim)
                svs.place_batch(yhat, idx_, 'prob', mode='full')
                svs.place_batch(img_.numpy(), idx_, 'rgb', mode='full')
                if batches % 50 == 0:
                    print('batch {:04d}'.format(batches))

            print('Making outputs')
            svs.make_outputs(reference='prob')
            prob_img = svs.output_imgs['prob']
            print('prob img', prob_img.shape)

            rgb_img = svs.output_imgs['rgb'] * 255
            print('rgb img', rgb_img.shape)

            color_img = colorize(rgb_img, prob_img)
            print('color img', color_img.shape)

            dst = os.path.join(args.save_dir, '{}.npy'.format(basename))
            np.save(dst, (prob_img * 255).astype(np.uint8))
            dst = os.path.join(args.save_dir, '{}.jpg'.format(basename))
            cv2.imwrite(dst, rgb_img[:, :, ::-1])
            dst = os.path.join(args.save_dir, '{}_c.jpg'.format(basename))
            cv2.imwrite(dst, color_img[:, :, ::-1])

            svs.close()
            svs = []

        except Exception as e:
            print(e)

        finally:
            print('Removing {}'.format(ramdisk_path))
            os.remove(ramdisk_path)