def main(args, sess): ## Search for slides slide_list = sorted(glob.glob(os.path.join(args.slide_dir, '*.svs'))) print('Found {} slides'.format(len(slide_list))) if args.shuffle: np.random.shuffle(slide_list) model = load_model(args.snapshot) ## Loop over found slides: for src in slide_list[:5]: ramdisk_path = transfer_to_ramdisk( src, args.ramdisk) # never use the original src svs = Slide(slide_path=ramdisk_path, preprocess_fn=lambda x: (x / 255.).astype(np.float32), process_mag=args.mag, process_size=args.input_dim, oversample_factor=1.5) svs.initialize_output(name='prob', dim=args.n_classes, mode='full') svs.initialize_output(name='rgb', dim=3, mode='full') n_tiles = len(svs.tile_list) prefetch = min(512, n_tiles) # Get tensors for image an index img, idx = get_img_idx(svs, args.batch_size, prefetch) yhat_op = model(img) batches = 0 while True: try: batches += 1 yhat, img_, idx_ = sess.run([yhat_op, img, idx]) yhat = vect_to_tile(yhat, args.input_dim) svs.place_batch(yhat, idx_, 'prob', mode='full') svs.place_batch(img_, idx_, 'rgb', mode='full') if batches % 50 == 0: print('batch {:04d}'.format(batches)) except tf.errors.OutOfRangeError: print('Done') break svs.make_outputs(reference='prob') prob_img = svs.output_imgs['prob'] rgb_img = svs.output_imgs['rgb'] * 255 color_img = colorize(rgb_img, prob_img) basename = os.path.basename(src).replace('.svs', '') dst = os.path.join(args.save_dir, '{}.npy'.format(basename)) np.save(dst, (prob_img * 255).astype(np.uint8)) dst = os.path.join(args.save_dir, '{}.jpg'.format(basename)) cv2.imwrite(dst, rgb_img[:, :, ::-1]) dst = os.path.join(args.save_dir, '{}_c.jpg'.format(basename)) cv2.imwrite(dst, color_img[:, :, ::-1]) os.remove(ramdisk_path)
def process_slide(slide_path, fg_path, model, sess, out_dir, process_mag, process_size, oversample, batch_size, n_classes): """ Process a slide Args: slide_path: str absolute or relative path to svs formatted slide fb_path: str absolute or relative path to foreground png model: tfmodels.SegmentationBasemodel object model definition to use. Weights must be restored first, i.e. call model.restore() before passing model sess: tf.Session out_dir: str path to use for output process_mag: int Usually one of: 5, 10, 20, 40. Other values may work but have not been tested process_size: int The input size required by model. oversample: float. Usually in [1., 2.] How much to oversample between tiles. Larger values will increase processing time. batch_size: int The batch size for inference. If the batch size is too large given the model and process_size, then OOM errors will be raised n_classes: int The number of classes output by model. i.e. shape(model.yhat) = (batch, h, w, n_classes) """ print('Working {}'.format(slide_path)) # print('Working {}'.format(fg_path)) # fgimg = cv2.imread(fg_path, 0) # fgimg = cv2.morphologyEx(fgimg, cv2.MORPH_CLOSE, # cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7,7))) svs = Slide( slide_path=slide_path, background_speed='all', # background_image = fgimg, preprocess_fn=preprocess_fn, normalize_fn=lambda x: x, process_mag=process_mag, process_size=process_size, oversample=oversample, verbose=False, ) svs.initialize_output('prob', dim=n_classes) svs.initialize_output('rgb', dim=3) PREFETCH = min(len(svs.place_list), 256) def wrapped_fn(idx): try: coords = svs.tile_list[idx] img = svs._read_tile(coords) return img, idx except: return 0 def read_region_at_index(idx): return tf.py_func(func=wrapped_fn, inp=[idx], Tout=[tf.float32, tf.int64], stateful=False) ds = tf.data.Dataset.from_generator(generator=svs.generate_index, output_types=tf.int64) ds = ds.map(read_region_at_index, num_parallel_calls=6) ds = ds.prefetch(PREFETCH) ds = ds.batch(batch_size) iterator = ds.make_one_shot_iterator() img, idx = iterator.get_next() print('Processing {} tiles'.format(len(svs.tile_list))) tstart = time.time() n_processed = 0 while True: try: tile, idx_ = sess.run([img, idx]) output = model.inference(tile) svs.place_batch(output, idx_, 'prob') svs.place_batch(tile, idx_, 'rgb') n_processed += BATCH_SIZE if n_processed % PRINT_ITER == 0: print('[{:06d}] elapsed time [{:3.3f}]'.format( n_processed, time.time() - tstart)) except tf.errors.OutOfRangeError: print('Finished') dt = time.time() - tstart spt = dt / float(len(svs.tile_list)) fps = len(svs.tile_list) / dt print('\nFinished. {:2.2f}min {:3.3f}s/tile\n'.format( dt / 60., spt)) print('\t {:3.3f} fps\n'.format(fps)) svs.make_outputs() prob_img = prob_output(svs) rgb_img = rgb_output(svs) break except Exception as e: print('Caught exception at tiles {}'.format(idx_)) # print(e.__doc__) # print(e.message) prob_img = None rgb_img = None break svs.close() return prob_img, rgb_img, fps
def main(args): ## Search for slides # slide_list = sorted(glob.glob(os.path.join(args.slide_dir, '*.svs'))) slide_list = read_list(args.slide_list) print('Found {} slides'.format(len(slide_list))) if args.shuffle: np.random.shuffle(slide_list) encoder_args = get_encoder_args(args.encoder) model = ClassifierEager(encoder_args=encoder_args, deep_classifier=True, n_classes=args.n_classes) fake_data = tf.constant( np.zeros((1, args.input_dim, args.input_dim, 3), dtype=np.float32)) yhat_ = model(fake_data) model.load_weights(args.snapshot) model.summary() if not os.path.exists(args.save_dir): # shutil.rmtree(args.save_dir) os.makedirs(args.save_dir) ## Loop over found slides: for src in slide_list: basename = os.path.basename(src).replace('.svs', '') dst = os.path.join(args.save_dir, '{}.npy'.format(basename)) if os.path.exists(dst): print('{} exists. Skipping'.format(dst)) continue ramdisk_path = transfer_to_ramdisk( src, args.ramdisk) # never use the original src try: fgpth = os.path.join(args.fgdir, '{}_fg.png'.format(basename)) fgimg = cv2.imread(fgpth, 0) fgimg = fill_fg(fgimg) svs = Slide(slide_path=ramdisk_path, preprocess_fn=lambda x: (x / 255.).astype(np.float32), normalize_fn=lambda x: x, background_speed='image', background_image=fgimg, process_mag=args.mag, process_size=args.input_dim, oversample_factor=1.1) svs.initialize_output(name='prob', dim=args.n_classes, mode='full') svs.initialize_output(name='rgb', dim=3, mode='full') n_tiles = len(svs.tile_list) prefetch = min(512, n_tiles) # Get tensors for image an index iterator = get_img_idx(svs, args.batch_size, prefetch) batches = 0 for img_, idx_ in iterator: batches += 1 yhat = model(img_, training=False) yhat = yhat.numpy() idx_ = idx_.numpy() img_ = img_.numpy() yhat = vect_to_tile(yhat, args.input_dim) svs.place_batch(yhat, idx_, 'prob', mode='full') svs.place_batch(img_, idx_, 'rgb', mode='full') if batches % 50 == 0: print('\tbatch {:04d}'.format(batches)) svs.make_outputs(reference='prob') prob_img = svs.output_imgs['prob'] rgb_img = svs.output_imgs['rgb'] * 255 color_img = colorize(rgb_img, prob_img) dst = os.path.join(args.save_dir, '{}.npy'.format(basename)) np.save(dst, (prob_img * 255).astype(np.uint8)) dst = os.path.join(args.save_dir, '{}.jpg'.format(basename)) cv2.imwrite(dst, rgb_img[:, :, ::-1]) dst = os.path.join(args.save_dir, '{}_c.jpg'.format(basename)) cv2.imwrite(dst, color_img[:, :, ::-1]) except Exception as e: print(e) finally: try: print('Closing SVS') svs.close() except: print('No SVS to close') os.remove(ramdisk_path)
def main(args): ## Search for slides slide_list = sorted(glob.glob(os.path.join(args.slide_dir, '*.svs'))) print('Found {} slides'.format(len(slide_list))) if args.shuffle: np.random.shuffle(slide_list) model = ClassifierEager(encoder_args=encoder_args, n_classes=args.n_classes) xdummy = tf.constant( np.zeros((args.batch_size, args.input_dim, args.input_dim, 3), dtype=np.float32)) yhat_op = model(xdummy, verbose=True) model.load_weights(args.snapshot, by_name=True) # model = tf.keras.models.load_model(args.snapshot, compile=False) model.summary() ## Loop over found slides: for src in slide_list: basename = os.path.basename(src).replace('.svs', '') dst = os.path.join(args.save_dir, '{}.npy'.format(basename)) if os.path.exists(dst): print('{} exists. continuing'.format(dst)) continue fgpth = os.path.join(args.fgdir, '{}_fg.png'.format(basename)) if os.path.exists(fgpth): print('fg image found:', fgpth) fgimg = cv2.imread(fgpth, 0) speed = 'image' else: speed = 'fast' fgimg = None ramdisk_path = transfer_to_ramdisk( src, args.ramdisk) # never use the original src try: svs = Slide(slide_path=ramdisk_path, preprocess_fn=lambda x: (reinhard(x) / 255.).astype(np.float32), background_speed=speed, background_image=fgimg, process_mag=args.mag, process_size=args.input_dim, oversample_factor=1.75) svs.initialize_output(name='prob', dim=args.n_classes, mode='full') svs.initialize_output(name='rgb', dim=3, mode='full') n_tiles = len(svs.tile_list) prefetch = min(128, n_tiles) print('Tiles:', n_tiles) # Get tensors for image an index # img, idx = get_img_idx(svs, args.batch_size, prefetch) iterator = get_iterator(svs, args.batch_size, prefetch) batches = 0 for img_, idx_ in iterator: batches += 1 # img_, idx_ = next(iterator) yhat = model(img_, training=False) yhat = yhat.numpy() idx_ = idx_.numpy() yhat = vect_to_tile(yhat, args.input_dim) svs.place_batch(yhat, idx_, 'prob', mode='full') svs.place_batch(img_.numpy(), idx_, 'rgb', mode='full') if batches % 50 == 0: print('batch {:04d}'.format(batches)) print('Making outputs') svs.make_outputs(reference='prob') prob_img = svs.output_imgs['prob'] print('prob img', prob_img.shape) rgb_img = svs.output_imgs['rgb'] * 255 print('rgb img', rgb_img.shape) color_img = colorize(rgb_img, prob_img) print('color img', color_img.shape) dst = os.path.join(args.save_dir, '{}.npy'.format(basename)) np.save(dst, (prob_img * 255).astype(np.uint8)) dst = os.path.join(args.save_dir, '{}.jpg'.format(basename)) cv2.imwrite(dst, rgb_img[:, :, ::-1]) dst = os.path.join(args.save_dir, '{}_c.jpg'.format(basename)) cv2.imwrite(dst, color_img[:, :, ::-1]) svs.close() svs = [] except Exception as e: print(e) finally: print('Removing {}'.format(ramdisk_path)) os.remove(ramdisk_path)