Ejemplo n.º 1
0
def main(sess, ramdisk_path, image_op, predict_op):
    input_size = image_op.get_shape().as_list()
    print(input_size)
    x_size, y_size = input_size[1:3]

    PAD = int((x_size - INPUT_SIZE) / 2)

    print('Working {}'.format(ramdisk_path))
    svs = Slide(slide_path=ramdisk_path,
                preprocess_fn=preprocess_fn,
                background_speed='fast',
                process_mag=PROCESS_MAG,
                process_size=INPUT_SIZE,
                oversample_factor=OVERSAMPLE,
                verbose=True)
    print('calculated foregroud: ', svs.foreground.shape)
    print('calculated ds_tile_map: ', svs.ds_tile_map.shape)

    svs.initialize_output('prob', dim=2, mode='tile')
    PREFETCH = min(len(svs.tile_list), 1024)

    def wrapped_fn(idx):
        coords = svs.tile_list[idx]
        img = svs._read_tile(coords)  # (h, w, 3)
        img = np.pad(img,
                     pad_width=((PAD, PAD), (PAD, PAD), (0, 0)),
                     mode='constant',
                     constant_values=0.)
        return img, idx

    def read_region_at_index(idx):
        return tf.py_func(func=wrapped_fn,
                          inp=[idx],
                          Tout=[tf.float32, tf.int64],
                          stateful=False)

    ds = tf.data.Dataset.from_generator(generator=svs.generate_index,
                                        output_types=tf.int64)
    ds = ds.map(read_region_at_index, num_parallel_calls=12)
    ds = ds.prefetch(PREFETCH)
    ds = ds.batch(BATCH_SIZE)

    iterator = ds.make_one_shot_iterator()
    img, idx = iterator.get_next()

    print('Processing {} tiles'.format(len(svs.tile_list)))
    tstart = time.time()
    n_processed = 0
    while True:
        try:
            tile, idx_ = sess.run([img, idx])
            output = sess.run(predict_op, {image_op: tile})
            svs.place_batch(output, idx_, 'prob', mode='tile')

            n_processed += BATCH_SIZE
            if n_processed % PRINT_ITER == 0:
                print('[{:06d}] elapsed time [{:3.3f}] ({})'.format(
                    n_processed,
                    time.time() - tstart, tile.shape))

        except tf.errors.OutOfRangeError:
            print('Finished')
            break
        except Exception as e:
            print(e)
            break

    dt = time.time() - tstart
    spt = dt / float(len(svs.tile_list))
    fps = len(svs.tile_list) / dt
    print('\nFinished. {:2.2f}min {:3.3f}s/tile\n'.format(dt / 60., spt))
    print('\t {:3.3f} fps\n'.format(fps))

    prob_img = prob_output(svs)
    svs.close()

    return prob_img, fps
Ejemplo n.º 2
0
def process_slide(slide_path, fg_path, model, sess, out_dir, process_mag,
                  process_size, oversample, batch_size, n_classes):
    """ Process a slide

  Args:
  slide_path: str
    absolute or relative path to svs formatted slide
  fb_path: str
    absolute or relative path to foreground png
  model: tfmodels.SegmentationBasemodel object
    model definition to use. Weights must be restored first, 
    i.e. call model.restore() before passing model
  sess: tf.Session
  out_dir: str
    path to use for output
  process_mag: int
    Usually one of: 5, 10, 20, 40.
    Other values may work but have not been tested
  process_size: int
    The input size required by model. 
  oversample: float. Usually in [1., 2.]
    How much to oversample between tiles. Larger values 
    will increase processing time.
  batch_size: int
    The batch size for inference. If the batch size is too 
    large given the model and process_size, then OOM errors
    will be raised
  n_classes: int
    The number of classes output by model. 
    i.e. shape(model.yhat) = (batch, h, w, n_classes)
  """

    print('Working {}'.format(slide_path))
    # print('Working {}'.format(fg_path))
    # fgimg = cv2.imread(fg_path, 0)
    # fgimg = cv2.morphologyEx(fgimg, cv2.MORPH_CLOSE,
    #                          cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7,7)))
    svs = Slide(
        slide_path=slide_path,
        background_speed='all',
        # background_image = fgimg,
        preprocess_fn=preprocess_fn,
        normalize_fn=lambda x: x,
        process_mag=process_mag,
        process_size=process_size,
        oversample=oversample,
        verbose=False,
    )
    svs.initialize_output('prob', dim=n_classes)
    svs.initialize_output('rgb', dim=3)
    PREFETCH = min(len(svs.place_list), 256)

    def wrapped_fn(idx):
        try:
            coords = svs.tile_list[idx]
            img = svs._read_tile(coords)
            return img, idx
        except:
            return 0

    def read_region_at_index(idx):
        return tf.py_func(func=wrapped_fn,
                          inp=[idx],
                          Tout=[tf.float32, tf.int64],
                          stateful=False)

    ds = tf.data.Dataset.from_generator(generator=svs.generate_index,
                                        output_types=tf.int64)
    ds = ds.map(read_region_at_index, num_parallel_calls=6)
    ds = ds.prefetch(PREFETCH)
    ds = ds.batch(batch_size)

    iterator = ds.make_one_shot_iterator()
    img, idx = iterator.get_next()

    print('Processing {} tiles'.format(len(svs.tile_list)))
    tstart = time.time()
    n_processed = 0
    while True:
        try:
            tile, idx_ = sess.run([img, idx])
            output = model.inference(tile)
            svs.place_batch(output, idx_, 'prob')
            svs.place_batch(tile, idx_, 'rgb')

            n_processed += BATCH_SIZE
            if n_processed % PRINT_ITER == 0:
                print('[{:06d}] elapsed time [{:3.3f}]'.format(
                    n_processed,
                    time.time() - tstart))

        except tf.errors.OutOfRangeError:
            print('Finished')
            dt = time.time() - tstart
            spt = dt / float(len(svs.tile_list))
            fps = len(svs.tile_list) / dt
            print('\nFinished. {:2.2f}min {:3.3f}s/tile\n'.format(
                dt / 60., spt))
            print('\t {:3.3f} fps\n'.format(fps))

            svs.make_outputs()
            prob_img = prob_output(svs)
            rgb_img = rgb_output(svs)
            break

        except Exception as e:
            print('Caught exception at tiles {}'.format(idx_))
            # print(e.__doc__)
            # print(e.message)
            prob_img = None
            rgb_img = None
            break

    svs.close()

    return prob_img, rgb_img, fps
Ejemplo n.º 3
0
def main(args):
    # Translate obfuscated file names to paths if necessary
    slide_list = read_list(args.f)
    print('Found {} slides'.format(len(slide_list)))

    encoder_args = get_encoder_args(args.encoder)
    model = MilkEager(encoder_args=encoder_args,
                      mil_type=args.mil,
                      batch_size=args.batch_size,
                      temperature=args.temperature,
                      deep_classifier=args.deep_classifier)

    x_pl = np.zeros((1, args.batch_size, args.input_dim, args.input_dim, 3),
                    dtype=np.float32)
    yhat = model(tf.constant(x_pl), verbose=True)
    print('yhat:', yhat.shape)

    print('setting model weights')
    model.load_weights(args.s, by_name=True)

    ## Loop over found slides:
    yhats = []
    for i, src in enumerate(slide_list):
        print('\nSlide {}'.format(i))
        basename = os.path.basename(src).replace('.svs', '')
        fgpth = os.path.join(args.fg, '{}_fg.png'.format(basename))
        if os.path.exists(fgpth):
            ramdisk_path = transfer_to_ramdisk(
                src, args.ramdisk)  # never use the original src
            print('Using fg image at : {}'.format(fgpth))
            fgimg = cv2.imread(fgpth, 0)
            try:
                svs = Slide(
                    slide_path=ramdisk_path,
                    # background_speed  = 'accurate',
                    background_speed='image',
                    background_image=fgimg,
                    preprocess_fn=lambda x: (x / 255.).astype(np.float32),
                    normalize_fn=lambda x: x,
                    process_mag=args.mag,
                    process_size=args.input_dim,
                    oversample_factor=args.oversample,
                    verbose=False)
            except Exception as e:
                print(e)
                print(
                    'Caught SVS related error. Cleaning ramdisk and continuing.'
                )
                print('Cleaning file: {}'.format(ramdisk_path))
                os.remove(ramdisk_path)
                continue
        else:
            print(fgpth)
            continue

        svs.initialize_output(name='attention', dim=1, mode='tile')
        yhat, att, indices = process_slide(svs, model, args)

        yhats.append(yhat)
        print('\tSlide predicted: {}'.format(yhat))

        if args.mil == 'average':
            print('Average MIL; continuing')
        elif args.mil in ['attention', 'instance']:
            print('Placing values ranged {:3.3f} - {:3.3f}'.format(
                att.min(), att.max()))
            print('Visualizing mean {:3.5f}'.format(np.mean(att)))
            print('Visualizing std {:3.5f}'.format(np.std(att)))
            svs.place_batch(att, indices, 'attention', mode='tile')
            attention_img = np.squeeze(svs.output_imgs['attention'])
            attention_img_raw = np.squeeze(svs.output_imgs['attention'])

            attention_img = attention_img * (1. / attention_img.max())
            attention_img = draw_attention(attention_img, n_bins=50)
            print('attention image:', attention_img.shape, attention_img.dtype,
                  attention_img.min(), attention_img.max())

            dst = os.path.join(args.o, '{}_att.npy'.format(basename))
            np.save(dst, attention_img_raw)
            dst = os.path.join(args.o, '{}_img.png'.format(basename))
            cv2.imwrite(dst, attention_img)

        yhat_dst = os.path.join(args.o, '{}_ypred.npy'.format(basename))
        np.save(yhat_dst, yhat)

        try:
            svs.close()
            os.remove(ramdisk_path)
            del svs
        except:
            print('{} already removed'.format(ramdisk_path))
Ejemplo n.º 4
0
def process_slide(slide_path, sess, out_dir, process_mag, process_size,
                  oversample, batch_size, n_classes, fgspeed, fgimg):
    """ Process a slide
  # Pytoch - compatible mode: uses sess only for the loading pipeline

  Args:
  slide_path: str
    absolute or relative path to svs formatted slide
  model: tfmodels.SegmentationBasemodel object
    model definition to use. Weights must be restored first, 
    i.e. call model.restore() before passing model
  sess: tf.Session
  out_dir: str
    path to use for output
  process_mag: int
    Usually one of: 5, 10, 20, 40.
    Other values may work but have not been tested
  process_size: int
    The input size required by model. 
  oversample: float. Usually in [1., 2.]
    How much to oversample between tiles. Larger values 
    will increase processing time.
  batch_size: int
    The batch size for inference. If the batch size is too 
    large given the model and process_size, then OOM errors
    will be raised
  n_classes: int
    The number of classes output by model. 
    i.e. shape(model.yhat) = (batch, h, w, n_classes)
  """

    try:
        print('Working {}'.format(slide_path))
        svs = Slide(
            slide_path=slide_path,
            preprocess_fn=lambda x: x.astype(np.float32),
            normalize_fn=lambda x: x,
            process_mag=process_mag,
            process_size=process_size,
            oversample=oversample,
            background_speed=fgspeed,
            background_img=fgimg,
            verbose=False,
        )
        # svs.initialize_output('prob', dim=n_classes)
        svs.initialize_output('rgb', dim=3)
        svs.initialize_output('prob', dim=n_classes, mode='tile')
        PREFETCH = min(len(svs.place_list), 1024)
    except Exception as e:
        print(e)
        print("Error loading slide at {}".format(slide_path))
        return None, None, None

    def wrapped_fn(idx):
        try:
            coords = svs.tile_list[idx]
            img = svs._read_tile(coords)
            return img, idx
        except:
            return 0

    def read_region_at_index(idx):
        return tf.py_func(func=wrapped_fn,
                          inp=[idx],
                          Tout=[tf.float32, tf.int64],
                          stateful=False)

    ds = tf.data.Dataset.from_generator(generator=svs.generate_index,
                                        output_types=tf.int64)
    ds = ds.map(read_region_at_index, num_parallel_calls=8)
    ds = ds.prefetch(PREFETCH)
    ds = ds.batch(batch_size)

    iterator = ds.make_one_shot_iterator()
    img, idx = iterator.get_next()

    print('Processing {} tiles'.format(len(svs.tile_list)))
    tstart = time.time()
    n_processed = 0
    while True:
        try:
            tile, idx_ = sess.run([img, idx])
            svs.place_batch(tile, idx_, 'rgb', clobber=True)

            n_processed += BATCH_SIZE
            if n_processed % PRINT_ITER == 0:
                print('[{:06d}] elapsed time [{:3.3f}]'.format(
                    n_processed,
                    time.time() - tstart))

        except tf.errors.OutOfRangeError:
            dt = time.time() - tstart
            spt = dt / float(len(svs.tile_list))
            fps = len(svs.tile_list) / dt
            print('\nFinished. {:2.2f}min {:3.3f}s/tile\n'.format(
                dt / 60., spt))
            print('{:3.3f} fps\n'.format(fps))

            # svs.make_outputs()
            rgb_img = rgb_output(svs)
            break

        except Exception as e:
            print('Caught exception at tiles {}'.format(idx_))
            # print(e.__doc__)
            print(e)
            rgb_img = None
            break

    svs.close()

    # Replace pure black with white
    bwimg = np.mean(rgb_img, axis=-1)
    img_b = bwimg < 10
    rgb_img[img_b, :] = 255
    return rgb_img, fps
Ejemplo n.º 5
0
def main(args):
    # Translate obfuscated file names to paths if necessary
    test_list = os.path.join(args.testdir, '{}.txt'.format(args.timestamp))
    test_list = read_test_list(test_list)
    test_unique_ids = [
        os.path.basename(x).replace('.npy', '') for x in test_list
    ]
    if args.randomize:
        np.random.shuffle(test_unique_ids)

    if args.max_slides:
        test_unique_ids = test_unique_ids[:args.max_slides]

    slide_list, slide_labels = get_slidelist_from_uids(test_unique_ids)

    print('Found {} slides'.format(len(slide_list)))

    snapshot = os.path.join(args.savedir, '{}.h5'.format(args.timestamp))
    # trained_model = load_model(snapshot)
    # if args.mcdropout:
    #   encoder_args['mcdropout'] = True

    encoder_args = get_encoder_args(args.encoder)
    model = MilkEager(encoder_args=encoder_args,
                      mil_type=args.mil,
                      batch_size=args.batch_size,
                      deep_classifier=args.deep_classifier,
                      temperature=args.temperature)

    x_pl = np.zeros((1, args.batch_size, args.input_dim, args.input_dim, 3),
                    dtype=np.float32)
    yhat = model(tf.constant(x_pl), verbose=True)
    print('yhat:', yhat.shape)

    print('setting model weights')
    model.load_weights(snapshot, by_name=True)

    ## Loop over found slides:
    yhats = []
    ytrues = []
    for i, (src, lab) in enumerate(zip(slide_list, slide_labels)):
        print('\nSlide {}'.format(i))
        basename = os.path.basename(src).replace('.svs', '')
        fgpth = os.path.join(args.fgdir, '{}_fg.png'.format(basename))
        if os.path.exists(fgpth):
            ramdisk_path = transfer_to_ramdisk(
                src, args.ramdisk)  # never use the original src
            print('Using fg image at : {}'.format(fgpth))
            fgimg = cv2.imread(fgpth, 0)
            svs = Slide(
                slide_path=ramdisk_path,
                # background_speed  = 'accurate',
                background_speed='image',
                background_image=fgimg,
                # preprocess_fn     = lambda x: (reinhard(x)/255.).astype(np.float32),
                preprocess_fn=lambda x: (x / 255.).astype(np.float32),
                process_mag=args.mag,
                process_size=args.input_dim,
                oversample_factor=args.oversample,
                verbose=False)
        else:
            ## require precomputed background; Exit.
            print('Required foreground image not found ({})'.format(fgpth))
            continue

        svs.initialize_output(name='attention', dim=1, mode='tile')
        n_tiles = len(svs.tile_list)

        yhat, att, indices = process_slide(svs, model, args)
        print('returned attention:', np.min(att), np.max(att), att.shape)

        yhat = yhat.numpy()
        yhats.append(yhat)
        ytrues.append(lab)
        print('\tSlide label: {} predicted: {}'.format(lab, yhat))

        svs.place_batch(att, indices, 'attention', mode='tile')
        attention_img = np.squeeze(svs.output_imgs['attention'])
        attention_img = attention_img * (1. / attention_img.max())
        attention_img = draw_attention(attention_img, n_bins=25)
        print('attention image:', attention_img.shape, attention_img.dtype,
              attention_img.min(), attention_img.max())

        dst = os.path.join(
            args.odir, args.timestamp,
            '{}_{}_{:3.3f}_att.npy'.format(basename, lab, yhat[0, 1]))
        np.save(dst, att)

        dst = os.path.join(
            args.odir, args.timestamp,
            '{}_{}_{:3.3f}_img.png'.format(basename, lab, yhat[0, 1]))
        cv2.imwrite(dst, attention_img)

        try:
            svs.close()
            os.remove(ramdisk_path)
        except:
            print('{} already removed'.format(ramdisk_path))

    yhats = np.concatenate(yhats, axis=0)
    ytrues = np.array(ytrues)
    acc = (np.argmax(yhats, axis=-1) == ytrues).mean()
    print(acc)
Ejemplo n.º 6
0
def main(args):
    ## Search for slides
    # slide_list = sorted(glob.glob(os.path.join(args.slide_dir, '*.svs')))
    slide_list = read_list(args.slide_list)
    print('Found {} slides'.format(len(slide_list)))
    if args.shuffle:
        np.random.shuffle(slide_list)

    encoder_args = get_encoder_args(args.encoder)
    model = ClassifierEager(encoder_args=encoder_args,
                            deep_classifier=True,
                            n_classes=args.n_classes)
    fake_data = tf.constant(
        np.zeros((1, args.input_dim, args.input_dim, 3), dtype=np.float32))
    yhat_ = model(fake_data)
    model.load_weights(args.snapshot)
    model.summary()
    if not os.path.exists(args.save_dir):
        # shutil.rmtree(args.save_dir)
        os.makedirs(args.save_dir)

    ## Loop over found slides:
    for src in slide_list:
        basename = os.path.basename(src).replace('.svs', '')
        dst = os.path.join(args.save_dir, '{}.npy'.format(basename))
        if os.path.exists(dst):
            print('{} exists. Skipping'.format(dst))
            continue

        ramdisk_path = transfer_to_ramdisk(
            src, args.ramdisk)  # never use the original src
        try:
            fgpth = os.path.join(args.fgdir, '{}_fg.png'.format(basename))
            fgimg = cv2.imread(fgpth, 0)
            fgimg = fill_fg(fgimg)
            svs = Slide(slide_path=ramdisk_path,
                        preprocess_fn=lambda x: (x / 255.).astype(np.float32),
                        normalize_fn=lambda x: x,
                        background_speed='image',
                        background_image=fgimg,
                        process_mag=args.mag,
                        process_size=args.input_dim,
                        oversample_factor=1.1)
            svs.initialize_output(name='prob', dim=args.n_classes, mode='full')
            svs.initialize_output(name='rgb', dim=3, mode='full')
            n_tiles = len(svs.tile_list)
            prefetch = min(512, n_tiles)
            # Get tensors for image an index
            iterator = get_img_idx(svs, args.batch_size, prefetch)
            batches = 0
            for img_, idx_ in iterator:
                batches += 1
                yhat = model(img_, training=False)
                yhat = yhat.numpy()
                idx_ = idx_.numpy()
                img_ = img_.numpy()
                yhat = vect_to_tile(yhat, args.input_dim)
                svs.place_batch(yhat, idx_, 'prob', mode='full')
                svs.place_batch(img_, idx_, 'rgb', mode='full')
                if batches % 50 == 0:
                    print('\tbatch {:04d}'.format(batches))

            svs.make_outputs(reference='prob')
            prob_img = svs.output_imgs['prob']
            rgb_img = svs.output_imgs['rgb'] * 255
            color_img = colorize(rgb_img, prob_img)
            dst = os.path.join(args.save_dir, '{}.npy'.format(basename))
            np.save(dst, (prob_img * 255).astype(np.uint8))
            dst = os.path.join(args.save_dir, '{}.jpg'.format(basename))
            cv2.imwrite(dst, rgb_img[:, :, ::-1])
            dst = os.path.join(args.save_dir, '{}_c.jpg'.format(basename))
            cv2.imwrite(dst, color_img[:, :, ::-1])
        except Exception as e:
            print(e)
        finally:
            try:
                print('Closing SVS')
                svs.close()
            except:
                print('No SVS to close')

            os.remove(ramdisk_path)
Ejemplo n.º 7
0
def main(args):
    ## Search for slides
    slide_list = sorted(glob.glob(os.path.join(args.slide_dir, '*.svs')))
    print('Found {} slides'.format(len(slide_list)))

    if args.shuffle:
        np.random.shuffle(slide_list)

    model = ClassifierEager(encoder_args=encoder_args,
                            n_classes=args.n_classes)
    xdummy = tf.constant(
        np.zeros((args.batch_size, args.input_dim, args.input_dim, 3),
                 dtype=np.float32))
    yhat_op = model(xdummy, verbose=True)
    model.load_weights(args.snapshot, by_name=True)

    # model = tf.keras.models.load_model(args.snapshot, compile=False)
    model.summary()

    ## Loop over found slides:
    for src in slide_list:
        basename = os.path.basename(src).replace('.svs', '')
        dst = os.path.join(args.save_dir, '{}.npy'.format(basename))

        if os.path.exists(dst):
            print('{} exists. continuing'.format(dst))
            continue

        fgpth = os.path.join(args.fgdir, '{}_fg.png'.format(basename))
        if os.path.exists(fgpth):
            print('fg image found:', fgpth)
            fgimg = cv2.imread(fgpth, 0)
            speed = 'image'
        else:
            speed = 'fast'
            fgimg = None

        ramdisk_path = transfer_to_ramdisk(
            src, args.ramdisk)  # never use the original src
        try:
            svs = Slide(slide_path=ramdisk_path,
                        preprocess_fn=lambda x:
                        (reinhard(x) / 255.).astype(np.float32),
                        background_speed=speed,
                        background_image=fgimg,
                        process_mag=args.mag,
                        process_size=args.input_dim,
                        oversample_factor=1.75)

            svs.initialize_output(name='prob', dim=args.n_classes, mode='full')
            svs.initialize_output(name='rgb', dim=3, mode='full')

            n_tiles = len(svs.tile_list)
            prefetch = min(128, n_tiles)
            print('Tiles:', n_tiles)

            # Get tensors for image an index
            # img, idx = get_img_idx(svs, args.batch_size, prefetch)
            iterator = get_iterator(svs, args.batch_size, prefetch)

            batches = 0
            for img_, idx_ in iterator:
                batches += 1
                # img_, idx_ = next(iterator)

                yhat = model(img_, training=False)
                yhat = yhat.numpy()
                idx_ = idx_.numpy()

                yhat = vect_to_tile(yhat, args.input_dim)
                svs.place_batch(yhat, idx_, 'prob', mode='full')
                svs.place_batch(img_.numpy(), idx_, 'rgb', mode='full')
                if batches % 50 == 0:
                    print('batch {:04d}'.format(batches))

            print('Making outputs')
            svs.make_outputs(reference='prob')
            prob_img = svs.output_imgs['prob']
            print('prob img', prob_img.shape)

            rgb_img = svs.output_imgs['rgb'] * 255
            print('rgb img', rgb_img.shape)

            color_img = colorize(rgb_img, prob_img)
            print('color img', color_img.shape)

            dst = os.path.join(args.save_dir, '{}.npy'.format(basename))
            np.save(dst, (prob_img * 255).astype(np.uint8))
            dst = os.path.join(args.save_dir, '{}.jpg'.format(basename))
            cv2.imwrite(dst, rgb_img[:, :, ::-1])
            dst = os.path.join(args.save_dir, '{}_c.jpg'.format(basename))
            cv2.imwrite(dst, color_img[:, :, ::-1])

            svs.close()
            svs = []

        except Exception as e:
            print(e)

        finally:
            print('Removing {}'.format(ramdisk_path))
            os.remove(ramdisk_path)
Ejemplo n.º 8
0
def main(args, sess):
    # Translate obfuscated file names to paths if necessary
    test_list = os.path.join(args.testdir, '{}.txt'.format(args.timestamp))
    test_list = read_test_list(test_list)
    test_unique_ids = [
        os.path.basename(x).replace('.npy', '') for x in test_list
    ]
    if args.randomize:
        np.random.shuffle(test_unique_ids)
    slide_list, slide_labels = get_slidelist_from_uids(test_unique_ids)

    print('Found {} slides'.format(len(slide_list)))

    snapshot = os.path.join(args.savedir, '{}.h5'.format(args.timestamp))
    # trained_model = load_model(snapshot)
    # if args.mcdropout:
    #   encoder_args['mcdropout'] = True

    encode_model = MilkEncode(input_shape=(args.input_dim, args.input_dim, 3),
                              encoder_args=encoder_args,
                              deep_classifier=args.deep_classifier)

    x_pl = tf.placeholder(shape=(None, args.input_dim, args.input_dim, 3),
                          dtype=tf.float32)
    z_op = encode_model(x_pl)

    input_shape = z_op.shape[-1]
    predict_model = MilkPredict(input_shape=[input_shape],
                                mode=args.mil,
                                use_gate=args.gated_attention,
                                deep_classifier=args.deep_classifier)
    attention_model = MilkAttention(input_shape=[input_shape],
                                    use_gate=args.gated_attention)

    print('setting encoder weights')
    encode_model.load_weights(snapshot, by_name=True)
    print('setting predict weights')
    predict_model.load_weights(snapshot, by_name=True)
    print('setting attention weights')
    attention_model.load_weights(snapshot, by_name=True)

    z_pl = tf.placeholder(shape=(None, input_shape), dtype=tf.float32)
    y_op = predict_model(z_pl)
    att_op = attention_model(z_pl)

    # fig = plt.figure(figsize=(2,2), dpi=180)

    ## Loop over found slides:
    yhats = []
    ytrues = []
    for i, (src, lab) in enumerate(zip(slide_list, slide_labels)):
        print('\nSlide {}'.format(i))
        basename = os.path.basename(src).replace('.svs', '')
        fgpth = os.path.join(args.fgdir, '{}_fg.png'.format(basename))
        if os.path.exists(fgpth):
            ramdisk_path = transfer_to_ramdisk(
                src, args.ramdisk)  # never use the original src
            print('Using fg image at : {}'.format(fgpth))
            fgimg = cv2.imread(fgpth, 0)
            svs = Slide(
                slide_path=ramdisk_path,
                # background_speed  = 'accurate',
                background_speed='image',
                background_image=fgimg,
                preprocess_fn=lambda x: (x / 255.).astype(np.float32),
                process_mag=args.mag,
                process_size=args.input_dim,
                oversample_factor=args.oversample,
                verbose=False)
        else:
            ## require precomputed background; Exit.
            print('Required foreground image not found ({})'.format(fgpth))
            continue

        svs.initialize_output(name='attention', dim=1, mode='tile')
        n_tiles = len(svs.tile_list)

        if not args.mcdropout:
            yhat, att, indices = process_slide(svs, encode_model, y_op, att_op,
                                               z_pl, args)
        else:
            yhat, att, yhat_sd, att_sd, indices = process_slide_mcdropout(
                svs, encode_model, y_op, att_op, z_pl, args)

        yhats.append(yhat)
        ytrues.append(lab)
        print('\tSlide label: {} predicted: {}'.format(lab, yhat))

        svs.place_batch(att, indices, 'attention', mode='tile')
        attention_img = np.squeeze(svs.output_imgs['attention'])
        attention_img = attention_img * (1. / attention_img.max())
        attention_img = draw_attention(attention_img, n_bins=50)
        print('attention image:', attention_img.shape, attention_img.dtype,
              attention_img.min(), attention_img.max())

        dst = os.path.join(args.odir, args.timestamp,
                           '{}_att.npy'.format(basename))
        np.save(dst, att)

        dst = os.path.join(args.odir, args.timestamp,
                           '{}_img.png'.format(basename))
        cv2.imwrite(dst, attention_img)

        # dst = os.path.join(args.odir, args.timestamp, '{}_hist.png'.format(basename))
        # fig.clf()
        # plt.hist(att, bins=100);
        # plt.title('Attention distribution\n{} ({} tiles)'.format(basename, n_tiles))
        # plt.xlabel('Attention score')
        # plt.ylabel('Tile count')
        # plt.savefig(dst, bbox_inches='tight')

        try:
            svs.close()
            os.remove(ramdisk_path)
        except:
            print('{} already removed'.format(ramdisk_path))
Ejemplo n.º 9
0
def main(args, sess):
    dst = os.path.join(args.odir, 'auc_{}.png'.format(args.timestamp))
    if os.path.exists(dst):
        print('{} exists. Exiting.'.format(dst))
        return

    # Translate obfuscated file names to paths if necessary
    test_list = os.path.join(args.testdir, '{}.txt'.format(args.timestamp))
    test_list = read_test_list(test_list)
    test_unique_ids = [
        os.path.basename(x).replace('.npy', '') for x in test_list
    ]
    slide_list, slide_labels = get_slidelist_from_uids(test_unique_ids)

    print('Found {} slides'.format(len(slide_list)))

    snapshot = os.path.join(args.savedir, '{}.h5'.format(args.timestamp))
    trained_model = load_model(snapshot)
    if args.mcdropout:
        encoder_args['mcdropout'] = True

    encode_model = MilkEncode(input_shape=(args.input_dim, args.input_dim, 3),
                              encoder_args=encoder_args)
    encode_shape = list(encode_model.output.shape)
    predict_model = MilkPredict(input_shape=[512],
                                mode=args.mil,
                                use_gate=args.gated_attention)
    # attention_model = MilkAttention(input_shape=[512], use_gate=args.gated_attention)

    models = model_utils.make_inference_functions(
        encode_model,
        predict_model,
        trained_model,
    )
    encode_model, predict_model = models

    z_pl = tf.placeholder(shape=(None, 512), dtype=tf.float32)
    y_op = predict_model(z_pl)

    fig = plt.figure(figsize=(2, 2), dpi=180)

    ## Loop over found slides:
    yhats = []
    ytrues = []
    for i, (src, lab) in enumerate(zip(slide_list, slide_labels)):
        print('\nSlide {}'.format(i))
        basename = os.path.basename(src).replace('.svs', '')
        fgpth = os.path.join(args.fgdir, '{}_fg.png'.format(basename))
        if os.path.exists(fgpth):
            ramdisk_path = transfer_to_ramdisk(
                src, args.ramdisk)  # never use the original src
            print('Using fg image at : {}'.format(fgpth))
            fgimg = cv2.imread(fgpth, 0)
            print('fgimg shape: ', fgimg.shape)
            svs = Slide(
                slide_path=ramdisk_path,
                # background_speed  = 'accurate',
                background_speed='image',
                background_image=fgimg,
                preprocess_fn=lambda x: (x / 255.).astype(np.float32),
                process_mag=args.mag,
                process_size=args.input_dim,
                oversample_factor=1.5,
                verbose=True)
            print('calculated foregroud: ', svs.foreground.shape)
            print('calculated ds_tile_map: ', svs.ds_tile_map.shape)
        else:
            ## Require precomputed background
            print('No fg image found ({})'.format(fgpth))
            continue

        n_tiles = len(svs.tile_list)

        if not args.mcdropout:
            yhat, indices = process_slide(svs, sess, encode_model, y_op, z_pl,
                                          args)
        else:
            yhat, yhat_sd, indices = process_slide_mcdropout(
                svs, sess, encode_model, y_op, z_pl, args)

        yhats.append(yhat)
        ytrues.append(lab)
        print('\tSlide label: {} predicted: {}'.format(lab, yhat))
        os.remove(ramdisk_path)
        svs.close()
        del svs

    yhats = np.concatenate(yhats, axis=0)
    ytrue = np.array(ytrues)
    for i, yt in enumerate(ytrue):
        print(yt, yhats[i, :])

    auc_curve(ytrue, yhats, savepath=dst)
    del encode_model
    del predict_model