Ejemplo n.º 1
0
def main(args):
    transform_fn = data_utils.make_transform_fn(128,
                                                128,
                                                args.input_dim,
                                                1.0,
                                                normalize=True)

    snapshot = os.path.join('../experiment/save',
                            '{}.h5'.format(args.timestamp))
    test_list = os.path.join('../experiment/test_lists',
                             '{}.txt'.format(args.timestamp))

    model = MilkEager(encoder_args=encoder_args,
                      mil_type=args.mil,
                      deep_classifier=True)
    x_dummy = tf.zeros(
        shape=[1, args.batch_size, args.input_dim, args.input_dim, 3],
        dtype=tf.float32)
    retvals = model(x_dummy, verbose=True)
    model.load_weights(snapshot, by_name=True)
    model.summary()

    test_list = read_test_list(test_list)

    yhats, ytrues = [], []
    features_case, features_classifier = [], []
    for test_case in test_list:
        case_name = os.path.basename(test_case).replace('.npy', '')
        print(test_case, case_name)
        # case_path = os.path.join('../dataset/tiles_reduced', '{}.npy'.format(case_name))
        case_x = np.load(test_case)
        case_x = np.stack([transform_fn(x) for x in case_x], 0)
        print(case_x.shape)

        if args.sample:
            case_x = case_x[
                np.random.choice(range(case_x.shape[0]), args.sample), ...]
            print(case_x.shape)

        features = model.encode_bag(case_x,
                                    batch_size=args.batch_size,
                                    training=True,
                                    return_z=True)
        print('features:', features.shape)
        # features_att, attention = model.mil_attention(features, return_att=True, training=False)
        # print('features:', features_att.shape, 'attention:', attention.shape)

        features_att = tf.reduce_mean(features, axis=0, keep_dims=True)
        yhat_instances = model.apply_classifier(features, training=False)
        print('yhat instances:', yhat_instances.shape)
        yhat = model.apply_classifier(features_att, training=False)
        print('yhat:', yhat.shape)

        yhat_1 = yhat_instances[:, 1].numpy()
        # yhat, attention, features, feat_case, feat_class = retvals
        # attention = np.squeeze(attention.numpy(), axis=0)
        high_att_idx, high_att_imgs, low_att_idx, low_att_imgs = get_attention_extremes(
            yhat_1, case_x, n=5)

        print('Case {}: predicted={} ({})'.format(test_case,
                                                  np.argmax(yhat, axis=-1),
                                                  yhat.numpy()))

        features = features.numpy()
        savepath = '{}_{}_{:3.2f}_yhat.png'.format(args.savebase, case_name,
                                                   yhat[0, 1])
        print('Saving figure {}'.format(savepath))
        z = draw_projection(features, features_att, yhat_1, savepath=savepath)

        savepath = '{}_{}_{:3.2f}_yhat_img.png'.format(args.savebase,
                                                       case_name, yhat[0, 1])
        print('Saving figure {}'.format(savepath))
        draw_projection_with_images(z,
                                    yhat_1,
                                    high_att_idx,
                                    high_att_imgs,
                                    low_att_idx,
                                    low_att_imgs,
                                    savepath=savepath)
Ejemplo n.º 2
0
def main(args):
  transform_fn = data_utils.make_transform_fn(128, 128, args.input_dim, 1.0, normalize=True)

  snapshot = os.path.join(args.snapshot_dir, '{}.h5'.format(args.timestamp))
  test_list = os.path.join(args.test_list_dir, '{}.txt'.format(args.timestamp))

  encoder_args = get_encoder_args(args.encoder)
  model = MilkEager(encoder_args=encoder_args, 
                    deep_classifier=True, 
                    mil_type='instance',
                    batch_size=args.batch_size,
                    temperature=args.temperature,
                    cls_normalize=args.cls_normalize)

  x_dummy = tf.zeros(shape=[1, args.batch_size, args.input_dim, args.input_dim, 3], dtype=tf.float32)
  retvals = model(x_dummy, verbose=True)
  model.load_weights(snapshot, by_name=True)
  model.summary()

  test_list = read_test_list(test_list)
  savebase = os.path.join(args.odir, args.timestamp)
  if os.path.exists(savebase):
    shutil.rmtree(savebase)
  os.makedirs(savebase)

  yhats, ytrues = [], []
  features_case, features_classifier = [], []
  for test_case in test_list:
    case_name = os.path.basename(test_case).replace('.npy', '')
    print(test_case, case_name)
    # case_path = os.path.join('../dataset/tiles_reduced', '{}.npy'.format(case_name))
    case_x = np.load(test_case)
    case_x = np.stack([transform_fn(x) for x in case_x], 0)
    ytrue = case_dict[case_name]
    print(case_x.shape, ytrue)
    ytrues.append(ytrue)

    if args.sample:
      case_x = case_x[np.random.choice(range(case_x.shape[0]), args.sample), ...]
      print(case_x.shape)

    # TODO variable names. attention --> something else
    features, attention = model.encode_bag(case_x, training=False, return_z=True)
    yhat = np.mean(attention, axis=0, keepdims=True)
    attention = attention.numpy()[:,1]
    print('features:', features.shape)
    print('attention:', attention.shape)
    print('yhat:', yhat.shape)
    # features_att, attention = model.mil_attention(features, return_att=True, training=False)
    # print('features:', features_att.shape, 'attention:', attention.shape)

    features_avg = np.mean(features, axis=0, keepdims=True)
    features_att = features_avg

    yhats.append(yhat)

    # yhat, attention, features, feat_case, feat_class = retvals
    # attention = np.squeeze(attention.numpy(), axis=0)
    high_att_idx, high_att_imgs, low_att_idx, low_att_imgs = get_attention_extremes(
      attention, case_x, n = 5)

    print('Case {}: predicted={} ({})'.format( test_case, np.argmax(yhat, axis=-1) , yhat))

    features = features.numpy()
    savepath = os.path.join(savebase, '{}_{:3.2f}.png'.format(case_name, yhat[0,1]))
    print('Saving figure {}'.format(savepath))
    z = draw_projection(features, features_avg, features_att, attention, savepath=savepath)

    # savepath = os.path.join(savebase, '{}_{:3.2f}_ys.png'.format(case_name, yhat[0,1]))
    # print('Saving figure {}'.format(savepath))
    # draw_projection_with_images(z, yhat_instances[:,1].numpy(), 
    #   high_att_idx, high_att_imgs, 
    #   low_att_idx, low_att_imgs, 
    #   savepath=savepath)

    savepath = os.path.join(savebase, '{}_{:3.2f}_imgs.png'.format(case_name, yhat[0,1]))
    print('Saving figure {}'.format(savepath))
    draw_projection_with_images(z, attention, 
      high_att_idx, high_att_imgs, 
      low_att_idx, low_att_imgs, 
      savepath=savepath)

    savepath = os.path.join(savebase, '{}_atns.npy'.format(case_name))
    np.save(savepath, attention)
    savepath = os.path.join(savebase, '{}_feat.npy'.format(case_name))
    np.save(savepath, features)

  yhats = np.concatenate(yhats, axis=0)
  yhats = np.argmax(yhats, axis=1)
  ytrues = np.array(ytrues)
  acc = (yhats == ytrues).mean()
  print(acc)
  cm = confusion_matrix(y_true=ytrues, y_pred=yhats)
  print(cm)
Ejemplo n.º 3
0
def main(args):
  """ 
  1. Create generator datasets from the provided lists
  2. train and validate Milk

  v0 - create datasets within this script
  v1 - factor monolithic training_utils.mil_train_loop !!
  tpu - replace data feeder and mil_train_loop with tf.keras.Model.fit()
  """
  # Take care of passed in test and val lists for the ensemble experiment
  # we need both test list and val list to be given.
  if (args.test_list is not None) and (args.val_list is not None):
    train_list, val_list, test_list = load_lists(
      os.path.join(args.data_patt, '*.npy'), 
      args.val_list, args.test_list)
  else:
    train_list, val_list, test_list = data_utils.list_data(
      os.path.join(args.data_patt, '*.npy'), 
      val_pct=args.val_pct, 
      test_pct=args.test_pct, 
      seed=args.seed)
  
  if args.verbose:
    print("train_list:")
    print(train_list)
    print("val_list:")
    print(val_list)
    print("test_list:")
    print(test_list)

  ## Filter out unwanted samples:
  train_list = filter_list_by_label(train_list)
  val_list = filter_list_by_label(val_list)
  test_list = filter_list_by_label(test_list)

  train_list = data_utils.enforce_minimum_size(train_list, args.bag_size, verbose=True)
  val_list = data_utils.enforce_minimum_size(val_list, args.bag_size, verbose=True)
  test_list = data_utils.enforce_minimum_size(test_list, args.bag_size, verbose=True)
  transform_fn = data_utils.make_transform_fn(args.x_size, 
                                              args.y_size, 
                                              args.crop_size, 
                                              args.scale, 
                                              normalize=True)
  # train_x, train_y = data_utils.load_list_to_memory(train_list, case_label_fn)
  # val_x, val_y = data_utils.load_list_to_memory(val_list, case_label_fn)

  # train_generator = data_utils.generate_from_memory(train_x, train_y, 
  #     batch_size=args.batch_size,
  #     bag_size=args.bag_size, 
  #     transform_fn=transform_fn,)
  # val_generator = data_utils.generate_from_memory(val_x, val_y, 
  #     batch_size=args.batch_size,
  #     bag_size=args.bag_size, 
  #     transform_fn=transform_fn,)

  # train_generator = subset_and_generate(train_list, case_label_fn, transform_fn, args, pct=0.5)
  # val_generator = subset_and_generate(val_list, case_label_fn, transform_fn, args, pct=1.)

  train_sequence = data_utils.MILSequence(train_list, 0.5, args.batch_size, args.bag_size, args.steps_per_epoch,
    case_label_fn, transform_fn, pad_first_dim=True)
  val_sequence = data_utils.MILSequence(val_list, 1., args.batch_size, args.bag_size, 100,
    case_label_fn, transform_fn, pad_first_dim=True)

  # print('Testing batch generator')
  # ## Some api change between nightly built TF and R1.5
  # x, y = next(train_generator)
  # print('x: ', x.shape)
  # print('y: ', y.shape)
  # del x
  # del y 

  print('Model initializing')
  encoder_args = get_encoder_args(args.encoder)
  model = Milk(input_shape=(args.bag_size, args.crop_size, args.crop_size, 3), 
               encoder_args=encoder_args, mode=args.mil, use_gate=args.gated_attention,
               temperature=args.temperature, freeze_encoder=args.freeze_encoder, 
               deep_classifier=args.deep_classifier)
  
  if args.tpu:
    # Need to use tensorflow optimizer
    optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
  else:
    # optimizer = tf.keras.optimizers.Adam(lr=args.learning_rate, decay=1e-6)
    optimizer = training_utils.AdamAccumulate(lr=args.learning_rate,
                               accum_iters=args.accumulate)

  exptime = datetime.datetime.now()
  exptime_str = exptime.strftime('%Y_%m_%d_%H_%M_%S')
  out_path = os.path.join(args.save_prefix, '{}.h5'.format(exptime_str))
  if not os.path.exists(os.path.dirname(out_path)):
    os.makedirs(os.path.dirname(out_path)) 

  # Todo : clean up
  val_list_file = os.path.join('./val_lists', '{}.txt'.format(exptime_str))
  with open(val_list_file, 'w+') as f:
    for v in val_list:
      f.write('{}\n'.format(v))

  test_list_file = os.path.join('./test_lists', '{}.txt'.format(exptime_str))
  with open(test_list_file, 'w+') as f:
    for v in test_list:
      f.write('{}\n'.format(v))

  ## Write out arguments passed for this session
  arg_file = os.path.join('./args', '{}.txt'.format(exptime_str))
  with open(arg_file, 'w+') as f:
    for a in vars(args):
      f.write('{}\t{}\n'.format(a, getattr(args, a)))

  ## Transfer to TPU 
  if args.tpu:
    print('Setting up model on TPU')
    if 'COLAB_TPU_ADDR' not in os.environ:
      print('ERROR: Not connected to a TPU runtime!')
    else:
      tpu_address = 'grpc://' + os.environ['COLAB_TPU_ADDR']
      print ('TPU address is', tpu_address)
    tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu=tpu_address)
    strategy = tf.contrib.tpu.TPUDistributionStrategy(tpu_cluster_resolver)
    model = tf.contrib.tpu.keras_to_tpu_model(model, strategy)

  model.compile(optimizer=optimizer,
                loss=tf.keras.losses.categorical_crossentropy,
                metrics=['categorical_accuracy'])
  model.summary()

  ## Replace randomly initialized weights after model is compiled and on the correct device.
  if args.pretrained_model is not None and os.path.exists(args.pretrained_model):
    print('Replacing random weights with weights from {}'.format(args.pretrained_model))
    model.load_weights(args.pretrained_model, by_name=True)

  if args.early_stop:
    callbacks = [
        tf.keras.callbacks.EarlyStopping(monitor = 'val_loss', 
                                         min_delta = 0.00001, 
                                         patience = 5, 
                                         verbose = 1, 
                                         mode = 'auto',)
    ]
  else:
    callbacks = []

  try:
    # refresh the data generator with a new subset each epoch
    # test on the same validation data
    # for epc in range(args.epochs):
      # train_generator = []
      # train_generator = subset_and_generate(train_list, case_label_fn, transform_fn, args, pct=0.25)
    model.fit_generator(generator=train_sequence,
                        validation_data=val_sequence,
                        #steps_per_epoch=args.steps_per_epoch, 
                        epochs=args.epochs,
                        workers=8,
                        use_multiprocessing=True,
                        callbacks=callbacks, )

  except KeyboardInterrupt:
    print('Keyboard interrupt caught')
  except Exception as e:
    print('Other error caught')
    print(type(e))
    print(e)
  finally:
    model.save(out_path)
    print('Saved model: {}'.format(out_path))
    print('Training done. Find val and test datasets at')
    print(val_list_file)
    print(test_list_file)
Ejemplo n.º 4
0
            s=5.0, marker='x', alpha=0.5,
            label='M{} Incorrect'.format(c))

    plt.legend(frameon=True, fontsize=8)
    plt.xticks([])
    plt.yticks([])
    plt.title('Case Projections')

    if savepath is None:
        plt.show()
    else:
        plt.savefig(savepath, bbox_inches='tight')
    plt.close()


transform_fn = data_utils.make_transform_fn(X_SIZE, Y_SIZE, CROP_SIZE, SCALE, 
    normalize=True)
def main(args):
    model = Milk()
    x_dummy = tf.zeros(shape=[MIN_BAG, CROP_SIZE, CROP_SIZE, 3], 
                        dtype=tf.float32)
    retvals = model(x_dummy, verbose=True, return_embedding=True)
    for k, retval in enumerate(retvals):
        print('retval {}: {}'.format(k, retval.shape))

    saver = tfe.Saver(model.variables)
    if args.snapshot is None:
        snapshot = tf.train.latest_checkpoint(args.snapshot_dir)
    else:
        snapshot = args.snapshot
    print('Restoring from {}'.format(snapshot))
Ejemplo n.º 5
0
def main(args):
  transform_fn = data_utils.make_transform_fn(args.x_size, args.y_size, 
                                              args.crop_size, args.scale)

  snapshot = 'save/{}.h5'.format(args.timestamp)
  # trained_model = load_model(snapshot)

  if args.mcdropout:
    encoder_args['mcdropout'] = True

  print('Model initializing')
  encode_model = MilkEncode(input_shape=(args.crop_size, args.crop_size, 3), 
                            encoder_args=encoder_args, deep_classifier=args.deep_classifier)
  encode_shape = list(encode_model.output.shape)
  x_pl = tf.placeholder(shape=(None, args.input_dim, args.input_dim, 3), dtype=tf.float32)
  z_op = encode_model(x_pl)

  input_shape = z_op.shape[-1]
  predict_model = MilkPredict(input_shape=[input_shape], mode=args.mil)

  print('loading weights from {}'.format(snapshot))
  encode_model.load_weights(snapshot, by_name=True)
  predict_model.load_weights(snapshot, by_name=True)

  # models = model_utils.make_inference_functions(encode_model,
  #                                               predict_model,
  #                                               trained_model, 
  #                                               attention_model=None)
  # encode_model, predict_model = models

  test_list = os.path.join(args.testdir, '{}.txt'.format(args.timestamp))
  test_list = read_test_list(test_list)

  ytrues = []
  yhats = []
  print('MC Dropout: {}'.format(args.mcdropout))
  for _ in range(args.n_repeat):
    for test_case in test_list:
      case_x, case_y = data_utils.load(
          data_path = test_case,
          transform_fn=transform_fn,
          case_label_fn=case_label_fn,
          all_tiles=True,
          )
      case_x = np.squeeze(case_x, axis=0)
      print('Running case x: ', case_x.shape)

      yhat = run_sample(case_x, encode_model, predict_model,
                        mcdropout=args.mcdropout,
                        batch_size=args.batch_size)
      ytrues.append(case_y)
      yhats.append(yhat)
      print(test_case, case_y, case_x.shape, yhat)

  ytrue = np.concatenate(ytrues, axis=0)
  yhat = np.concatenate(yhats, axis=0)

  ytrue_max = np.argmax(ytrue, axis=-1)
  yhat_max = np.argmax(yhat, axis=-1)
  accuracy = (ytrue_max == yhat_max).mean()
  print('Accuracy: {:3.3f}'.format(accuracy))

  if args.odir is not None:
    save_img = os.path.join(args.odir, '{}.png'.format(args.timestamp))
    save_metrics = os.path.join(args.odir, '{}.txt'.format(args.timestamp))
    save_yhat = os.path.join(args.odir, '{}.npy'.format(args.timestamp))
    save_ytrue = os.path.join(args.odir, '{}_ytrue.npy'.format(args.timestamp))
    np.save(save_yhat, yhat)
    np.save(save_ytrue, ytrue)
  else:
    save_img = None
    save_metrics = None
    save_yhat = None

  auc_curve(ytrue, yhat, savepath=save_img)
  test_eval(ytrue, yhat, savepath=save_metrics)
Ejemplo n.º 6
0
def main(args):
    transform_fn = data_utils.make_transform_fn(args.x_size,
                                                args.y_size,
                                                args.crop_size,
                                                args.scale,
                                                flip=False,
                                                middle_crop=True,
                                                rotate=False,
                                                normalize=True)

    snapshot = 'save/{}.h5'.format(args.timestamp)
    # trained_model = load_model(snapshot)

    encoder_args = get_encoder_args(args.encoder)
    if args.mcdropout:
        encoder_args['mcdropout'] = True

    print('Model initializing')
    model = MilkEager(encoder_args=encoder_args,
                      mil_type=args.mil,
                      deep_classifier=args.deep_classifier,
                      cls_normalize=args.cls_normalize,
                      batch_size=args.batch_size,
                      temperature=args.temperature)
    xdummy = tf.zeros((1, args.batch_size, args.x_size, args.y_size, 3))
    ydummy = model(xdummy, verbose=True)
    print(xdummy.shape, ydummy.shape)
    del xdummy, ydummy
    model.load_weights(snapshot, by_name=True)

    test_list = os.path.join(args.testdir, '{}.txt'.format(args.timestamp))
    test_list = read_test_list(test_list)

    ytrues = []
    yhats = []
    print('MC Dropout: {}'.format(args.mcdropout))
    for _ in range(args.n_repeat):
        for test_case in test_list:
            case_x, case_y = data_utils.load(
                data_path=test_case,
                transform_fn=transform_fn,
                case_label_fn=case_label_fn,
                all_tiles=True,
            )
            # case_x = np.squeeze(case_x, axis=0)
            print('Running case x: ', case_x.shape)
            yhat = run_sample(case_x,
                              model,
                              mcdropout=args.mcdropout,
                              sample_mode=args.sample_mode)
            ytrues.append(case_y)
            yhats.append(yhat)
            print(test_case, case_y, case_x.shape, yhat)

    ytrue = np.concatenate(ytrues, axis=0)
    yhat = np.concatenate(yhats, axis=0)

    ytrue_max = np.argmax(ytrue, axis=-1)
    yhat_max = np.argmax(yhat, axis=-1)
    accuracy = (ytrue_max == yhat_max).mean()
    print('Accuracy: {:3.3f}'.format(accuracy))

    if args.odir is not None:
        save_img = os.path.join(args.odir, '{}.png'.format(args.timestamp))
        save_metrics = os.path.join(args.odir, '{}.txt'.format(args.timestamp))
        save_yhat = os.path.join(args.odir, '{}.npy'.format(args.timestamp))
        save_ytrue = os.path.join(args.odir,
                                  '{}_ytrue.npy'.format(args.timestamp))
        np.save(save_yhat, yhat)
        np.save(save_ytrue, ytrue)
    else:
        save_img = None
        save_metrics = None
        save_yhat = None

    auc_curve(ytrue, yhat, savepath=save_img)
    test_eval(ytrue, yhat, savepath=save_metrics)