コード例 #1
0
ファイル: test_eager.py プロジェクト: nathanin/milk
def main(args):

  #with tf.device('/gpu:0'): 
  print('Model initializing')
  encoder_args = get_encoder_args(args.encoder)
  model = MilkEager(encoder_args=encoder_args, 
                    mil_type=args.mil,
                    deep_classifier=args.deep_classifier,
                    batch_size=16,
                    temperature=args.temperature,
                    heads=args.heads)
  print('Running once to load CUDA')
  x = tf.zeros((1, 1, args.crop_size, args.crop_size, args.channels))

  ## This is really weird. eager mode complains when this is a range()
  ## It evern complains when it's a list(range())
  ## If the tf.contrib.eager.defun decorator is removed, it's OK
  ## So it's an autograph problem 
  all_heads = [0,1,2,3,4,5,6,7,8,9]
  yhat = model(x.gpu(), heads=all_heads, training=True, verbose=True)
  model.summary()

  model.load_weights(args.pretrained_model, by_name=True)

  ## Set up the data stuff
  data_factory = MILDataset(args.dataset, crop=args.crop_size, n_classes=2)
  data_factory.split_train_val('case_id', seed=args.seed)
 
  test_iterator = data_factory.tensorflow_iterator(mode='test', seed=args.seed, 
                                                   batch_size=1,
                                                   buffer_size=1,
                                                   threads=1,
                                                   subset=args.bag_size, 
                                                   attr='stage_code', 
                                                   eager=True)

  ## Track yhats
  print('-------------------------------------------------------\n\n')
  ytrues , yhats = [], []
  all_heads = list(range(args.heads))
  for k, (x, y) in enumerate(test_iterator):

    # print('{:03d}: ytrue = {}'.format(k, y[0,1]))
    ytrues.append(y[0,1])

    yhat = model(x.gpu(), training=False, heads=all_heads)
    yhat = np.array([yh[0,1] for yh in yhat])
    yhats.append(yhat)
    # print('     yhat = {}'.format(yhat))

    # Take a running mean and show it
    acc, _ = calc_acc(ytrues, yhats)
    print('\r{:03d} Accuracy: {:3.3f} %'.format(k, acc), end='', flush = True)
    # sys.stdout.flush()

  
  print('\n\n-------------------------------------------------------')
  print_accuracy(ytrues, yhats)
  write_out(ytrues, yhats, args.out)
コード例 #2
0
ファイル: train_graph.py プロジェクト: nathanin/milk
def main(args):
    print(args)
    crop_size = int(args.input_dim / args.downsample)

    # Build the dataset
    dataset = ClassificationDataset(record_path=args.dataset,
                                    crop_size=crop_size,
                                    downsample=args.downsample,
                                    n_classes=args.n_classes,
                                    n_threads=args.n_threads,
                                    batch=args.batch_size,
                                    prefetch_buffer=args.prefetch_buffer,
                                    shuffle_buffer=args.shuffle_buffer,
                                    eager=False)

    # Test batch:
    encoder_args = get_encoder_args(args.encoder)
    model = Classifier(input_shape=(args.input_dim, args.input_dim, 3),
                       n_classes=args.n_classes,
                       encoder_args=encoder_args,
                       deep_classifier=True)

    ## Need tf.train for TPU's
    # optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
    optimizer = tf.keras.optimizers.Adam(lr=args.learning_rate, decay=1e-5)

    model.compile(optimizer=optimizer,
                  loss=tf.keras.losses.categorical_crossentropy,
                  metrics=['categorical_accuracy'])
    model.summary()

    try:
        model.fit(dataset.dataset.make_one_shot_iterator(),
                  steps_per_epoch=args.iterations,
                  epochs=args.epochs)
    except KeyboardInterrupt:
        print('Stop signal')
    finally:
        print('Saving')
        model.save(args.save_path)
コード例 #3
0
ファイル: pretrain_mnist.py プロジェクト: nathanin/milk
def main(args):
    (train_x, train_y), (test_x, test_y) = mnist.load_data()
    generator = generate_batch(train_x, train_y, args.batch)
    test_generator = generate_batch(test_x, test_y, args.batch)
    batch_x, batch_y = next(generator)
    print('batch:', batch_x.shape, batch_y.shape, batch_x.min(), batch_x.max())

    encoder_args = get_encoder_args(args.mnist)
    model = Classifier(input_shape=(28, 28, 1),
                       n_classes=10,
                       encoder_args=encoder_args)

    optimizer = tf.keras.optimizers.Adam(lr=args.lr, decay=args.decay)
    model.compile(optimizer=optimizer,
                  loss=tf.keras.losses.categorical_crossentropy,
                  metrics=['categorical_accuracy'])

    model.fit_generator(generator,
                        steps_per_epoch=args.steps_per_epoch,
                        epochs=args.epochs,
                        validation_data=test_generator,
                        validation_steps=50)
    model.save(args.o)
コード例 #4
0
def main(args):

    crop_size = int(args.input_dim / args.downsample)

    dataset = ClassificationDataset(record_path=args.test_data,
                                    crop_size=crop_size,
                                    downsample=args.downsample,
                                    n_classes=args.n_classes,
                                    batch=args.batch_size,
                                    prefetch_buffer=args.prefetch_buffer,
                                    repeats=args.repeats,
                                    eager=True)

    batchx, batchy = next(dataset.iterator)
    print('Test batch:')
    print('batchx: ', batchx.get_shape())
    print('batchy: ', batchy.get_shape())

    if args.load_model:
        model = load_model(args.snapshot, compile=False)
    else:
        encoder_args = get_encoder_args(args.encoder)
        model = ClassifierEager(
            encoder_args=encoder_args,
            n_classes=args.n_classes,
        )
        yhat = model(batchx, training=False, verbose=True)
        model.summary()
        model.load_weights(args.snapshot)

    # Loop:
    ytrue_vector, yhat_vector, features = [], [], []
    counter = 0

    for batchx, batchy in dataset.iterator:
        counter += 1
        # batchx, batchy = next(dataset.iterator)

        yhat_, feat_ = model(batchx,
                             return_features=True,
                             training=False,
                             verbose=True)

        ytrue_vector.append(batchy)
        yhat_vector.append(yhat_)
        features.append(feat_)

        if counter % 10 == 0:
            print(counter, 'ytrue:', batchy.shape, 'yhat:', yhat_.shape)
            print(np.argmax(yhat_, axis=-1))
            print(np.argmax(batchy, axis=-1))

        # except tf.errors.OutOfRangeError:
        #   break

    features = np.concatenate(features, axis=0)
    ytrue_vector = np.concatenate(ytrue_vector, axis=0)
    yhat_vector = np.concatenate(yhat_vector, axis=0)
    print('features: ', features.shape)
    print('ytrues: ', ytrue_vector.shape)
    print('yhats: ', yhat_vector.shape)

    ytrue_max = np.argmax(ytrue_vector, axis=-1)
    yhat_max = np.argmax(yhat_vector, axis=-1)

    ytrue_max[ytrue_max == 2] = 1
    yhat_max[yhat_max == 2] = 1

    accuracy = np.mean(ytrue_max == yhat_max)
    print('Accuracy: {:3.3f}'.format(accuracy))
    print(classification_report(y_true=ytrue_max, y_pred=yhat_max))
    auc_curves(ytrue_vector, yhat_vector, savepath=args.save)

    draw_projection(features, yhat_max, savepath=args.saveproj)
コード例 #5
0
def main(args):
    # Define a compute_fn that should do three things:
    # 1. define an iterator over the slide's tiles
    # 2. compute an output with given model parameter
    # 3.

    if args.iter_type == 'python':

        def compute_fn(slide, args, model=None):
            print('Slide with {}'.format(len(slide.tile_list)))
            it_factory = PythonIterator(slide, args)
            for k, (img, idx) in enumerate(it_factory.yield_batch()):
                prob = model(img)
                if k % 50 == 0:
                    print('Batch #{:04d} idx:{} img:{} prob:{}'.format(
                        k, idx.shape, img.shape, prob.shape))
                slide.place_batch(prob, idx, 'prob', mode='tile')
            ret = slide.output_imgs['prob']
            return ret

    # Tensorflow multithreaded queue-based iterator (in eager mode)
    elif args.iter_type == 'tf':

        def compute_fn(slide, args, model=None):
            assert tf.executing_eagerly()
            print('Slide with {}'.format(len(slide.tile_list)))

            # In eager mode, we return a tf.contrib.eager.Iterator
            eager_iterator = TensorflowIterator(slide, args).make_iterator()

            # The iterator can be used directly. Ququeing and multithreading
            # are handled in the backend by the tf.data.Dataset ops
            features, indices = [], []
            for k, (img, idx) in enumerate(eager_iterator):
                # img = tf.expand_dims(img, axis=0)
                features.append(
                    model.encode_bag(img, training=False, return_z=True))
                indices.append(idx.numpy())

                img, idx = img.numpy(), idx.numpy()
                if k % 50 == 0:
                    print('Batch #{:04d}\t{}'.format(k, img.shape))

            features = tf.concat(features, axis=0)
            z_att, att = model.mil_attention(features,
                                             training=False,
                                             return_raw_att=True)
            att = np.squeeze(att)
            indices = np.concatenate(indices)
            slide.place_batch(att, indices, 'att', mode='tile')
            ret = slide.output_imgs['att']
            return ret

    # Set up the model first
    encoder_args = get_encoder_args(args.encoder)
    model = MilkEager(encoder_args=encoder_args,
                      mil_type=args.mil,
                      deep_classifier=args.deep_classifier,
                      batch_size=args.batchsize,
                      temperature=args.temperature,
                      heads=args.heads)

    x = tf.zeros((1, 1, args.process_size, args.process_size, 3))
    _ = model(x, verbose=True, head='all', training=True)
    model.load_weights(args.snapshot, by_name=True)

    # keras Model subclass
    model.summary()

    # Read list of inputs
    with open(args.slides, 'r') as f:
        slides = [x.strip() for x in f]

    # Loop over slides
    for src in slides:
        # Dirty substitution of the file extension give us the
        # destination. Do this first so we can just skip the slide
        # if this destination already exists.
        # Set the --suffix option to reflect the model / type of processed output
        dst = repext(src, args.suffix)

        # Loading data from ramdisk incurs a one-time copy cost
        rdsrc = cpramdisk(src, args.ramdisk)
        print('File:', rdsrc)

        # Wrapped inside of a try-except-finally.
        # We want to make sure the slide gets cleaned from
        # memory in case there's an error or stop signal in the
        # middle of processing.
        try:
            # Initialze the side from our temporary path, with
            # the arguments passed in from command-line.
            # This returns an svsutils.Slide object
            slide = Slide(rdsrc, args)

            # This step will eventually be included in slide creation
            # with some default compute_fn's provided by svsutils
            # For now, do it case-by-case, and use the compute_fn
            # that we defined just above.
            slide.initialize_output('att',
                                    args.n_classes,
                                    mode='tile',
                                    compute_fn=compute_fn)

            # Call the compute function to compute this output.
            # Again, this may change to something like...
            #     slide.compute_all
            # which would loop over all the defined output types.
            ret = slide.compute('att', args, model=model)
            print('{} --> {}'.format(ret.shape, dst))
            np.save(dst, ret[:, :, ::-1])
        except Exception as e:
            print(e)
            traceback.print_tb(e.__traceback__)
        finally:
            print('Removing {}'.format(rdsrc))
            os.remove(rdsrc)
コード例 #6
0
def main(args):
  transform_fn = data_utils.make_transform_fn(128, 128, args.input_dim, 1.0, normalize=True)

  snapshot = os.path.join(args.snapshot_dir, '{}.h5'.format(args.timestamp))
  test_list = os.path.join(args.test_list_dir, '{}.txt'.format(args.timestamp))

  encoder_args = get_encoder_args(args.encoder)
  model = MilkEager(encoder_args=encoder_args, 
                    deep_classifier=True, 
                    mil_type='instance',
                    batch_size=args.batch_size,
                    temperature=args.temperature,
                    cls_normalize=args.cls_normalize)

  x_dummy = tf.zeros(shape=[1, args.batch_size, args.input_dim, args.input_dim, 3], dtype=tf.float32)
  retvals = model(x_dummy, verbose=True)
  model.load_weights(snapshot, by_name=True)
  model.summary()

  test_list = read_test_list(test_list)
  savebase = os.path.join(args.odir, args.timestamp)
  if os.path.exists(savebase):
    shutil.rmtree(savebase)
  os.makedirs(savebase)

  yhats, ytrues = [], []
  features_case, features_classifier = [], []
  for test_case in test_list:
    case_name = os.path.basename(test_case).replace('.npy', '')
    print(test_case, case_name)
    # case_path = os.path.join('../dataset/tiles_reduced', '{}.npy'.format(case_name))
    case_x = np.load(test_case)
    case_x = np.stack([transform_fn(x) for x in case_x], 0)
    ytrue = case_dict[case_name]
    print(case_x.shape, ytrue)
    ytrues.append(ytrue)

    if args.sample:
      case_x = case_x[np.random.choice(range(case_x.shape[0]), args.sample), ...]
      print(case_x.shape)

    # TODO variable names. attention --> something else
    features, attention = model.encode_bag(case_x, training=False, return_z=True)
    yhat = np.mean(attention, axis=0, keepdims=True)
    attention = attention.numpy()[:,1]
    print('features:', features.shape)
    print('attention:', attention.shape)
    print('yhat:', yhat.shape)
    # features_att, attention = model.mil_attention(features, return_att=True, training=False)
    # print('features:', features_att.shape, 'attention:', attention.shape)

    features_avg = np.mean(features, axis=0, keepdims=True)
    features_att = features_avg

    yhats.append(yhat)

    # yhat, attention, features, feat_case, feat_class = retvals
    # attention = np.squeeze(attention.numpy(), axis=0)
    high_att_idx, high_att_imgs, low_att_idx, low_att_imgs = get_attention_extremes(
      attention, case_x, n = 5)

    print('Case {}: predicted={} ({})'.format( test_case, np.argmax(yhat, axis=-1) , yhat))

    features = features.numpy()
    savepath = os.path.join(savebase, '{}_{:3.2f}.png'.format(case_name, yhat[0,1]))
    print('Saving figure {}'.format(savepath))
    z = draw_projection(features, features_avg, features_att, attention, savepath=savepath)

    # savepath = os.path.join(savebase, '{}_{:3.2f}_ys.png'.format(case_name, yhat[0,1]))
    # print('Saving figure {}'.format(savepath))
    # draw_projection_with_images(z, yhat_instances[:,1].numpy(), 
    #   high_att_idx, high_att_imgs, 
    #   low_att_idx, low_att_imgs, 
    #   savepath=savepath)

    savepath = os.path.join(savebase, '{}_{:3.2f}_imgs.png'.format(case_name, yhat[0,1]))
    print('Saving figure {}'.format(savepath))
    draw_projection_with_images(z, attention, 
      high_att_idx, high_att_imgs, 
      low_att_idx, low_att_imgs, 
      savepath=savepath)

    savepath = os.path.join(savebase, '{}_atns.npy'.format(case_name))
    np.save(savepath, attention)
    savepath = os.path.join(savebase, '{}_feat.npy'.format(case_name))
    np.save(savepath, features)

  yhats = np.concatenate(yhats, axis=0)
  yhats = np.argmax(yhats, axis=1)
  ytrues = np.array(ytrues)
  acc = (yhats == ytrues).mean()
  print(acc)
  cm = confusion_matrix(y_true=ytrues, y_pred=yhats)
  print(cm)
コード例 #7
0
def main(args):
    if args.mnist is not None:
        (train_x, train_y), (test_x, test_y) = mnist.load_data(args.mnist)
    else:
        (train_x, train_y), (test_x, test_y) = mnist.load_data()

    print('train_x:', train_x.shape, train_x.dtype, train_x.min(),
          train_x.max())
    print('train_y:', train_y.shape)
    print('test_x:', test_x.shape)
    print('test_y:', test_y.shape)

    positive_label = np.random.choice(range(10))
    print('using positive label = {}'.format(positive_label))

    train_x_pos, train_x_neg = rearrange_bagged_mnist(train_x, train_y,
                                                      positive_label)
    test_x_pos, test_x_neg = rearrange_bagged_mnist(test_x, test_y,
                                                    positive_label)
    print('rearranged training set:')
    print('\ttrain_x_pos:', train_x_pos.shape, train_x_pos.dtype,
          train_x_pos.min(), train_x_pos.max())
    print('\ttrain_x_neg:', train_x_neg.shape)
    print('\ttest_x_pos:', test_x_pos.shape)
    print('\ttest_x_neg:', test_x_neg.shape)

    generator = generate_bagged_mnist(train_x_pos, train_x_neg, args.n,
                                      args.batch_size)
    val_generator = generate_bagged_mnist(test_x_pos, test_x_neg, args.n,
                                          args.batch_size)
    batch_x, batch_y = next(generator)
    print('batch_x:', batch_x.shape, 'batch_y:', batch_y.shape)

    encoder_args = get_encoder_args('mnist')
    model = MilkEager(
        encoder_args=encoder_args,
        mil_type=args.mil,
        deep_classifier=True,
    )
    y_dummy = model(batch_x, verbose=True)
    model.summary()

    if args.pretrained is not None and os.path.exists(args.pretrained):
        model.load_weights(args.pretrained, by_name=True)
    else:
        print('Pretrained model not found ({}). Continuing end 2 end.'.format(
            args.pretrained))

    if args.gpus > 1:
        print('Duplicating model onto 2 GPUs')
        model = tf.keras.utils.multi_gpu_model(model,
                                               args.gpus,
                                               cpu_merge=True,
                                               cpu_relocation=False)

    optimizer = tf.train.AdamOptimizer(learning_rate=args.lr)

    try:
        for k in range(int(args.steps_per_epoch * args.epochs)):
            with tf.GradientTape() as tape:
                x, y = next(generator)
                yhat = model(tf.constant(x), training=True)
                loss = tf.keras.losses.categorical_crossentropy(
                    y_true=tf.constant(y, dtype=tf.float32), y_pred=yhat)

            grads = tape.gradient(loss, model.variables)
            optimizer.apply_gradients(zip(grads, model.variables))

            if k % 50 == 0:
                print('{:06d}: loss={:3.5f}'.format(k, np.mean(loss)))
                for y_, yh_ in zip(y, yhat):
                    print('\t{} {}'.format(y_, yh_))

    except KeyboardInterrupt:
        print('Keyboard interrupt caught')

    except Exception as e:
        print('Other error caught')
        print(type(e))
        print(e)

    finally:
        model.save_weights(args.o)
        print('Saved model: {}'.format(args.o))
コード例 #8
0
def main(args):
  """ 
  1. Create generator datasets from the provided lists
  2. train and validate Milk

  v0 - create datasets within this script
  v1 - factor monolithic training_utils.mil_train_loop !!
  tpu - replace data feeder and mil_train_loop with tf.keras.Model.fit()
  """
  # Take care of passed in test and val lists for the ensemble experiment
  # we need both test list and val list to be given.
  if (args.test_list is not None) and (args.val_list is not None):
    train_list, val_list, test_list = load_lists(
      os.path.join(args.data_patt, '*.npy'), 
      args.val_list, args.test_list)
  else:
    train_list, val_list, test_list = data_utils.list_data(
      os.path.join(args.data_patt, '*.npy'), 
      val_pct=args.val_pct, 
      test_pct=args.test_pct, 
      seed=args.seed)
  
  if args.verbose:
    print("train_list:")
    print(train_list)
    print("val_list:")
    print(val_list)
    print("test_list:")
    print(test_list)

  ## Filter out unwanted samples:
  train_list = filter_list_by_label(train_list)
  val_list = filter_list_by_label(val_list)
  test_list = filter_list_by_label(test_list)

  train_list = data_utils.enforce_minimum_size(train_list, args.bag_size, verbose=True)
  val_list = data_utils.enforce_minimum_size(val_list, args.bag_size, verbose=True)
  test_list = data_utils.enforce_minimum_size(test_list, args.bag_size, verbose=True)
  transform_fn = data_utils.make_transform_fn(args.x_size, 
                                              args.y_size, 
                                              args.crop_size, 
                                              args.scale, 
                                              normalize=True)
  # train_x, train_y = data_utils.load_list_to_memory(train_list, case_label_fn)
  # val_x, val_y = data_utils.load_list_to_memory(val_list, case_label_fn)

  # train_generator = data_utils.generate_from_memory(train_x, train_y, 
  #     batch_size=args.batch_size,
  #     bag_size=args.bag_size, 
  #     transform_fn=transform_fn,)
  # val_generator = data_utils.generate_from_memory(val_x, val_y, 
  #     batch_size=args.batch_size,
  #     bag_size=args.bag_size, 
  #     transform_fn=transform_fn,)

  # train_generator = subset_and_generate(train_list, case_label_fn, transform_fn, args, pct=0.5)
  # val_generator = subset_and_generate(val_list, case_label_fn, transform_fn, args, pct=1.)

  train_sequence = data_utils.MILSequence(train_list, 0.5, args.batch_size, args.bag_size, args.steps_per_epoch,
    case_label_fn, transform_fn, pad_first_dim=True)
  val_sequence = data_utils.MILSequence(val_list, 1., args.batch_size, args.bag_size, 100,
    case_label_fn, transform_fn, pad_first_dim=True)

  # print('Testing batch generator')
  # ## Some api change between nightly built TF and R1.5
  # x, y = next(train_generator)
  # print('x: ', x.shape)
  # print('y: ', y.shape)
  # del x
  # del y 

  print('Model initializing')
  encoder_args = get_encoder_args(args.encoder)
  model = Milk(input_shape=(args.bag_size, args.crop_size, args.crop_size, 3), 
               encoder_args=encoder_args, mode=args.mil, use_gate=args.gated_attention,
               temperature=args.temperature, freeze_encoder=args.freeze_encoder, 
               deep_classifier=args.deep_classifier)
  
  if args.tpu:
    # Need to use tensorflow optimizer
    optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
  else:
    # optimizer = tf.keras.optimizers.Adam(lr=args.learning_rate, decay=1e-6)
    optimizer = training_utils.AdamAccumulate(lr=args.learning_rate,
                               accum_iters=args.accumulate)

  exptime = datetime.datetime.now()
  exptime_str = exptime.strftime('%Y_%m_%d_%H_%M_%S')
  out_path = os.path.join(args.save_prefix, '{}.h5'.format(exptime_str))
  if not os.path.exists(os.path.dirname(out_path)):
    os.makedirs(os.path.dirname(out_path)) 

  # Todo : clean up
  val_list_file = os.path.join('./val_lists', '{}.txt'.format(exptime_str))
  with open(val_list_file, 'w+') as f:
    for v in val_list:
      f.write('{}\n'.format(v))

  test_list_file = os.path.join('./test_lists', '{}.txt'.format(exptime_str))
  with open(test_list_file, 'w+') as f:
    for v in test_list:
      f.write('{}\n'.format(v))

  ## Write out arguments passed for this session
  arg_file = os.path.join('./args', '{}.txt'.format(exptime_str))
  with open(arg_file, 'w+') as f:
    for a in vars(args):
      f.write('{}\t{}\n'.format(a, getattr(args, a)))

  ## Transfer to TPU 
  if args.tpu:
    print('Setting up model on TPU')
    if 'COLAB_TPU_ADDR' not in os.environ:
      print('ERROR: Not connected to a TPU runtime!')
    else:
      tpu_address = 'grpc://' + os.environ['COLAB_TPU_ADDR']
      print ('TPU address is', tpu_address)
    tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu=tpu_address)
    strategy = tf.contrib.tpu.TPUDistributionStrategy(tpu_cluster_resolver)
    model = tf.contrib.tpu.keras_to_tpu_model(model, strategy)

  model.compile(optimizer=optimizer,
                loss=tf.keras.losses.categorical_crossentropy,
                metrics=['categorical_accuracy'])
  model.summary()

  ## Replace randomly initialized weights after model is compiled and on the correct device.
  if args.pretrained_model is not None and os.path.exists(args.pretrained_model):
    print('Replacing random weights with weights from {}'.format(args.pretrained_model))
    model.load_weights(args.pretrained_model, by_name=True)

  if args.early_stop:
    callbacks = [
        tf.keras.callbacks.EarlyStopping(monitor = 'val_loss', 
                                         min_delta = 0.00001, 
                                         patience = 5, 
                                         verbose = 1, 
                                         mode = 'auto',)
    ]
  else:
    callbacks = []

  try:
    # refresh the data generator with a new subset each epoch
    # test on the same validation data
    # for epc in range(args.epochs):
      # train_generator = []
      # train_generator = subset_and_generate(train_list, case_label_fn, transform_fn, args, pct=0.25)
    model.fit_generator(generator=train_sequence,
                        validation_data=val_sequence,
                        #steps_per_epoch=args.steps_per_epoch, 
                        epochs=args.epochs,
                        workers=8,
                        use_multiprocessing=True,
                        callbacks=callbacks, )

  except KeyboardInterrupt:
    print('Keyboard interrupt caught')
  except Exception as e:
    print('Other error caught')
    print(type(e))
    print(e)
  finally:
    model.save(out_path)
    print('Saved model: {}'.format(out_path))
    print('Training done. Find val and test datasets at')
    print(val_list_file)
    print(test_list_file)
コード例 #9
0
def main(args):


  # Define a compute_fn that should do three things:
  # 1. define an iterator over the slide's tiles
  # 2. compute an output with a given model / arguments
  # 3. return a reconstructed slide
  def compute_fn(slide, args, model=None, n_dropout=10 ):
    assert tf.executing_eagerly()
    print('Slide with {}'.format(len(slide.tile_list)))

    # In eager mode, we return a tf.contrib.eager.Iterator
    eager_iterator = TensorflowIterator(slide, args).make_iterator()

    # The iterator can be used directly. Ququeing and multithreading
    # are handled in the backend by the tf.data.Dataset ops
    features, indices = [], []
    for k, (img, idx) in enumerate(eager_iterator):
      # img = tf.expand_dims(img, axis=0)
      features.append( model.encode_bag(img, training=False, return_z=True) )
      indices.append(idx.numpy())

      img, idx = img.numpy(), idx.numpy()
      if k % 50 == 0:
        print('Batch #{:04d}\t{}'.format(k, img.shape))

    features = tf.concat(features, axis=0)

    ## Sample-dropout
    # features = features.numpy()
    # print(features.shape)
    # n_instances = features.shape[0]
    # att = np.zeros(n_instances)
    # n_choice = int(n_instances * 0.7)
    # all_heads = list(range(args.heads))
    # for j in range(n_dropout):
    #   idx = np.random.choice(range(n_instances), n_choice, replace=False)
    #   print(idx)
    #   fdrop = features[idx, :]

    z_att, att = model.mil_attention(features,
                                     training=False, 
                                     return_raw_att=True)

    # att[idx] += np.squeeze(attdrop)
    yhat_multihead = model.apply_classifier(z_att, heads=all_heads, 
      training=False)
    print('yhat mean {}'.format(np.mean(yhat_multihead, axis=0)))

    indices = np.concatenate(indices)
    att = np.squeeze(att)
    slide.place_batch(att, indices, 'att', mode='tile')
    ret = slide.output_imgs['att']
    print('Got attention image: {}'.format(ret.shape))

    return ret, features.numpy()




  ## Begin main script:
  # Set up the model first
  encoder_args = get_encoder_args(args.encoder)
  model = MilkEager(encoder_args=encoder_args,
                    mil_type=args.mil,
                    deep_classifier=args.deep_classifier,
                    batch_size=args.batchsize,
                    temperature=args.temperature,
                    heads = args.heads)
  
  x = tf.zeros((1, 1, args.process_size,
                args.process_size, 3))
  all_heads = [0,1,2,3,4,5,6,7,8,9]
  _ = model(x, verbose=True, heads=all_heads, training=True)
  model.load_weights(args.snapshot, by_name=True)

  # keras Model subclass
  model.summary()

  # Read list of inputs
  with open(args.slides, 'r') as f:
    slides = [x.strip() for x in f]

  # Loop over slides
  for src in slides:
    # Dirty substitution of the file extension give us the
    # destination. Do this first so we can just skip the slide
    # if this destination already exists.
    # Set the --suffix option to reflect the model / type of processed output
    dst = repext(src, args.suffix)
    featdst = repext(src, args.suffix+'.feat.npy')

    # Loading data from ramdisk incurs a one-time copy cost
    rdsrc = cpramdisk(src, args.ramdisk)
    print('\n\nFile:', rdsrc)

    # Wrapped inside of a try-except-finally.
    # We want to make sure the slide gets cleaned from 
    # memory in case there's an error or stop signal in the 
    # middle of processing.
    try:
      # Initialze the side from our temporary path, with 
      # the arguments passed in from command-line.
      # This returns an svsutils.Slide object
      slide = Slide(rdsrc, args)

      # This step will eventually be included in slide creation
      # with some default compute_fn's provided by svsutils
      # For now, do it case-by-case, and use the compute_fn
      # that we defined just above.
      slide.initialize_output('att', args.n_classes, mode='tile',
        compute_fn=compute_fn)

      # Call the compute function to compute this output.
      # Again, this may change to something like...
      #     slide.compute_all
      # which would loop over all the defined output types.
      ret, features = slide.compute('att', args, model=model)
      print('{} --> {}'.format(ret.shape, dst))
      print('{} --> {}'.format(features.shape, featdst))
      np.save(dst, ret)
      np.save(featdst, features)
    except Exception as e:
      print(e)
      traceback.print_tb(e.__traceback__)
    finally:
      print('Removing {}'.format(rdsrc))
      os.remove(rdsrc)
コード例 #10
0
def main(args):
    """ 
  1. Create generator datasets from the provided lists
  2. train and validate Milk

  v0 - create datasets within this script
  v1 - factor monolithic training_utils.mil_train_loop !!
  tpu - replace data feeder and mil_train_loop with tf.keras.Model.fit()
  July 4 2019 - added MILDataset that takes away all the dataset nonsense
  """
    out_path, exptime_str = create_outputs(args)

    data_factory = MILDataset(args.dataset, crop=args.crop_size, n_classes=2)
    data_factory.split_train_val('case_id', seed=args.seed)

    #with tf.device('/gpu:0'):
    print('Model initializing')
    encoder_args = get_encoder_args(args.encoder)
    model = MilkEager(encoder_args=encoder_args,
                      mil_type=args.mil,
                      deep_classifier=args.deep_classifier,
                      batch_size=16,
                      temperature=args.temperature,
                      heads=args.heads)
    print('Running once to load CUDA')
    x = tf.zeros((1, 1, args.crop_size, args.crop_size, args.channels))

    ## The way we give the list is very particular.
    all_heads = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
    yhat = model(x, heads=all_heads, training=True, verbose=True)
    print('yhat: {} ({} {})'.format(yhat[0], yhat[0].shape, yhat[0].dtype))
    model.summary()

    ## Replace randomly initialized weights after model is compiled and on the correct device.
    if args.pretrained_model is not None and os.path.exists(
            args.pretrained_model):
        print('Replacing random weights with weights from {}'.format(
            args.pretrained_model))
        try:
            model.load_weights(args.pretrained_model, by_name=True)
        except Exception as e:
            print(e)

    ## Controlling overfitting by monitoring a metric, with some patience since the last improvement
    if args.early_stop:
        stopper = ShouldStop(patience=args.heads)
    else:
        stopper = lambda x: False

    # accumulator = GradientAccumulator(n = args.accumulate, variable_list=trainable_variables)

    optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
    trackers = {stat: [] for stat in ['loss', 'acc']}
    trackers['step'] = 0

    def py_it():
        return data_factory.python_iterator(mode='train',
                                            subset=args.bag_size,
                                            attr='stage_code',
                                            seed=None,
                                            epochs=args.epochs)

    train_len = data_factory.dataset_lengths['train']
    tf_dataset = (
        tf.data.Dataset.from_generator(py_it,
                                       output_types=(tf.uint8, tf.uint8))
        # .repeat(repeats)
        .map(data_factory.map_fn, num_parallel_calls=args.threads)
        # .prefetch(buffer_size)
        .batch(args.batch_size))

    # data_factory.tensorflow_iterator(mode='train', ref='train', repeats=args.epochs,
    #   subset=args.bag_size, seed=None, attr='stage_code', threads=args.threads)
    # data_factory.tensorflow_iterator(mode='val', ref='val', repeats=args.epochs,
    #   subset=args.bag_size, seed=None, attr='stage_code', threads=args.threads)

    # val_iterator, val_len = data_factory.iterator_refs['val']
    # train_iterator, train_len = data_factory.iterator_refs['train']

    trainable_variables = model.trainable_variables
    try:
        # for epc in range(args.epochs):
        #   tf.set_random_seed(1)
        # trackers = train_epoch(model, optimizer, train_iterator, train_len, epc, trackers, args)
        # train_head = [epc % args.heads]

        avglosses, steptimes = [], []
        # print('\nTraining head {}'.format(train_head))
        for k, (x, y) in enumerate(tf_dataset):
            if k % train_len == 0:
                gc.collect()
                tf.set_random_seed(1)
                # tf.reset_default_graph()
                train_head = [np.random.choice(args.heads)]
                print('\nTraining head [{}]'.format(train_head))

            tstart = time.time()
            with tf.GradientTape() as tape:
                yhat = model(x, training=True, heads=train_head)
                loss = tf.keras.losses.categorical_crossentropy(
                    y_true=tf.constant(y, dtype=tf.float32), y_pred=yhat[0])

            grads = tape.gradient(loss, trainable_variables)
            del tape
            loss_mn = np.mean(loss)
            acc = eval_acc(y, yhat)
            avglosses.append(loss_mn)

            trackers['loss'].append(loss_mn)
            trackers['acc'].append(acc)
            trackers['step'] += 1

            tend = time.time()
            steptimes.append(tend - tstart)
            # if should_update:
            #   grads = accumulator.accumulate()
            # with tf.device('/cpu:0'):
            # print('Applying gradients')
            optimizer.apply_gradients(zip(grads, trainable_variables))
            print('\r{:07d}: loss={:3.5f} dt={:3.3f}s   '.format(
                k, np.mean(avglosses), np.mean(steptimes)),
                  end='',
                  flush=1)
            # if (k+1) % train_len == 0: break

            # if epc % args.snapshot_epochs == 0:
            #   snapshot_path = out_path.replace('.h5', '-{:03d}.h5'.format(epc))
            #   print('Snapshotting to {}'.format(snapshot_path))
            #   model.save_weights(snapshot_path)

    except KeyboardInterrupt:
        print('Keyboard interrupt caught')

    except Exception as e:
        print('Other error caught')
        print(e)
        traceback.print_tb(e.__traceback__)

    finally:
        model.save_weights(out_path)
        print('Saved model: {}'.format(out_path))

        # Save the loss profile
        training_stats = os.path.join(
            args.out_base, 'save',
            '{}_training_curves.txt'.format(exptime_str))
        print('Dumping training stats --> {}'.format(training_stats))
        with open(training_stats, 'w+') as f:
            for s, l, a in zip(np.arange(trackers['step']), trackers['loss'],
                               trackers['acc']):
                f.write('{:06d}\t{:3.5f}\t{:3.5f}\n'.format(s, l, a))
コード例 #11
0
ファイル: deploy_eager.py プロジェクト: nathanin/milk
def main(args):
    # Translate obfuscated file names to paths if necessary
    slide_list = read_list(args.f)
    print('Found {} slides'.format(len(slide_list)))

    encoder_args = get_encoder_args(args.encoder)
    model = MilkEager(encoder_args=encoder_args,
                      mil_type=args.mil,
                      batch_size=args.batch_size,
                      temperature=args.temperature,
                      deep_classifier=args.deep_classifier)

    x_pl = np.zeros((1, args.batch_size, args.input_dim, args.input_dim, 3),
                    dtype=np.float32)
    yhat = model(tf.constant(x_pl), verbose=True)
    print('yhat:', yhat.shape)

    print('setting model weights')
    model.load_weights(args.s, by_name=True)

    ## Loop over found slides:
    yhats = []
    for i, src in enumerate(slide_list):
        print('\nSlide {}'.format(i))
        basename = os.path.basename(src).replace('.svs', '')
        fgpth = os.path.join(args.fg, '{}_fg.png'.format(basename))
        if os.path.exists(fgpth):
            ramdisk_path = transfer_to_ramdisk(
                src, args.ramdisk)  # never use the original src
            print('Using fg image at : {}'.format(fgpth))
            fgimg = cv2.imread(fgpth, 0)
            try:
                svs = Slide(
                    slide_path=ramdisk_path,
                    # background_speed  = 'accurate',
                    background_speed='image',
                    background_image=fgimg,
                    preprocess_fn=lambda x: (x / 255.).astype(np.float32),
                    normalize_fn=lambda x: x,
                    process_mag=args.mag,
                    process_size=args.input_dim,
                    oversample_factor=args.oversample,
                    verbose=False)
            except Exception as e:
                print(e)
                print(
                    'Caught SVS related error. Cleaning ramdisk and continuing.'
                )
                print('Cleaning file: {}'.format(ramdisk_path))
                os.remove(ramdisk_path)
                continue
        else:
            print(fgpth)
            continue

        svs.initialize_output(name='attention', dim=1, mode='tile')
        yhat, att, indices = process_slide(svs, model, args)

        yhats.append(yhat)
        print('\tSlide predicted: {}'.format(yhat))

        if args.mil == 'average':
            print('Average MIL; continuing')
        elif args.mil in ['attention', 'instance']:
            print('Placing values ranged {:3.3f} - {:3.3f}'.format(
                att.min(), att.max()))
            print('Visualizing mean {:3.5f}'.format(np.mean(att)))
            print('Visualizing std {:3.5f}'.format(np.std(att)))
            svs.place_batch(att, indices, 'attention', mode='tile')
            attention_img = np.squeeze(svs.output_imgs['attention'])
            attention_img_raw = np.squeeze(svs.output_imgs['attention'])

            attention_img = attention_img * (1. / attention_img.max())
            attention_img = draw_attention(attention_img, n_bins=50)
            print('attention image:', attention_img.shape, attention_img.dtype,
                  attention_img.min(), attention_img.max())

            dst = os.path.join(args.o, '{}_att.npy'.format(basename))
            np.save(dst, attention_img_raw)
            dst = os.path.join(args.o, '{}_img.png'.format(basename))
            cv2.imwrite(dst, attention_img)

        yhat_dst = os.path.join(args.o, '{}_ypred.npy'.format(basename))
        np.save(yhat_dst, yhat)

        try:
            svs.close()
            os.remove(ramdisk_path)
            del svs
        except:
            print('{} already removed'.format(ramdisk_path))
コード例 #12
0
def main(args):
    # Translate obfuscated file names to paths if necessary
    test_list = os.path.join(args.testdir, '{}.txt'.format(args.timestamp))
    test_list = read_test_list(test_list)
    test_unique_ids = [
        os.path.basename(x).replace('.npy', '') for x in test_list
    ]
    if args.randomize:
        np.random.shuffle(test_unique_ids)

    if args.max_slides:
        test_unique_ids = test_unique_ids[:args.max_slides]

    slide_list, slide_labels = get_slidelist_from_uids(test_unique_ids)

    print('Found {} slides'.format(len(slide_list)))

    snapshot = os.path.join(args.savedir, '{}.h5'.format(args.timestamp))
    # trained_model = load_model(snapshot)
    # if args.mcdropout:
    #   encoder_args['mcdropout'] = True

    encoder_args = get_encoder_args(args.encoder)
    model = MilkEager(encoder_args=encoder_args,
                      mil_type=args.mil,
                      batch_size=args.batch_size,
                      deep_classifier=args.deep_classifier,
                      temperature=args.temperature)

    x_pl = np.zeros((1, args.batch_size, args.input_dim, args.input_dim, 3),
                    dtype=np.float32)
    yhat = model(tf.constant(x_pl), verbose=True)
    print('yhat:', yhat.shape)

    print('setting model weights')
    model.load_weights(snapshot, by_name=True)

    ## Loop over found slides:
    yhats = []
    ytrues = []
    for i, (src, lab) in enumerate(zip(slide_list, slide_labels)):
        print('\nSlide {}'.format(i))
        basename = os.path.basename(src).replace('.svs', '')
        fgpth = os.path.join(args.fgdir, '{}_fg.png'.format(basename))
        if os.path.exists(fgpth):
            ramdisk_path = transfer_to_ramdisk(
                src, args.ramdisk)  # never use the original src
            print('Using fg image at : {}'.format(fgpth))
            fgimg = cv2.imread(fgpth, 0)
            svs = Slide(
                slide_path=ramdisk_path,
                # background_speed  = 'accurate',
                background_speed='image',
                background_image=fgimg,
                # preprocess_fn     = lambda x: (reinhard(x)/255.).astype(np.float32),
                preprocess_fn=lambda x: (x / 255.).astype(np.float32),
                process_mag=args.mag,
                process_size=args.input_dim,
                oversample_factor=args.oversample,
                verbose=False)
        else:
            ## require precomputed background; Exit.
            print('Required foreground image not found ({})'.format(fgpth))
            continue

        svs.initialize_output(name='attention', dim=1, mode='tile')
        n_tiles = len(svs.tile_list)

        yhat, att, indices = process_slide(svs, model, args)
        print('returned attention:', np.min(att), np.max(att), att.shape)

        yhat = yhat.numpy()
        yhats.append(yhat)
        ytrues.append(lab)
        print('\tSlide label: {} predicted: {}'.format(lab, yhat))

        svs.place_batch(att, indices, 'attention', mode='tile')
        attention_img = np.squeeze(svs.output_imgs['attention'])
        attention_img = attention_img * (1. / attention_img.max())
        attention_img = draw_attention(attention_img, n_bins=25)
        print('attention image:', attention_img.shape, attention_img.dtype,
              attention_img.min(), attention_img.max())

        dst = os.path.join(
            args.odir, args.timestamp,
            '{}_{}_{:3.3f}_att.npy'.format(basename, lab, yhat[0, 1]))
        np.save(dst, att)

        dst = os.path.join(
            args.odir, args.timestamp,
            '{}_{}_{:3.3f}_img.png'.format(basename, lab, yhat[0, 1]))
        cv2.imwrite(dst, attention_img)

        try:
            svs.close()
            os.remove(ramdisk_path)
        except:
            print('{} already removed'.format(ramdisk_path))

    yhats = np.concatenate(yhats, axis=0)
    ytrues = np.array(ytrues)
    acc = (np.argmax(yhats, axis=-1) == ytrues).mean()
    print(acc)
コード例 #13
0
ファイル: deploy_eager.py プロジェクト: nathanin/milk
def main(args):
    ## Search for slides
    # slide_list = sorted(glob.glob(os.path.join(args.slide_dir, '*.svs')))
    slide_list = read_list(args.slide_list)
    print('Found {} slides'.format(len(slide_list)))
    if args.shuffle:
        np.random.shuffle(slide_list)

    encoder_args = get_encoder_args(args.encoder)
    model = ClassifierEager(encoder_args=encoder_args,
                            deep_classifier=True,
                            n_classes=args.n_classes)
    fake_data = tf.constant(
        np.zeros((1, args.input_dim, args.input_dim, 3), dtype=np.float32))
    yhat_ = model(fake_data)
    model.load_weights(args.snapshot)
    model.summary()
    if not os.path.exists(args.save_dir):
        # shutil.rmtree(args.save_dir)
        os.makedirs(args.save_dir)

    ## Loop over found slides:
    for src in slide_list:
        basename = os.path.basename(src).replace('.svs', '')
        dst = os.path.join(args.save_dir, '{}.npy'.format(basename))
        if os.path.exists(dst):
            print('{} exists. Skipping'.format(dst))
            continue

        ramdisk_path = transfer_to_ramdisk(
            src, args.ramdisk)  # never use the original src
        try:
            fgpth = os.path.join(args.fgdir, '{}_fg.png'.format(basename))
            fgimg = cv2.imread(fgpth, 0)
            fgimg = fill_fg(fgimg)
            svs = Slide(slide_path=ramdisk_path,
                        preprocess_fn=lambda x: (x / 255.).astype(np.float32),
                        normalize_fn=lambda x: x,
                        background_speed='image',
                        background_image=fgimg,
                        process_mag=args.mag,
                        process_size=args.input_dim,
                        oversample_factor=1.1)
            svs.initialize_output(name='prob', dim=args.n_classes, mode='full')
            svs.initialize_output(name='rgb', dim=3, mode='full')
            n_tiles = len(svs.tile_list)
            prefetch = min(512, n_tiles)
            # Get tensors for image an index
            iterator = get_img_idx(svs, args.batch_size, prefetch)
            batches = 0
            for img_, idx_ in iterator:
                batches += 1
                yhat = model(img_, training=False)
                yhat = yhat.numpy()
                idx_ = idx_.numpy()
                img_ = img_.numpy()
                yhat = vect_to_tile(yhat, args.input_dim)
                svs.place_batch(yhat, idx_, 'prob', mode='full')
                svs.place_batch(img_, idx_, 'rgb', mode='full')
                if batches % 50 == 0:
                    print('\tbatch {:04d}'.format(batches))

            svs.make_outputs(reference='prob')
            prob_img = svs.output_imgs['prob']
            rgb_img = svs.output_imgs['rgb'] * 255
            color_img = colorize(rgb_img, prob_img)
            dst = os.path.join(args.save_dir, '{}.npy'.format(basename))
            np.save(dst, (prob_img * 255).astype(np.uint8))
            dst = os.path.join(args.save_dir, '{}.jpg'.format(basename))
            cv2.imwrite(dst, rgb_img[:, :, ::-1])
            dst = os.path.join(args.save_dir, '{}_c.jpg'.format(basename))
            cv2.imwrite(dst, color_img[:, :, ::-1])
        except Exception as e:
            print(e)
        finally:
            try:
                print('Closing SVS')
                svs.close()
            except:
                print('No SVS to close')

            os.remove(ramdisk_path)
コード例 #14
0
def main(args):
    transform_fn = data_utils.make_transform_fn(args.x_size,
                                                args.y_size,
                                                args.crop_size,
                                                args.scale,
                                                flip=False,
                                                middle_crop=True,
                                                rotate=False,
                                                normalize=True)

    snapshot = 'save/{}.h5'.format(args.timestamp)
    # trained_model = load_model(snapshot)

    encoder_args = get_encoder_args(args.encoder)
    if args.mcdropout:
        encoder_args['mcdropout'] = True

    print('Model initializing')
    model = MilkEager(encoder_args=encoder_args,
                      mil_type=args.mil,
                      deep_classifier=args.deep_classifier,
                      cls_normalize=args.cls_normalize,
                      batch_size=args.batch_size,
                      temperature=args.temperature)
    xdummy = tf.zeros((1, args.batch_size, args.x_size, args.y_size, 3))
    ydummy = model(xdummy, verbose=True)
    print(xdummy.shape, ydummy.shape)
    del xdummy, ydummy
    model.load_weights(snapshot, by_name=True)

    test_list = os.path.join(args.testdir, '{}.txt'.format(args.timestamp))
    test_list = read_test_list(test_list)

    ytrues = []
    yhats = []
    print('MC Dropout: {}'.format(args.mcdropout))
    for _ in range(args.n_repeat):
        for test_case in test_list:
            case_x, case_y = data_utils.load(
                data_path=test_case,
                transform_fn=transform_fn,
                case_label_fn=case_label_fn,
                all_tiles=True,
            )
            # case_x = np.squeeze(case_x, axis=0)
            print('Running case x: ', case_x.shape)
            yhat = run_sample(case_x,
                              model,
                              mcdropout=args.mcdropout,
                              sample_mode=args.sample_mode)
            ytrues.append(case_y)
            yhats.append(yhat)
            print(test_case, case_y, case_x.shape, yhat)

    ytrue = np.concatenate(ytrues, axis=0)
    yhat = np.concatenate(yhats, axis=0)

    ytrue_max = np.argmax(ytrue, axis=-1)
    yhat_max = np.argmax(yhat, axis=-1)
    accuracy = (ytrue_max == yhat_max).mean()
    print('Accuracy: {:3.3f}'.format(accuracy))

    if args.odir is not None:
        save_img = os.path.join(args.odir, '{}.png'.format(args.timestamp))
        save_metrics = os.path.join(args.odir, '{}.txt'.format(args.timestamp))
        save_yhat = os.path.join(args.odir, '{}.npy'.format(args.timestamp))
        save_ytrue = os.path.join(args.odir,
                                  '{}_ytrue.npy'.format(args.timestamp))
        np.save(save_yhat, yhat)
        np.save(save_ytrue, ytrue)
    else:
        save_img = None
        save_metrics = None
        save_yhat = None

    auc_curve(ytrue, yhat, savepath=save_img)
    test_eval(ytrue, yhat, savepath=save_metrics)
コード例 #15
0
def main(args):
    print(args)
    # Get crop size from input_dim and downsample
    crop_size = int(args.input_dim / args.downsample)

    # Build the dataset
    dataset = ClassificationDataset(record_path=args.dataset,
                                    crop_size=crop_size,
                                    downsample=args.downsample,
                                    n_classes=args.n_classes,
                                    n_threads=args.n_threads,
                                    batch=args.batch_size,
                                    prefetch_buffer=args.prefetch_buffer,
                                    shuffle_buffer=args.shuffle_buffer,
                                    device=args.device,
                                    device_buffer=args.device_buffer,
                                    eager=True)

    # Test batch:
    batchx, batchy = next(dataset.iterator)
    print('Test batch:')
    print('batchx: ', batchx.get_shape())
    print('batchy: ', batchy.get_shape())

    ## Working on this way -- hoping for a speed up with the keras.Model.fit() functions..
    # input_tensor = Input(name='images', shape=[args.input_dim, args.input_dim, args.image_channels] )
    # logits = ClassifierEager(encoder_args=encoder_args, n_classes=args.n_classes)(input_tensor)
    # model = tf.keras.Model(inputs=input_tensor, outputs=logits)

    encoder_args = get_encoder_args(args.encoder)
    model = ClassifierEager(encoder_args=encoder_args,
                            n_classes=args.n_classes)
    yhat = model(batchx, training=True, verbose=True)
    print('yhat: ', yhat.get_shape())

    if args.snapshot is not None and os.path.exists(args.snapshot):
        model.load_weights(args.snapshot)

    # optimizer = tf.keras.optimizers.Adam(lr=args.learning_rate, decay=1e-5)
    optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate)

    model.summary()

    # model.compile(optimizer = optimizer,
    #   loss = 'categorical_crossentropy',
    #   metrics = ['categorical_accuracy'])

    # model.fit_generator(generator=tf.contrib.eager.Iterator(dataset),
    #   steps_per_epoch = args.steps_per_epoch,
    #   epochs = args.epochs)

    try:
        running_avg = []
        for k in range(args.steps_per_epoch * args.epochs):
            with tf.GradientTape() as tape:
                batchx, batchy = next(dataset.iterator)
                yhat = model(batchx)

                loss = tf.keras.losses.categorical_crossentropy(y_true=batchy,
                                                                y_pred=yhat)
                running_avg.append(np.mean(loss))

                grads = tape.gradient(loss, model.variables)
                optimizer.apply_gradients(zip(grads, model.variables))

            if k % 50 == 0:
                print('{:05d} loss={:3.3f}'.format(k, np.mean(running_avg)))
                running_avg = []

    except Exception as e:
        print(e)
        print('Caught exception')
        traceback.print_tb(e.__traceback__)

    finally:
        print('Saving')
        model.save_weights(args.saveto)
コード例 #16
0
ファイル: train_mnist.py プロジェクト: nathanin/milk
def main(args):
    if args.mnist is not None:
        (train_x, train_y), (test_x, test_y) = mnist.load_data(args.mnist)
    else:
        (train_x, train_y), (test_x, test_y) = mnist.load_data()

    train_x = train_x / 255.
    test_x = test_x / 255.
    print('train_x:', train_x.shape, train_x.dtype, train_x.min(),
          train_x.max())
    print('train_y:', train_y.shape)
    print('test_x:', test_x.shape)
    print('test_y:', test_y.shape)

    positive_label = np.random.choice(range(10), 1, replace=False)
    print('using positive label = {}'.format(positive_label))

    train_x_pos, train_x_neg = rearrange_bagged_mnist(train_x, train_y,
                                                      positive_label)
    test_x_pos, test_x_neg = rearrange_bagged_mnist(test_x, test_y,
                                                    positive_label)
    print('rearranged training set:')
    print('\ttrain_x_pos:', train_x_pos.shape, train_x_pos.dtype,
          train_x_pos.min(), train_x_pos.max())
    print('\ttrain_x_neg:', train_x_neg.shape)
    print('\ttest_x_pos:', test_x_pos.shape)
    print('\ttest_x_neg:', test_x_neg.shape)

    generator = generate_bagged_mnist(train_x_pos, train_x_neg, args.bag_size,
                                      args.batch_size)
    val_generator = generate_bagged_mnist(test_x_pos, test_x_neg,
                                          args.bag_size, args.batch_size)
    batch_x, batch_y = next(generator)
    print('batch_x:', batch_x[0].shape, 'batch_y:', batch_y.shape)

    encoder_args = get_encoder_args('mnist')
    model = MilkBatch(input_shape=(args.bag_size, 28, 28, 1),
                      encoder_args=encoder_args,
                      mode=args.mil,
                      batch_size=args.batch_size,
                      bag_size=args.bag_size,
                      deep_classifier=True)

    # model = MilkEager(encoder_args = encoder_args,
    #                   mil_type = args.mil,)
    # model.build((args.batch_size, args.bag_size, 28, 28, 1))
    # model.summary()

    if args.pretrained is not None and os.path.exists(args.pretrained):
        print('Restoring weights from {}'.format(args.pretrained))
        model.load_weights(args.pretrained, by_name=True)
    else:
        print('Pretrained model not found ({}). Continuing end 2 end.'.format(
            args.pretrained))

    if args.gpus > 1:
        print('Duplicating model onto 2 GPUs')
        model = tf.keras.utils.multi_gpu_model(model,
                                               args.gpus,
                                               cpu_merge=True,
                                               cpu_relocation=False)

    optimizer = tf.keras.optimizers.Adam(lr=args.lr, decay=args.decay)
    model.compile(
        optimizer=optimizer,
        loss=tf.keras.losses.categorical_crossentropy,
        metrics=['categorical_accuracy'],
    )

    model.summary()

    # for epc in range(args.epochs):
    #   for k in range(int(args.epoch_steps)):
    #     batch_x, batch_y = next(generator)
    #     model.train_on_batch(batch_x, batch_y)
    #     if k % 10 == 0:
    #       y_pred = model.predict(batch_x)
    #       print(y_pred)

    callbacks = [
        tf.keras.callbacks.TensorBoard(histogram_freq=1,
                                       write_graph=False,
                                       write_grads=True,
                                       update_freq='batch')
    ]

    model.fit_generator(generator=generator,
                        validation_data=val_generator,
                        validation_steps=100,
                        steps_per_epoch=args.epoch_steps,
                        epochs=args.epochs,
                        callbacks=callbacks)

    model.save(args.o)