def main(sess, ramdisk_path, image_op, predict_op): input_size = image_op.get_shape().as_list() print(input_size) x_size, y_size = input_size[1:3] PAD = int((x_size - INPUT_SIZE) / 2) print('Working {}'.format(ramdisk_path)) svs = Slide(slide_path=ramdisk_path, preprocess_fn=preprocess_fn, background_speed='fast', process_mag=PROCESS_MAG, process_size=INPUT_SIZE, oversample_factor=OVERSAMPLE, verbose=True) print('calculated foregroud: ', svs.foreground.shape) print('calculated ds_tile_map: ', svs.ds_tile_map.shape) svs.initialize_output('prob', dim=2, mode='tile') PREFETCH = min(len(svs.tile_list), 1024) def wrapped_fn(idx): coords = svs.tile_list[idx] img = svs._read_tile(coords) # (h, w, 3) img = np.pad(img, pad_width=((PAD, PAD), (PAD, PAD), (0, 0)), mode='constant', constant_values=0.) return img, idx def read_region_at_index(idx): return tf.py_func(func=wrapped_fn, inp=[idx], Tout=[tf.float32, tf.int64], stateful=False) ds = tf.data.Dataset.from_generator(generator=svs.generate_index, output_types=tf.int64) ds = ds.map(read_region_at_index, num_parallel_calls=12) ds = ds.prefetch(PREFETCH) ds = ds.batch(BATCH_SIZE) iterator = ds.make_one_shot_iterator() img, idx = iterator.get_next() print('Processing {} tiles'.format(len(svs.tile_list))) tstart = time.time() n_processed = 0 while True: try: tile, idx_ = sess.run([img, idx]) output = sess.run(predict_op, {image_op: tile}) svs.place_batch(output, idx_, 'prob', mode='tile') n_processed += BATCH_SIZE if n_processed % PRINT_ITER == 0: print('[{:06d}] elapsed time [{:3.3f}] ({})'.format( n_processed, time.time() - tstart, tile.shape)) except tf.errors.OutOfRangeError: print('Finished') break except Exception as e: print(e) break dt = time.time() - tstart spt = dt / float(len(svs.tile_list)) fps = len(svs.tile_list) / dt print('\nFinished. {:2.2f}min {:3.3f}s/tile\n'.format(dt / 60., spt)) print('\t {:3.3f} fps\n'.format(fps)) prob_img = prob_output(svs) svs.close() return prob_img, fps
def process_slide(slide_path, fg_path, model, sess, out_dir, process_mag, process_size, oversample, batch_size, n_classes): """ Process a slide Args: slide_path: str absolute or relative path to svs formatted slide fb_path: str absolute or relative path to foreground png model: tfmodels.SegmentationBasemodel object model definition to use. Weights must be restored first, i.e. call model.restore() before passing model sess: tf.Session out_dir: str path to use for output process_mag: int Usually one of: 5, 10, 20, 40. Other values may work but have not been tested process_size: int The input size required by model. oversample: float. Usually in [1., 2.] How much to oversample between tiles. Larger values will increase processing time. batch_size: int The batch size for inference. If the batch size is too large given the model and process_size, then OOM errors will be raised n_classes: int The number of classes output by model. i.e. shape(model.yhat) = (batch, h, w, n_classes) """ print('Working {}'.format(slide_path)) # print('Working {}'.format(fg_path)) # fgimg = cv2.imread(fg_path, 0) # fgimg = cv2.morphologyEx(fgimg, cv2.MORPH_CLOSE, # cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7,7))) svs = Slide( slide_path=slide_path, background_speed='all', # background_image = fgimg, preprocess_fn=preprocess_fn, normalize_fn=lambda x: x, process_mag=process_mag, process_size=process_size, oversample=oversample, verbose=False, ) svs.initialize_output('prob', dim=n_classes) svs.initialize_output('rgb', dim=3) PREFETCH = min(len(svs.place_list), 256) def wrapped_fn(idx): try: coords = svs.tile_list[idx] img = svs._read_tile(coords) return img, idx except: return 0 def read_region_at_index(idx): return tf.py_func(func=wrapped_fn, inp=[idx], Tout=[tf.float32, tf.int64], stateful=False) ds = tf.data.Dataset.from_generator(generator=svs.generate_index, output_types=tf.int64) ds = ds.map(read_region_at_index, num_parallel_calls=6) ds = ds.prefetch(PREFETCH) ds = ds.batch(batch_size) iterator = ds.make_one_shot_iterator() img, idx = iterator.get_next() print('Processing {} tiles'.format(len(svs.tile_list))) tstart = time.time() n_processed = 0 while True: try: tile, idx_ = sess.run([img, idx]) output = model.inference(tile) svs.place_batch(output, idx_, 'prob') svs.place_batch(tile, idx_, 'rgb') n_processed += BATCH_SIZE if n_processed % PRINT_ITER == 0: print('[{:06d}] elapsed time [{:3.3f}]'.format( n_processed, time.time() - tstart)) except tf.errors.OutOfRangeError: print('Finished') dt = time.time() - tstart spt = dt / float(len(svs.tile_list)) fps = len(svs.tile_list) / dt print('\nFinished. {:2.2f}min {:3.3f}s/tile\n'.format( dt / 60., spt)) print('\t {:3.3f} fps\n'.format(fps)) svs.make_outputs() prob_img = prob_output(svs) rgb_img = rgb_output(svs) break except Exception as e: print('Caught exception at tiles {}'.format(idx_)) # print(e.__doc__) # print(e.message) prob_img = None rgb_img = None break svs.close() return prob_img, rgb_img, fps
def main(args): # Translate obfuscated file names to paths if necessary slide_list = read_list(args.f) print('Found {} slides'.format(len(slide_list))) encoder_args = get_encoder_args(args.encoder) model = MilkEager(encoder_args=encoder_args, mil_type=args.mil, batch_size=args.batch_size, temperature=args.temperature, deep_classifier=args.deep_classifier) x_pl = np.zeros((1, args.batch_size, args.input_dim, args.input_dim, 3), dtype=np.float32) yhat = model(tf.constant(x_pl), verbose=True) print('yhat:', yhat.shape) print('setting model weights') model.load_weights(args.s, by_name=True) ## Loop over found slides: yhats = [] for i, src in enumerate(slide_list): print('\nSlide {}'.format(i)) basename = os.path.basename(src).replace('.svs', '') fgpth = os.path.join(args.fg, '{}_fg.png'.format(basename)) if os.path.exists(fgpth): ramdisk_path = transfer_to_ramdisk( src, args.ramdisk) # never use the original src print('Using fg image at : {}'.format(fgpth)) fgimg = cv2.imread(fgpth, 0) try: svs = Slide( slide_path=ramdisk_path, # background_speed = 'accurate', background_speed='image', background_image=fgimg, preprocess_fn=lambda x: (x / 255.).astype(np.float32), normalize_fn=lambda x: x, process_mag=args.mag, process_size=args.input_dim, oversample_factor=args.oversample, verbose=False) except Exception as e: print(e) print( 'Caught SVS related error. Cleaning ramdisk and continuing.' ) print('Cleaning file: {}'.format(ramdisk_path)) os.remove(ramdisk_path) continue else: print(fgpth) continue svs.initialize_output(name='attention', dim=1, mode='tile') yhat, att, indices = process_slide(svs, model, args) yhats.append(yhat) print('\tSlide predicted: {}'.format(yhat)) if args.mil == 'average': print('Average MIL; continuing') elif args.mil in ['attention', 'instance']: print('Placing values ranged {:3.3f} - {:3.3f}'.format( att.min(), att.max())) print('Visualizing mean {:3.5f}'.format(np.mean(att))) print('Visualizing std {:3.5f}'.format(np.std(att))) svs.place_batch(att, indices, 'attention', mode='tile') attention_img = np.squeeze(svs.output_imgs['attention']) attention_img_raw = np.squeeze(svs.output_imgs['attention']) attention_img = attention_img * (1. / attention_img.max()) attention_img = draw_attention(attention_img, n_bins=50) print('attention image:', attention_img.shape, attention_img.dtype, attention_img.min(), attention_img.max()) dst = os.path.join(args.o, '{}_att.npy'.format(basename)) np.save(dst, attention_img_raw) dst = os.path.join(args.o, '{}_img.png'.format(basename)) cv2.imwrite(dst, attention_img) yhat_dst = os.path.join(args.o, '{}_ypred.npy'.format(basename)) np.save(yhat_dst, yhat) try: svs.close() os.remove(ramdisk_path) del svs except: print('{} already removed'.format(ramdisk_path))
def process_slide(slide_path, sess, out_dir, process_mag, process_size, oversample, batch_size, n_classes, fgspeed, fgimg): """ Process a slide # Pytoch - compatible mode: uses sess only for the loading pipeline Args: slide_path: str absolute or relative path to svs formatted slide model: tfmodels.SegmentationBasemodel object model definition to use. Weights must be restored first, i.e. call model.restore() before passing model sess: tf.Session out_dir: str path to use for output process_mag: int Usually one of: 5, 10, 20, 40. Other values may work but have not been tested process_size: int The input size required by model. oversample: float. Usually in [1., 2.] How much to oversample between tiles. Larger values will increase processing time. batch_size: int The batch size for inference. If the batch size is too large given the model and process_size, then OOM errors will be raised n_classes: int The number of classes output by model. i.e. shape(model.yhat) = (batch, h, w, n_classes) """ try: print('Working {}'.format(slide_path)) svs = Slide( slide_path=slide_path, preprocess_fn=lambda x: x.astype(np.float32), normalize_fn=lambda x: x, process_mag=process_mag, process_size=process_size, oversample=oversample, background_speed=fgspeed, background_img=fgimg, verbose=False, ) # svs.initialize_output('prob', dim=n_classes) svs.initialize_output('rgb', dim=3) svs.initialize_output('prob', dim=n_classes, mode='tile') PREFETCH = min(len(svs.place_list), 1024) except Exception as e: print(e) print("Error loading slide at {}".format(slide_path)) return None, None, None def wrapped_fn(idx): try: coords = svs.tile_list[idx] img = svs._read_tile(coords) return img, idx except: return 0 def read_region_at_index(idx): return tf.py_func(func=wrapped_fn, inp=[idx], Tout=[tf.float32, tf.int64], stateful=False) ds = tf.data.Dataset.from_generator(generator=svs.generate_index, output_types=tf.int64) ds = ds.map(read_region_at_index, num_parallel_calls=8) ds = ds.prefetch(PREFETCH) ds = ds.batch(batch_size) iterator = ds.make_one_shot_iterator() img, idx = iterator.get_next() print('Processing {} tiles'.format(len(svs.tile_list))) tstart = time.time() n_processed = 0 while True: try: tile, idx_ = sess.run([img, idx]) svs.place_batch(tile, idx_, 'rgb', clobber=True) n_processed += BATCH_SIZE if n_processed % PRINT_ITER == 0: print('[{:06d}] elapsed time [{:3.3f}]'.format( n_processed, time.time() - tstart)) except tf.errors.OutOfRangeError: dt = time.time() - tstart spt = dt / float(len(svs.tile_list)) fps = len(svs.tile_list) / dt print('\nFinished. {:2.2f}min {:3.3f}s/tile\n'.format( dt / 60., spt)) print('{:3.3f} fps\n'.format(fps)) # svs.make_outputs() rgb_img = rgb_output(svs) break except Exception as e: print('Caught exception at tiles {}'.format(idx_)) # print(e.__doc__) print(e) rgb_img = None break svs.close() # Replace pure black with white bwimg = np.mean(rgb_img, axis=-1) img_b = bwimg < 10 rgb_img[img_b, :] = 255 return rgb_img, fps
def main(args): # Translate obfuscated file names to paths if necessary test_list = os.path.join(args.testdir, '{}.txt'.format(args.timestamp)) test_list = read_test_list(test_list) test_unique_ids = [ os.path.basename(x).replace('.npy', '') for x in test_list ] if args.randomize: np.random.shuffle(test_unique_ids) if args.max_slides: test_unique_ids = test_unique_ids[:args.max_slides] slide_list, slide_labels = get_slidelist_from_uids(test_unique_ids) print('Found {} slides'.format(len(slide_list))) snapshot = os.path.join(args.savedir, '{}.h5'.format(args.timestamp)) # trained_model = load_model(snapshot) # if args.mcdropout: # encoder_args['mcdropout'] = True encoder_args = get_encoder_args(args.encoder) model = MilkEager(encoder_args=encoder_args, mil_type=args.mil, batch_size=args.batch_size, deep_classifier=args.deep_classifier, temperature=args.temperature) x_pl = np.zeros((1, args.batch_size, args.input_dim, args.input_dim, 3), dtype=np.float32) yhat = model(tf.constant(x_pl), verbose=True) print('yhat:', yhat.shape) print('setting model weights') model.load_weights(snapshot, by_name=True) ## Loop over found slides: yhats = [] ytrues = [] for i, (src, lab) in enumerate(zip(slide_list, slide_labels)): print('\nSlide {}'.format(i)) basename = os.path.basename(src).replace('.svs', '') fgpth = os.path.join(args.fgdir, '{}_fg.png'.format(basename)) if os.path.exists(fgpth): ramdisk_path = transfer_to_ramdisk( src, args.ramdisk) # never use the original src print('Using fg image at : {}'.format(fgpth)) fgimg = cv2.imread(fgpth, 0) svs = Slide( slide_path=ramdisk_path, # background_speed = 'accurate', background_speed='image', background_image=fgimg, # preprocess_fn = lambda x: (reinhard(x)/255.).astype(np.float32), preprocess_fn=lambda x: (x / 255.).astype(np.float32), process_mag=args.mag, process_size=args.input_dim, oversample_factor=args.oversample, verbose=False) else: ## require precomputed background; Exit. print('Required foreground image not found ({})'.format(fgpth)) continue svs.initialize_output(name='attention', dim=1, mode='tile') n_tiles = len(svs.tile_list) yhat, att, indices = process_slide(svs, model, args) print('returned attention:', np.min(att), np.max(att), att.shape) yhat = yhat.numpy() yhats.append(yhat) ytrues.append(lab) print('\tSlide label: {} predicted: {}'.format(lab, yhat)) svs.place_batch(att, indices, 'attention', mode='tile') attention_img = np.squeeze(svs.output_imgs['attention']) attention_img = attention_img * (1. / attention_img.max()) attention_img = draw_attention(attention_img, n_bins=25) print('attention image:', attention_img.shape, attention_img.dtype, attention_img.min(), attention_img.max()) dst = os.path.join( args.odir, args.timestamp, '{}_{}_{:3.3f}_att.npy'.format(basename, lab, yhat[0, 1])) np.save(dst, att) dst = os.path.join( args.odir, args.timestamp, '{}_{}_{:3.3f}_img.png'.format(basename, lab, yhat[0, 1])) cv2.imwrite(dst, attention_img) try: svs.close() os.remove(ramdisk_path) except: print('{} already removed'.format(ramdisk_path)) yhats = np.concatenate(yhats, axis=0) ytrues = np.array(ytrues) acc = (np.argmax(yhats, axis=-1) == ytrues).mean() print(acc)
def main(args): ## Search for slides # slide_list = sorted(glob.glob(os.path.join(args.slide_dir, '*.svs'))) slide_list = read_list(args.slide_list) print('Found {} slides'.format(len(slide_list))) if args.shuffle: np.random.shuffle(slide_list) encoder_args = get_encoder_args(args.encoder) model = ClassifierEager(encoder_args=encoder_args, deep_classifier=True, n_classes=args.n_classes) fake_data = tf.constant( np.zeros((1, args.input_dim, args.input_dim, 3), dtype=np.float32)) yhat_ = model(fake_data) model.load_weights(args.snapshot) model.summary() if not os.path.exists(args.save_dir): # shutil.rmtree(args.save_dir) os.makedirs(args.save_dir) ## Loop over found slides: for src in slide_list: basename = os.path.basename(src).replace('.svs', '') dst = os.path.join(args.save_dir, '{}.npy'.format(basename)) if os.path.exists(dst): print('{} exists. Skipping'.format(dst)) continue ramdisk_path = transfer_to_ramdisk( src, args.ramdisk) # never use the original src try: fgpth = os.path.join(args.fgdir, '{}_fg.png'.format(basename)) fgimg = cv2.imread(fgpth, 0) fgimg = fill_fg(fgimg) svs = Slide(slide_path=ramdisk_path, preprocess_fn=lambda x: (x / 255.).astype(np.float32), normalize_fn=lambda x: x, background_speed='image', background_image=fgimg, process_mag=args.mag, process_size=args.input_dim, oversample_factor=1.1) svs.initialize_output(name='prob', dim=args.n_classes, mode='full') svs.initialize_output(name='rgb', dim=3, mode='full') n_tiles = len(svs.tile_list) prefetch = min(512, n_tiles) # Get tensors for image an index iterator = get_img_idx(svs, args.batch_size, prefetch) batches = 0 for img_, idx_ in iterator: batches += 1 yhat = model(img_, training=False) yhat = yhat.numpy() idx_ = idx_.numpy() img_ = img_.numpy() yhat = vect_to_tile(yhat, args.input_dim) svs.place_batch(yhat, idx_, 'prob', mode='full') svs.place_batch(img_, idx_, 'rgb', mode='full') if batches % 50 == 0: print('\tbatch {:04d}'.format(batches)) svs.make_outputs(reference='prob') prob_img = svs.output_imgs['prob'] rgb_img = svs.output_imgs['rgb'] * 255 color_img = colorize(rgb_img, prob_img) dst = os.path.join(args.save_dir, '{}.npy'.format(basename)) np.save(dst, (prob_img * 255).astype(np.uint8)) dst = os.path.join(args.save_dir, '{}.jpg'.format(basename)) cv2.imwrite(dst, rgb_img[:, :, ::-1]) dst = os.path.join(args.save_dir, '{}_c.jpg'.format(basename)) cv2.imwrite(dst, color_img[:, :, ::-1]) except Exception as e: print(e) finally: try: print('Closing SVS') svs.close() except: print('No SVS to close') os.remove(ramdisk_path)
def main(args): ## Search for slides slide_list = sorted(glob.glob(os.path.join(args.slide_dir, '*.svs'))) print('Found {} slides'.format(len(slide_list))) if args.shuffle: np.random.shuffle(slide_list) model = ClassifierEager(encoder_args=encoder_args, n_classes=args.n_classes) xdummy = tf.constant( np.zeros((args.batch_size, args.input_dim, args.input_dim, 3), dtype=np.float32)) yhat_op = model(xdummy, verbose=True) model.load_weights(args.snapshot, by_name=True) # model = tf.keras.models.load_model(args.snapshot, compile=False) model.summary() ## Loop over found slides: for src in slide_list: basename = os.path.basename(src).replace('.svs', '') dst = os.path.join(args.save_dir, '{}.npy'.format(basename)) if os.path.exists(dst): print('{} exists. continuing'.format(dst)) continue fgpth = os.path.join(args.fgdir, '{}_fg.png'.format(basename)) if os.path.exists(fgpth): print('fg image found:', fgpth) fgimg = cv2.imread(fgpth, 0) speed = 'image' else: speed = 'fast' fgimg = None ramdisk_path = transfer_to_ramdisk( src, args.ramdisk) # never use the original src try: svs = Slide(slide_path=ramdisk_path, preprocess_fn=lambda x: (reinhard(x) / 255.).astype(np.float32), background_speed=speed, background_image=fgimg, process_mag=args.mag, process_size=args.input_dim, oversample_factor=1.75) svs.initialize_output(name='prob', dim=args.n_classes, mode='full') svs.initialize_output(name='rgb', dim=3, mode='full') n_tiles = len(svs.tile_list) prefetch = min(128, n_tiles) print('Tiles:', n_tiles) # Get tensors for image an index # img, idx = get_img_idx(svs, args.batch_size, prefetch) iterator = get_iterator(svs, args.batch_size, prefetch) batches = 0 for img_, idx_ in iterator: batches += 1 # img_, idx_ = next(iterator) yhat = model(img_, training=False) yhat = yhat.numpy() idx_ = idx_.numpy() yhat = vect_to_tile(yhat, args.input_dim) svs.place_batch(yhat, idx_, 'prob', mode='full') svs.place_batch(img_.numpy(), idx_, 'rgb', mode='full') if batches % 50 == 0: print('batch {:04d}'.format(batches)) print('Making outputs') svs.make_outputs(reference='prob') prob_img = svs.output_imgs['prob'] print('prob img', prob_img.shape) rgb_img = svs.output_imgs['rgb'] * 255 print('rgb img', rgb_img.shape) color_img = colorize(rgb_img, prob_img) print('color img', color_img.shape) dst = os.path.join(args.save_dir, '{}.npy'.format(basename)) np.save(dst, (prob_img * 255).astype(np.uint8)) dst = os.path.join(args.save_dir, '{}.jpg'.format(basename)) cv2.imwrite(dst, rgb_img[:, :, ::-1]) dst = os.path.join(args.save_dir, '{}_c.jpg'.format(basename)) cv2.imwrite(dst, color_img[:, :, ::-1]) svs.close() svs = [] except Exception as e: print(e) finally: print('Removing {}'.format(ramdisk_path)) os.remove(ramdisk_path)
def main(args, sess): # Translate obfuscated file names to paths if necessary test_list = os.path.join(args.testdir, '{}.txt'.format(args.timestamp)) test_list = read_test_list(test_list) test_unique_ids = [ os.path.basename(x).replace('.npy', '') for x in test_list ] if args.randomize: np.random.shuffle(test_unique_ids) slide_list, slide_labels = get_slidelist_from_uids(test_unique_ids) print('Found {} slides'.format(len(slide_list))) snapshot = os.path.join(args.savedir, '{}.h5'.format(args.timestamp)) # trained_model = load_model(snapshot) # if args.mcdropout: # encoder_args['mcdropout'] = True encode_model = MilkEncode(input_shape=(args.input_dim, args.input_dim, 3), encoder_args=encoder_args, deep_classifier=args.deep_classifier) x_pl = tf.placeholder(shape=(None, args.input_dim, args.input_dim, 3), dtype=tf.float32) z_op = encode_model(x_pl) input_shape = z_op.shape[-1] predict_model = MilkPredict(input_shape=[input_shape], mode=args.mil, use_gate=args.gated_attention, deep_classifier=args.deep_classifier) attention_model = MilkAttention(input_shape=[input_shape], use_gate=args.gated_attention) print('setting encoder weights') encode_model.load_weights(snapshot, by_name=True) print('setting predict weights') predict_model.load_weights(snapshot, by_name=True) print('setting attention weights') attention_model.load_weights(snapshot, by_name=True) z_pl = tf.placeholder(shape=(None, input_shape), dtype=tf.float32) y_op = predict_model(z_pl) att_op = attention_model(z_pl) # fig = plt.figure(figsize=(2,2), dpi=180) ## Loop over found slides: yhats = [] ytrues = [] for i, (src, lab) in enumerate(zip(slide_list, slide_labels)): print('\nSlide {}'.format(i)) basename = os.path.basename(src).replace('.svs', '') fgpth = os.path.join(args.fgdir, '{}_fg.png'.format(basename)) if os.path.exists(fgpth): ramdisk_path = transfer_to_ramdisk( src, args.ramdisk) # never use the original src print('Using fg image at : {}'.format(fgpth)) fgimg = cv2.imread(fgpth, 0) svs = Slide( slide_path=ramdisk_path, # background_speed = 'accurate', background_speed='image', background_image=fgimg, preprocess_fn=lambda x: (x / 255.).astype(np.float32), process_mag=args.mag, process_size=args.input_dim, oversample_factor=args.oversample, verbose=False) else: ## require precomputed background; Exit. print('Required foreground image not found ({})'.format(fgpth)) continue svs.initialize_output(name='attention', dim=1, mode='tile') n_tiles = len(svs.tile_list) if not args.mcdropout: yhat, att, indices = process_slide(svs, encode_model, y_op, att_op, z_pl, args) else: yhat, att, yhat_sd, att_sd, indices = process_slide_mcdropout( svs, encode_model, y_op, att_op, z_pl, args) yhats.append(yhat) ytrues.append(lab) print('\tSlide label: {} predicted: {}'.format(lab, yhat)) svs.place_batch(att, indices, 'attention', mode='tile') attention_img = np.squeeze(svs.output_imgs['attention']) attention_img = attention_img * (1. / attention_img.max()) attention_img = draw_attention(attention_img, n_bins=50) print('attention image:', attention_img.shape, attention_img.dtype, attention_img.min(), attention_img.max()) dst = os.path.join(args.odir, args.timestamp, '{}_att.npy'.format(basename)) np.save(dst, att) dst = os.path.join(args.odir, args.timestamp, '{}_img.png'.format(basename)) cv2.imwrite(dst, attention_img) # dst = os.path.join(args.odir, args.timestamp, '{}_hist.png'.format(basename)) # fig.clf() # plt.hist(att, bins=100); # plt.title('Attention distribution\n{} ({} tiles)'.format(basename, n_tiles)) # plt.xlabel('Attention score') # plt.ylabel('Tile count') # plt.savefig(dst, bbox_inches='tight') try: svs.close() os.remove(ramdisk_path) except: print('{} already removed'.format(ramdisk_path))
def main(args, sess): dst = os.path.join(args.odir, 'auc_{}.png'.format(args.timestamp)) if os.path.exists(dst): print('{} exists. Exiting.'.format(dst)) return # Translate obfuscated file names to paths if necessary test_list = os.path.join(args.testdir, '{}.txt'.format(args.timestamp)) test_list = read_test_list(test_list) test_unique_ids = [ os.path.basename(x).replace('.npy', '') for x in test_list ] slide_list, slide_labels = get_slidelist_from_uids(test_unique_ids) print('Found {} slides'.format(len(slide_list))) snapshot = os.path.join(args.savedir, '{}.h5'.format(args.timestamp)) trained_model = load_model(snapshot) if args.mcdropout: encoder_args['mcdropout'] = True encode_model = MilkEncode(input_shape=(args.input_dim, args.input_dim, 3), encoder_args=encoder_args) encode_shape = list(encode_model.output.shape) predict_model = MilkPredict(input_shape=[512], mode=args.mil, use_gate=args.gated_attention) # attention_model = MilkAttention(input_shape=[512], use_gate=args.gated_attention) models = model_utils.make_inference_functions( encode_model, predict_model, trained_model, ) encode_model, predict_model = models z_pl = tf.placeholder(shape=(None, 512), dtype=tf.float32) y_op = predict_model(z_pl) fig = plt.figure(figsize=(2, 2), dpi=180) ## Loop over found slides: yhats = [] ytrues = [] for i, (src, lab) in enumerate(zip(slide_list, slide_labels)): print('\nSlide {}'.format(i)) basename = os.path.basename(src).replace('.svs', '') fgpth = os.path.join(args.fgdir, '{}_fg.png'.format(basename)) if os.path.exists(fgpth): ramdisk_path = transfer_to_ramdisk( src, args.ramdisk) # never use the original src print('Using fg image at : {}'.format(fgpth)) fgimg = cv2.imread(fgpth, 0) print('fgimg shape: ', fgimg.shape) svs = Slide( slide_path=ramdisk_path, # background_speed = 'accurate', background_speed='image', background_image=fgimg, preprocess_fn=lambda x: (x / 255.).astype(np.float32), process_mag=args.mag, process_size=args.input_dim, oversample_factor=1.5, verbose=True) print('calculated foregroud: ', svs.foreground.shape) print('calculated ds_tile_map: ', svs.ds_tile_map.shape) else: ## Require precomputed background print('No fg image found ({})'.format(fgpth)) continue n_tiles = len(svs.tile_list) if not args.mcdropout: yhat, indices = process_slide(svs, sess, encode_model, y_op, z_pl, args) else: yhat, yhat_sd, indices = process_slide_mcdropout( svs, sess, encode_model, y_op, z_pl, args) yhats.append(yhat) ytrues.append(lab) print('\tSlide label: {} predicted: {}'.format(lab, yhat)) os.remove(ramdisk_path) svs.close() del svs yhats = np.concatenate(yhats, axis=0) ytrue = np.array(ytrues) for i, yt in enumerate(ytrue): print(yt, yhats[i, :]) auc_curve(ytrue, yhats, savepath=dst) del encode_model del predict_model