示例#1
0
def main(args):

  #with tf.device('/gpu:0'): 
  print('Model initializing')
  encoder_args = get_encoder_args(args.encoder)
  model = MilkEager(encoder_args=encoder_args, 
                    mil_type=args.mil,
                    deep_classifier=args.deep_classifier,
                    batch_size=16,
                    temperature=args.temperature,
                    heads=args.heads)
  print('Running once to load CUDA')
  x = tf.zeros((1, 1, args.crop_size, args.crop_size, args.channels))

  ## This is really weird. eager mode complains when this is a range()
  ## It evern complains when it's a list(range())
  ## If the tf.contrib.eager.defun decorator is removed, it's OK
  ## So it's an autograph problem 
  all_heads = [0,1,2,3,4,5,6,7,8,9]
  yhat = model(x.gpu(), heads=all_heads, training=True, verbose=True)
  model.summary()

  model.load_weights(args.pretrained_model, by_name=True)

  ## Set up the data stuff
  data_factory = MILDataset(args.dataset, crop=args.crop_size, n_classes=2)
  data_factory.split_train_val('case_id', seed=args.seed)
 
  test_iterator = data_factory.tensorflow_iterator(mode='test', seed=args.seed, 
                                                   batch_size=1,
                                                   buffer_size=1,
                                                   threads=1,
                                                   subset=args.bag_size, 
                                                   attr='stage_code', 
                                                   eager=True)

  ## Track yhats
  print('-------------------------------------------------------\n\n')
  ytrues , yhats = [], []
  all_heads = list(range(args.heads))
  for k, (x, y) in enumerate(test_iterator):

    # print('{:03d}: ytrue = {}'.format(k, y[0,1]))
    ytrues.append(y[0,1])

    yhat = model(x.gpu(), training=False, heads=all_heads)
    yhat = np.array([yh[0,1] for yh in yhat])
    yhats.append(yhat)
    # print('     yhat = {}'.format(yhat))

    # Take a running mean and show it
    acc, _ = calc_acc(ytrues, yhats)
    print('\r{:03d} Accuracy: {:3.3f} %'.format(k, acc), end='', flush = True)
    # sys.stdout.flush()

  
  print('\n\n-------------------------------------------------------')
  print_accuracy(ytrues, yhats)
  write_out(ytrues, yhats, args.out)
示例#2
0
def main(args):
    transform_fn = data_utils.make_transform_fn(128,
                                                128,
                                                args.input_dim,
                                                1.0,
                                                normalize=True)

    snapshot = os.path.join('../experiment/save',
                            '{}.h5'.format(args.timestamp))
    test_list = os.path.join('../experiment/test_lists',
                             '{}.txt'.format(args.timestamp))

    model = MilkEager(encoder_args=encoder_args,
                      mil_type=args.mil,
                      deep_classifier=True)
    x_dummy = tf.zeros(
        shape=[1, args.batch_size, args.input_dim, args.input_dim, 3],
        dtype=tf.float32)
    retvals = model(x_dummy, verbose=True)
    model.load_weights(snapshot, by_name=True)
    model.summary()

    test_list = read_test_list(test_list)

    yhats, ytrues = [], []
    features_case, features_classifier = [], []
    for test_case in test_list:
        case_name = os.path.basename(test_case).replace('.npy', '')
        print(test_case, case_name)
        # case_path = os.path.join('../dataset/tiles_reduced', '{}.npy'.format(case_name))
        case_x = np.load(test_case)
        case_x = np.stack([transform_fn(x) for x in case_x], 0)
        print(case_x.shape)

        if args.sample:
            case_x = case_x[
                np.random.choice(range(case_x.shape[0]), args.sample), ...]
            print(case_x.shape)

        features = model.encode_bag(case_x,
                                    batch_size=args.batch_size,
                                    training=True,
                                    return_z=True)
        print('features:', features.shape)
        # features_att, attention = model.mil_attention(features, return_att=True, training=False)
        # print('features:', features_att.shape, 'attention:', attention.shape)

        features_att = tf.reduce_mean(features, axis=0, keep_dims=True)
        yhat_instances = model.apply_classifier(features, training=False)
        print('yhat instances:', yhat_instances.shape)
        yhat = model.apply_classifier(features_att, training=False)
        print('yhat:', yhat.shape)

        yhat_1 = yhat_instances[:, 1].numpy()
        # yhat, attention, features, feat_case, feat_class = retvals
        # attention = np.squeeze(attention.numpy(), axis=0)
        high_att_idx, high_att_imgs, low_att_idx, low_att_imgs = get_attention_extremes(
            yhat_1, case_x, n=5)

        print('Case {}: predicted={} ({})'.format(test_case,
                                                  np.argmax(yhat, axis=-1),
                                                  yhat.numpy()))

        features = features.numpy()
        savepath = '{}_{}_{:3.2f}_yhat.png'.format(args.savebase, case_name,
                                                   yhat[0, 1])
        print('Saving figure {}'.format(savepath))
        z = draw_projection(features, features_att, yhat_1, savepath=savepath)

        savepath = '{}_{}_{:3.2f}_yhat_img.png'.format(args.savebase,
                                                       case_name, yhat[0, 1])
        print('Saving figure {}'.format(savepath))
        draw_projection_with_images(z,
                                    yhat_1,
                                    high_att_idx,
                                    high_att_imgs,
                                    low_att_idx,
                                    low_att_imgs,
                                    savepath=savepath)
示例#3
0
def main(args):
    # Define a compute_fn that should do three things:
    # 1. define an iterator over the slide's tiles
    # 2. compute an output with given model parameter
    # 3.

    if args.iter_type == 'python':

        def compute_fn(slide, args, model=None):
            print('Slide with {}'.format(len(slide.tile_list)))
            it_factory = PythonIterator(slide, args)
            for k, (img, idx) in enumerate(it_factory.yield_batch()):
                prob = model(img)
                if k % 50 == 0:
                    print('Batch #{:04d} idx:{} img:{} prob:{}'.format(
                        k, idx.shape, img.shape, prob.shape))
                slide.place_batch(prob, idx, 'prob', mode='tile')
            ret = slide.output_imgs['prob']
            return ret

    # Tensorflow multithreaded queue-based iterator (in eager mode)
    elif args.iter_type == 'tf':

        def compute_fn(slide, args, model=None):
            assert tf.executing_eagerly()
            print('Slide with {}'.format(len(slide.tile_list)))

            # In eager mode, we return a tf.contrib.eager.Iterator
            eager_iterator = TensorflowIterator(slide, args).make_iterator()

            # The iterator can be used directly. Ququeing and multithreading
            # are handled in the backend by the tf.data.Dataset ops
            features, indices = [], []
            for k, (img, idx) in enumerate(eager_iterator):
                # img = tf.expand_dims(img, axis=0)
                features.append(
                    model.encode_bag(img, training=False, return_z=True))
                indices.append(idx.numpy())

                img, idx = img.numpy(), idx.numpy()
                if k % 50 == 0:
                    print('Batch #{:04d}\t{}'.format(k, img.shape))

            features = tf.concat(features, axis=0)
            z_att, att = model.mil_attention(features,
                                             training=False,
                                             return_raw_att=True)
            att = np.squeeze(att)
            indices = np.concatenate(indices)
            slide.place_batch(att, indices, 'att', mode='tile')
            ret = slide.output_imgs['att']
            return ret

    # Set up the model first
    encoder_args = get_encoder_args(args.encoder)
    model = MilkEager(encoder_args=encoder_args,
                      mil_type=args.mil,
                      deep_classifier=args.deep_classifier,
                      batch_size=args.batchsize,
                      temperature=args.temperature,
                      heads=args.heads)

    x = tf.zeros((1, 1, args.process_size, args.process_size, 3))
    _ = model(x, verbose=True, head='all', training=True)
    model.load_weights(args.snapshot, by_name=True)

    # keras Model subclass
    model.summary()

    # Read list of inputs
    with open(args.slides, 'r') as f:
        slides = [x.strip() for x in f]

    # Loop over slides
    for src in slides:
        # Dirty substitution of the file extension give us the
        # destination. Do this first so we can just skip the slide
        # if this destination already exists.
        # Set the --suffix option to reflect the model / type of processed output
        dst = repext(src, args.suffix)

        # Loading data from ramdisk incurs a one-time copy cost
        rdsrc = cpramdisk(src, args.ramdisk)
        print('File:', rdsrc)

        # Wrapped inside of a try-except-finally.
        # We want to make sure the slide gets cleaned from
        # memory in case there's an error or stop signal in the
        # middle of processing.
        try:
            # Initialze the side from our temporary path, with
            # the arguments passed in from command-line.
            # This returns an svsutils.Slide object
            slide = Slide(rdsrc, args)

            # This step will eventually be included in slide creation
            # with some default compute_fn's provided by svsutils
            # For now, do it case-by-case, and use the compute_fn
            # that we defined just above.
            slide.initialize_output('att',
                                    args.n_classes,
                                    mode='tile',
                                    compute_fn=compute_fn)

            # Call the compute function to compute this output.
            # Again, this may change to something like...
            #     slide.compute_all
            # which would loop over all the defined output types.
            ret = slide.compute('att', args, model=model)
            print('{} --> {}'.format(ret.shape, dst))
            np.save(dst, ret[:, :, ::-1])
        except Exception as e:
            print(e)
            traceback.print_tb(e.__traceback__)
        finally:
            print('Removing {}'.format(rdsrc))
            os.remove(rdsrc)
示例#4
0
def main(args):
  transform_fn = data_utils.make_transform_fn(128, 128, args.input_dim, 1.0, normalize=True)

  snapshot = os.path.join(args.snapshot_dir, '{}.h5'.format(args.timestamp))
  test_list = os.path.join(args.test_list_dir, '{}.txt'.format(args.timestamp))

  encoder_args = get_encoder_args(args.encoder)
  model = MilkEager(encoder_args=encoder_args, 
                    deep_classifier=True, 
                    mil_type='instance',
                    batch_size=args.batch_size,
                    temperature=args.temperature,
                    cls_normalize=args.cls_normalize)

  x_dummy = tf.zeros(shape=[1, args.batch_size, args.input_dim, args.input_dim, 3], dtype=tf.float32)
  retvals = model(x_dummy, verbose=True)
  model.load_weights(snapshot, by_name=True)
  model.summary()

  test_list = read_test_list(test_list)
  savebase = os.path.join(args.odir, args.timestamp)
  if os.path.exists(savebase):
    shutil.rmtree(savebase)
  os.makedirs(savebase)

  yhats, ytrues = [], []
  features_case, features_classifier = [], []
  for test_case in test_list:
    case_name = os.path.basename(test_case).replace('.npy', '')
    print(test_case, case_name)
    # case_path = os.path.join('../dataset/tiles_reduced', '{}.npy'.format(case_name))
    case_x = np.load(test_case)
    case_x = np.stack([transform_fn(x) for x in case_x], 0)
    ytrue = case_dict[case_name]
    print(case_x.shape, ytrue)
    ytrues.append(ytrue)

    if args.sample:
      case_x = case_x[np.random.choice(range(case_x.shape[0]), args.sample), ...]
      print(case_x.shape)

    # TODO variable names. attention --> something else
    features, attention = model.encode_bag(case_x, training=False, return_z=True)
    yhat = np.mean(attention, axis=0, keepdims=True)
    attention = attention.numpy()[:,1]
    print('features:', features.shape)
    print('attention:', attention.shape)
    print('yhat:', yhat.shape)
    # features_att, attention = model.mil_attention(features, return_att=True, training=False)
    # print('features:', features_att.shape, 'attention:', attention.shape)

    features_avg = np.mean(features, axis=0, keepdims=True)
    features_att = features_avg

    yhats.append(yhat)

    # yhat, attention, features, feat_case, feat_class = retvals
    # attention = np.squeeze(attention.numpy(), axis=0)
    high_att_idx, high_att_imgs, low_att_idx, low_att_imgs = get_attention_extremes(
      attention, case_x, n = 5)

    print('Case {}: predicted={} ({})'.format( test_case, np.argmax(yhat, axis=-1) , yhat))

    features = features.numpy()
    savepath = os.path.join(savebase, '{}_{:3.2f}.png'.format(case_name, yhat[0,1]))
    print('Saving figure {}'.format(savepath))
    z = draw_projection(features, features_avg, features_att, attention, savepath=savepath)

    # savepath = os.path.join(savebase, '{}_{:3.2f}_ys.png'.format(case_name, yhat[0,1]))
    # print('Saving figure {}'.format(savepath))
    # draw_projection_with_images(z, yhat_instances[:,1].numpy(), 
    #   high_att_idx, high_att_imgs, 
    #   low_att_idx, low_att_imgs, 
    #   savepath=savepath)

    savepath = os.path.join(savebase, '{}_{:3.2f}_imgs.png'.format(case_name, yhat[0,1]))
    print('Saving figure {}'.format(savepath))
    draw_projection_with_images(z, attention, 
      high_att_idx, high_att_imgs, 
      low_att_idx, low_att_imgs, 
      savepath=savepath)

    savepath = os.path.join(savebase, '{}_atns.npy'.format(case_name))
    np.save(savepath, attention)
    savepath = os.path.join(savebase, '{}_feat.npy'.format(case_name))
    np.save(savepath, features)

  yhats = np.concatenate(yhats, axis=0)
  yhats = np.argmax(yhats, axis=1)
  ytrues = np.array(ytrues)
  acc = (yhats == ytrues).mean()
  print(acc)
  cm = confusion_matrix(y_true=ytrues, y_pred=yhats)
  print(cm)
示例#5
0
def main(args):
    if args.mnist is not None:
        (train_x, train_y), (test_x, test_y) = mnist.load_data(args.mnist)
    else:
        (train_x, train_y), (test_x, test_y) = mnist.load_data()

    print('train_x:', train_x.shape, train_x.dtype, train_x.min(),
          train_x.max())
    print('train_y:', train_y.shape)
    print('test_x:', test_x.shape)
    print('test_y:', test_y.shape)

    positive_label = np.random.choice(range(10))
    print('using positive label = {}'.format(positive_label))

    train_x_pos, train_x_neg = rearrange_bagged_mnist(train_x, train_y,
                                                      positive_label)
    test_x_pos, test_x_neg = rearrange_bagged_mnist(test_x, test_y,
                                                    positive_label)
    print('rearranged training set:')
    print('\ttrain_x_pos:', train_x_pos.shape, train_x_pos.dtype,
          train_x_pos.min(), train_x_pos.max())
    print('\ttrain_x_neg:', train_x_neg.shape)
    print('\ttest_x_pos:', test_x_pos.shape)
    print('\ttest_x_neg:', test_x_neg.shape)

    generator = generate_bagged_mnist(train_x_pos, train_x_neg, args.n,
                                      args.batch_size)
    val_generator = generate_bagged_mnist(test_x_pos, test_x_neg, args.n,
                                          args.batch_size)
    batch_x, batch_y = next(generator)
    print('batch_x:', batch_x.shape, 'batch_y:', batch_y.shape)

    encoder_args = get_encoder_args('mnist')
    model = MilkEager(
        encoder_args=encoder_args,
        mil_type=args.mil,
        deep_classifier=True,
    )
    y_dummy = model(batch_x, verbose=True)
    model.summary()

    if args.pretrained is not None and os.path.exists(args.pretrained):
        model.load_weights(args.pretrained, by_name=True)
    else:
        print('Pretrained model not found ({}). Continuing end 2 end.'.format(
            args.pretrained))

    if args.gpus > 1:
        print('Duplicating model onto 2 GPUs')
        model = tf.keras.utils.multi_gpu_model(model,
                                               args.gpus,
                                               cpu_merge=True,
                                               cpu_relocation=False)

    optimizer = tf.train.AdamOptimizer(learning_rate=args.lr)

    try:
        for k in range(int(args.steps_per_epoch * args.epochs)):
            with tf.GradientTape() as tape:
                x, y = next(generator)
                yhat = model(tf.constant(x), training=True)
                loss = tf.keras.losses.categorical_crossentropy(
                    y_true=tf.constant(y, dtype=tf.float32), y_pred=yhat)

            grads = tape.gradient(loss, model.variables)
            optimizer.apply_gradients(zip(grads, model.variables))

            if k % 50 == 0:
                print('{:06d}: loss={:3.5f}'.format(k, np.mean(loss)))
                for y_, yh_ in zip(y, yhat):
                    print('\t{} {}'.format(y_, yh_))

    except KeyboardInterrupt:
        print('Keyboard interrupt caught')

    except Exception as e:
        print('Other error caught')
        print(type(e))
        print(e)

    finally:
        model.save_weights(args.o)
        print('Saved model: {}'.format(args.o))
示例#6
0
def main(args):


  # Define a compute_fn that should do three things:
  # 1. define an iterator over the slide's tiles
  # 2. compute an output with a given model / arguments
  # 3. return a reconstructed slide
  def compute_fn(slide, args, model=None, n_dropout=10 ):
    assert tf.executing_eagerly()
    print('Slide with {}'.format(len(slide.tile_list)))

    # In eager mode, we return a tf.contrib.eager.Iterator
    eager_iterator = TensorflowIterator(slide, args).make_iterator()

    # The iterator can be used directly. Ququeing and multithreading
    # are handled in the backend by the tf.data.Dataset ops
    features, indices = [], []
    for k, (img, idx) in enumerate(eager_iterator):
      # img = tf.expand_dims(img, axis=0)
      features.append( model.encode_bag(img, training=False, return_z=True) )
      indices.append(idx.numpy())

      img, idx = img.numpy(), idx.numpy()
      if k % 50 == 0:
        print('Batch #{:04d}\t{}'.format(k, img.shape))

    features = tf.concat(features, axis=0)

    ## Sample-dropout
    # features = features.numpy()
    # print(features.shape)
    # n_instances = features.shape[0]
    # att = np.zeros(n_instances)
    # n_choice = int(n_instances * 0.7)
    # all_heads = list(range(args.heads))
    # for j in range(n_dropout):
    #   idx = np.random.choice(range(n_instances), n_choice, replace=False)
    #   print(idx)
    #   fdrop = features[idx, :]

    z_att, att = model.mil_attention(features,
                                     training=False, 
                                     return_raw_att=True)

    # att[idx] += np.squeeze(attdrop)
    yhat_multihead = model.apply_classifier(z_att, heads=all_heads, 
      training=False)
    print('yhat mean {}'.format(np.mean(yhat_multihead, axis=0)))

    indices = np.concatenate(indices)
    att = np.squeeze(att)
    slide.place_batch(att, indices, 'att', mode='tile')
    ret = slide.output_imgs['att']
    print('Got attention image: {}'.format(ret.shape))

    return ret, features.numpy()




  ## Begin main script:
  # Set up the model first
  encoder_args = get_encoder_args(args.encoder)
  model = MilkEager(encoder_args=encoder_args,
                    mil_type=args.mil,
                    deep_classifier=args.deep_classifier,
                    batch_size=args.batchsize,
                    temperature=args.temperature,
                    heads = args.heads)
  
  x = tf.zeros((1, 1, args.process_size,
                args.process_size, 3))
  all_heads = [0,1,2,3,4,5,6,7,8,9]
  _ = model(x, verbose=True, heads=all_heads, training=True)
  model.load_weights(args.snapshot, by_name=True)

  # keras Model subclass
  model.summary()

  # Read list of inputs
  with open(args.slides, 'r') as f:
    slides = [x.strip() for x in f]

  # Loop over slides
  for src in slides:
    # Dirty substitution of the file extension give us the
    # destination. Do this first so we can just skip the slide
    # if this destination already exists.
    # Set the --suffix option to reflect the model / type of processed output
    dst = repext(src, args.suffix)
    featdst = repext(src, args.suffix+'.feat.npy')

    # Loading data from ramdisk incurs a one-time copy cost
    rdsrc = cpramdisk(src, args.ramdisk)
    print('\n\nFile:', rdsrc)

    # Wrapped inside of a try-except-finally.
    # We want to make sure the slide gets cleaned from 
    # memory in case there's an error or stop signal in the 
    # middle of processing.
    try:
      # Initialze the side from our temporary path, with 
      # the arguments passed in from command-line.
      # This returns an svsutils.Slide object
      slide = Slide(rdsrc, args)

      # This step will eventually be included in slide creation
      # with some default compute_fn's provided by svsutils
      # For now, do it case-by-case, and use the compute_fn
      # that we defined just above.
      slide.initialize_output('att', args.n_classes, mode='tile',
        compute_fn=compute_fn)

      # Call the compute function to compute this output.
      # Again, this may change to something like...
      #     slide.compute_all
      # which would loop over all the defined output types.
      ret, features = slide.compute('att', args, model=model)
      print('{} --> {}'.format(ret.shape, dst))
      print('{} --> {}'.format(features.shape, featdst))
      np.save(dst, ret)
      np.save(featdst, features)
    except Exception as e:
      print(e)
      traceback.print_tb(e.__traceback__)
    finally:
      print('Removing {}'.format(rdsrc))
      os.remove(rdsrc)
示例#7
0
def main(args):
    """ 
  1. Create generator datasets from the provided lists
  2. train and validate Milk

  v0 - create datasets within this script
  v1 - factor monolithic training_utils.mil_train_loop !!
  tpu - replace data feeder and mil_train_loop with tf.keras.Model.fit()
  July 4 2019 - added MILDataset that takes away all the dataset nonsense
  """
    out_path, exptime_str = create_outputs(args)

    data_factory = MILDataset(args.dataset, crop=args.crop_size, n_classes=2)
    data_factory.split_train_val('case_id', seed=args.seed)

    #with tf.device('/gpu:0'):
    print('Model initializing')
    encoder_args = get_encoder_args(args.encoder)
    model = MilkEager(encoder_args=encoder_args,
                      mil_type=args.mil,
                      deep_classifier=args.deep_classifier,
                      batch_size=16,
                      temperature=args.temperature,
                      heads=args.heads)
    print('Running once to load CUDA')
    x = tf.zeros((1, 1, args.crop_size, args.crop_size, args.channels))

    ## The way we give the list is very particular.
    all_heads = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
    yhat = model(x, heads=all_heads, training=True, verbose=True)
    print('yhat: {} ({} {})'.format(yhat[0], yhat[0].shape, yhat[0].dtype))
    model.summary()

    ## Replace randomly initialized weights after model is compiled and on the correct device.
    if args.pretrained_model is not None and os.path.exists(
            args.pretrained_model):
        print('Replacing random weights with weights from {}'.format(
            args.pretrained_model))
        try:
            model.load_weights(args.pretrained_model, by_name=True)
        except Exception as e:
            print(e)

    ## Controlling overfitting by monitoring a metric, with some patience since the last improvement
    if args.early_stop:
        stopper = ShouldStop(patience=args.heads)
    else:
        stopper = lambda x: False

    # accumulator = GradientAccumulator(n = args.accumulate, variable_list=trainable_variables)

    optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
    trackers = {stat: [] for stat in ['loss', 'acc']}
    trackers['step'] = 0

    def py_it():
        return data_factory.python_iterator(mode='train',
                                            subset=args.bag_size,
                                            attr='stage_code',
                                            seed=None,
                                            epochs=args.epochs)

    train_len = data_factory.dataset_lengths['train']
    tf_dataset = (
        tf.data.Dataset.from_generator(py_it,
                                       output_types=(tf.uint8, tf.uint8))
        # .repeat(repeats)
        .map(data_factory.map_fn, num_parallel_calls=args.threads)
        # .prefetch(buffer_size)
        .batch(args.batch_size))

    # data_factory.tensorflow_iterator(mode='train', ref='train', repeats=args.epochs,
    #   subset=args.bag_size, seed=None, attr='stage_code', threads=args.threads)
    # data_factory.tensorflow_iterator(mode='val', ref='val', repeats=args.epochs,
    #   subset=args.bag_size, seed=None, attr='stage_code', threads=args.threads)

    # val_iterator, val_len = data_factory.iterator_refs['val']
    # train_iterator, train_len = data_factory.iterator_refs['train']

    trainable_variables = model.trainable_variables
    try:
        # for epc in range(args.epochs):
        #   tf.set_random_seed(1)
        # trackers = train_epoch(model, optimizer, train_iterator, train_len, epc, trackers, args)
        # train_head = [epc % args.heads]

        avglosses, steptimes = [], []
        # print('\nTraining head {}'.format(train_head))
        for k, (x, y) in enumerate(tf_dataset):
            if k % train_len == 0:
                gc.collect()
                tf.set_random_seed(1)
                # tf.reset_default_graph()
                train_head = [np.random.choice(args.heads)]
                print('\nTraining head [{}]'.format(train_head))

            tstart = time.time()
            with tf.GradientTape() as tape:
                yhat = model(x, training=True, heads=train_head)
                loss = tf.keras.losses.categorical_crossentropy(
                    y_true=tf.constant(y, dtype=tf.float32), y_pred=yhat[0])

            grads = tape.gradient(loss, trainable_variables)
            del tape
            loss_mn = np.mean(loss)
            acc = eval_acc(y, yhat)
            avglosses.append(loss_mn)

            trackers['loss'].append(loss_mn)
            trackers['acc'].append(acc)
            trackers['step'] += 1

            tend = time.time()
            steptimes.append(tend - tstart)
            # if should_update:
            #   grads = accumulator.accumulate()
            # with tf.device('/cpu:0'):
            # print('Applying gradients')
            optimizer.apply_gradients(zip(grads, trainable_variables))
            print('\r{:07d}: loss={:3.5f} dt={:3.3f}s   '.format(
                k, np.mean(avglosses), np.mean(steptimes)),
                  end='',
                  flush=1)
            # if (k+1) % train_len == 0: break

            # if epc % args.snapshot_epochs == 0:
            #   snapshot_path = out_path.replace('.h5', '-{:03d}.h5'.format(epc))
            #   print('Snapshotting to {}'.format(snapshot_path))
            #   model.save_weights(snapshot_path)

    except KeyboardInterrupt:
        print('Keyboard interrupt caught')

    except Exception as e:
        print('Other error caught')
        print(e)
        traceback.print_tb(e.__traceback__)

    finally:
        model.save_weights(out_path)
        print('Saved model: {}'.format(out_path))

        # Save the loss profile
        training_stats = os.path.join(
            args.out_base, 'save',
            '{}_training_curves.txt'.format(exptime_str))
        print('Dumping training stats --> {}'.format(training_stats))
        with open(training_stats, 'w+') as f:
            for s, l, a in zip(np.arange(trackers['step']), trackers['loss'],
                               trackers['acc']):
                f.write('{:06d}\t{:3.5f}\t{:3.5f}\n'.format(s, l, a))
示例#8
0
def main(args):
    # Translate obfuscated file names to paths if necessary
    slide_list = read_list(args.f)
    print('Found {} slides'.format(len(slide_list)))

    encoder_args = get_encoder_args(args.encoder)
    model = MilkEager(encoder_args=encoder_args,
                      mil_type=args.mil,
                      batch_size=args.batch_size,
                      temperature=args.temperature,
                      deep_classifier=args.deep_classifier)

    x_pl = np.zeros((1, args.batch_size, args.input_dim, args.input_dim, 3),
                    dtype=np.float32)
    yhat = model(tf.constant(x_pl), verbose=True)
    print('yhat:', yhat.shape)

    print('setting model weights')
    model.load_weights(args.s, by_name=True)

    ## Loop over found slides:
    yhats = []
    for i, src in enumerate(slide_list):
        print('\nSlide {}'.format(i))
        basename = os.path.basename(src).replace('.svs', '')
        fgpth = os.path.join(args.fg, '{}_fg.png'.format(basename))
        if os.path.exists(fgpth):
            ramdisk_path = transfer_to_ramdisk(
                src, args.ramdisk)  # never use the original src
            print('Using fg image at : {}'.format(fgpth))
            fgimg = cv2.imread(fgpth, 0)
            try:
                svs = Slide(
                    slide_path=ramdisk_path,
                    # background_speed  = 'accurate',
                    background_speed='image',
                    background_image=fgimg,
                    preprocess_fn=lambda x: (x / 255.).astype(np.float32),
                    normalize_fn=lambda x: x,
                    process_mag=args.mag,
                    process_size=args.input_dim,
                    oversample_factor=args.oversample,
                    verbose=False)
            except Exception as e:
                print(e)
                print(
                    'Caught SVS related error. Cleaning ramdisk and continuing.'
                )
                print('Cleaning file: {}'.format(ramdisk_path))
                os.remove(ramdisk_path)
                continue
        else:
            print(fgpth)
            continue

        svs.initialize_output(name='attention', dim=1, mode='tile')
        yhat, att, indices = process_slide(svs, model, args)

        yhats.append(yhat)
        print('\tSlide predicted: {}'.format(yhat))

        if args.mil == 'average':
            print('Average MIL; continuing')
        elif args.mil in ['attention', 'instance']:
            print('Placing values ranged {:3.3f} - {:3.3f}'.format(
                att.min(), att.max()))
            print('Visualizing mean {:3.5f}'.format(np.mean(att)))
            print('Visualizing std {:3.5f}'.format(np.std(att)))
            svs.place_batch(att, indices, 'attention', mode='tile')
            attention_img = np.squeeze(svs.output_imgs['attention'])
            attention_img_raw = np.squeeze(svs.output_imgs['attention'])

            attention_img = attention_img * (1. / attention_img.max())
            attention_img = draw_attention(attention_img, n_bins=50)
            print('attention image:', attention_img.shape, attention_img.dtype,
                  attention_img.min(), attention_img.max())

            dst = os.path.join(args.o, '{}_att.npy'.format(basename))
            np.save(dst, attention_img_raw)
            dst = os.path.join(args.o, '{}_img.png'.format(basename))
            cv2.imwrite(dst, attention_img)

        yhat_dst = os.path.join(args.o, '{}_ypred.npy'.format(basename))
        np.save(yhat_dst, yhat)

        try:
            svs.close()
            os.remove(ramdisk_path)
            del svs
        except:
            print('{} already removed'.format(ramdisk_path))
示例#9
0
def main(args):
    # Translate obfuscated file names to paths if necessary
    test_list = os.path.join(args.testdir, '{}.txt'.format(args.timestamp))
    test_list = read_test_list(test_list)
    test_unique_ids = [
        os.path.basename(x).replace('.npy', '') for x in test_list
    ]
    if args.randomize:
        np.random.shuffle(test_unique_ids)

    if args.max_slides:
        test_unique_ids = test_unique_ids[:args.max_slides]

    slide_list, slide_labels = get_slidelist_from_uids(test_unique_ids)

    print('Found {} slides'.format(len(slide_list)))

    snapshot = os.path.join(args.savedir, '{}.h5'.format(args.timestamp))
    # trained_model = load_model(snapshot)
    # if args.mcdropout:
    #   encoder_args['mcdropout'] = True

    encoder_args = get_encoder_args(args.encoder)
    model = MilkEager(encoder_args=encoder_args,
                      mil_type=args.mil,
                      batch_size=args.batch_size,
                      deep_classifier=args.deep_classifier,
                      temperature=args.temperature)

    x_pl = np.zeros((1, args.batch_size, args.input_dim, args.input_dim, 3),
                    dtype=np.float32)
    yhat = model(tf.constant(x_pl), verbose=True)
    print('yhat:', yhat.shape)

    print('setting model weights')
    model.load_weights(snapshot, by_name=True)

    ## Loop over found slides:
    yhats = []
    ytrues = []
    for i, (src, lab) in enumerate(zip(slide_list, slide_labels)):
        print('\nSlide {}'.format(i))
        basename = os.path.basename(src).replace('.svs', '')
        fgpth = os.path.join(args.fgdir, '{}_fg.png'.format(basename))
        if os.path.exists(fgpth):
            ramdisk_path = transfer_to_ramdisk(
                src, args.ramdisk)  # never use the original src
            print('Using fg image at : {}'.format(fgpth))
            fgimg = cv2.imread(fgpth, 0)
            svs = Slide(
                slide_path=ramdisk_path,
                # background_speed  = 'accurate',
                background_speed='image',
                background_image=fgimg,
                # preprocess_fn     = lambda x: (reinhard(x)/255.).astype(np.float32),
                preprocess_fn=lambda x: (x / 255.).astype(np.float32),
                process_mag=args.mag,
                process_size=args.input_dim,
                oversample_factor=args.oversample,
                verbose=False)
        else:
            ## require precomputed background; Exit.
            print('Required foreground image not found ({})'.format(fgpth))
            continue

        svs.initialize_output(name='attention', dim=1, mode='tile')
        n_tiles = len(svs.tile_list)

        yhat, att, indices = process_slide(svs, model, args)
        print('returned attention:', np.min(att), np.max(att), att.shape)

        yhat = yhat.numpy()
        yhats.append(yhat)
        ytrues.append(lab)
        print('\tSlide label: {} predicted: {}'.format(lab, yhat))

        svs.place_batch(att, indices, 'attention', mode='tile')
        attention_img = np.squeeze(svs.output_imgs['attention'])
        attention_img = attention_img * (1. / attention_img.max())
        attention_img = draw_attention(attention_img, n_bins=25)
        print('attention image:', attention_img.shape, attention_img.dtype,
              attention_img.min(), attention_img.max())

        dst = os.path.join(
            args.odir, args.timestamp,
            '{}_{}_{:3.3f}_att.npy'.format(basename, lab, yhat[0, 1]))
        np.save(dst, att)

        dst = os.path.join(
            args.odir, args.timestamp,
            '{}_{}_{:3.3f}_img.png'.format(basename, lab, yhat[0, 1]))
        cv2.imwrite(dst, attention_img)

        try:
            svs.close()
            os.remove(ramdisk_path)
        except:
            print('{} already removed'.format(ramdisk_path))

    yhats = np.concatenate(yhats, axis=0)
    ytrues = np.array(ytrues)
    acc = (np.argmax(yhats, axis=-1) == ytrues).mean()
    print(acc)
示例#10
0
def main(args):
    transform_fn = data_utils.make_transform_fn(args.x_size,
                                                args.y_size,
                                                args.crop_size,
                                                args.scale,
                                                flip=False,
                                                middle_crop=True,
                                                rotate=False,
                                                normalize=True)

    snapshot = 'save/{}.h5'.format(args.timestamp)
    # trained_model = load_model(snapshot)

    encoder_args = get_encoder_args(args.encoder)
    if args.mcdropout:
        encoder_args['mcdropout'] = True

    print('Model initializing')
    model = MilkEager(encoder_args=encoder_args,
                      mil_type=args.mil,
                      deep_classifier=args.deep_classifier,
                      cls_normalize=args.cls_normalize,
                      batch_size=args.batch_size,
                      temperature=args.temperature)
    xdummy = tf.zeros((1, args.batch_size, args.x_size, args.y_size, 3))
    ydummy = model(xdummy, verbose=True)
    print(xdummy.shape, ydummy.shape)
    del xdummy, ydummy
    model.load_weights(snapshot, by_name=True)

    test_list = os.path.join(args.testdir, '{}.txt'.format(args.timestamp))
    test_list = read_test_list(test_list)

    ytrues = []
    yhats = []
    print('MC Dropout: {}'.format(args.mcdropout))
    for _ in range(args.n_repeat):
        for test_case in test_list:
            case_x, case_y = data_utils.load(
                data_path=test_case,
                transform_fn=transform_fn,
                case_label_fn=case_label_fn,
                all_tiles=True,
            )
            # case_x = np.squeeze(case_x, axis=0)
            print('Running case x: ', case_x.shape)
            yhat = run_sample(case_x,
                              model,
                              mcdropout=args.mcdropout,
                              sample_mode=args.sample_mode)
            ytrues.append(case_y)
            yhats.append(yhat)
            print(test_case, case_y, case_x.shape, yhat)

    ytrue = np.concatenate(ytrues, axis=0)
    yhat = np.concatenate(yhats, axis=0)

    ytrue_max = np.argmax(ytrue, axis=-1)
    yhat_max = np.argmax(yhat, axis=-1)
    accuracy = (ytrue_max == yhat_max).mean()
    print('Accuracy: {:3.3f}'.format(accuracy))

    if args.odir is not None:
        save_img = os.path.join(args.odir, '{}.png'.format(args.timestamp))
        save_metrics = os.path.join(args.odir, '{}.txt'.format(args.timestamp))
        save_yhat = os.path.join(args.odir, '{}.npy'.format(args.timestamp))
        save_ytrue = os.path.join(args.odir,
                                  '{}_ytrue.npy'.format(args.timestamp))
        np.save(save_yhat, yhat)
        np.save(save_ytrue, ytrue)
    else:
        save_img = None
        save_metrics = None
        save_yhat = None

    auc_curve(ytrue, yhat, savepath=save_img)
    test_eval(ytrue, yhat, savepath=save_metrics)