def main(args): #with tf.device('/gpu:0'): print('Model initializing') encoder_args = get_encoder_args(args.encoder) model = MilkEager(encoder_args=encoder_args, mil_type=args.mil, deep_classifier=args.deep_classifier, batch_size=16, temperature=args.temperature, heads=args.heads) print('Running once to load CUDA') x = tf.zeros((1, 1, args.crop_size, args.crop_size, args.channels)) ## This is really weird. eager mode complains when this is a range() ## It evern complains when it's a list(range()) ## If the tf.contrib.eager.defun decorator is removed, it's OK ## So it's an autograph problem all_heads = [0,1,2,3,4,5,6,7,8,9] yhat = model(x.gpu(), heads=all_heads, training=True, verbose=True) model.summary() model.load_weights(args.pretrained_model, by_name=True) ## Set up the data stuff data_factory = MILDataset(args.dataset, crop=args.crop_size, n_classes=2) data_factory.split_train_val('case_id', seed=args.seed) test_iterator = data_factory.tensorflow_iterator(mode='test', seed=args.seed, batch_size=1, buffer_size=1, threads=1, subset=args.bag_size, attr='stage_code', eager=True) ## Track yhats print('-------------------------------------------------------\n\n') ytrues , yhats = [], [] all_heads = list(range(args.heads)) for k, (x, y) in enumerate(test_iterator): # print('{:03d}: ytrue = {}'.format(k, y[0,1])) ytrues.append(y[0,1]) yhat = model(x.gpu(), training=False, heads=all_heads) yhat = np.array([yh[0,1] for yh in yhat]) yhats.append(yhat) # print(' yhat = {}'.format(yhat)) # Take a running mean and show it acc, _ = calc_acc(ytrues, yhats) print('\r{:03d} Accuracy: {:3.3f} %'.format(k, acc), end='', flush = True) # sys.stdout.flush() print('\n\n-------------------------------------------------------') print_accuracy(ytrues, yhats) write_out(ytrues, yhats, args.out)
def main(args): transform_fn = data_utils.make_transform_fn(128, 128, args.input_dim, 1.0, normalize=True) snapshot = os.path.join('../experiment/save', '{}.h5'.format(args.timestamp)) test_list = os.path.join('../experiment/test_lists', '{}.txt'.format(args.timestamp)) model = MilkEager(encoder_args=encoder_args, mil_type=args.mil, deep_classifier=True) x_dummy = tf.zeros( shape=[1, args.batch_size, args.input_dim, args.input_dim, 3], dtype=tf.float32) retvals = model(x_dummy, verbose=True) model.load_weights(snapshot, by_name=True) model.summary() test_list = read_test_list(test_list) yhats, ytrues = [], [] features_case, features_classifier = [], [] for test_case in test_list: case_name = os.path.basename(test_case).replace('.npy', '') print(test_case, case_name) # case_path = os.path.join('../dataset/tiles_reduced', '{}.npy'.format(case_name)) case_x = np.load(test_case) case_x = np.stack([transform_fn(x) for x in case_x], 0) print(case_x.shape) if args.sample: case_x = case_x[ np.random.choice(range(case_x.shape[0]), args.sample), ...] print(case_x.shape) features = model.encode_bag(case_x, batch_size=args.batch_size, training=True, return_z=True) print('features:', features.shape) # features_att, attention = model.mil_attention(features, return_att=True, training=False) # print('features:', features_att.shape, 'attention:', attention.shape) features_att = tf.reduce_mean(features, axis=0, keep_dims=True) yhat_instances = model.apply_classifier(features, training=False) print('yhat instances:', yhat_instances.shape) yhat = model.apply_classifier(features_att, training=False) print('yhat:', yhat.shape) yhat_1 = yhat_instances[:, 1].numpy() # yhat, attention, features, feat_case, feat_class = retvals # attention = np.squeeze(attention.numpy(), axis=0) high_att_idx, high_att_imgs, low_att_idx, low_att_imgs = get_attention_extremes( yhat_1, case_x, n=5) print('Case {}: predicted={} ({})'.format(test_case, np.argmax(yhat, axis=-1), yhat.numpy())) features = features.numpy() savepath = '{}_{}_{:3.2f}_yhat.png'.format(args.savebase, case_name, yhat[0, 1]) print('Saving figure {}'.format(savepath)) z = draw_projection(features, features_att, yhat_1, savepath=savepath) savepath = '{}_{}_{:3.2f}_yhat_img.png'.format(args.savebase, case_name, yhat[0, 1]) print('Saving figure {}'.format(savepath)) draw_projection_with_images(z, yhat_1, high_att_idx, high_att_imgs, low_att_idx, low_att_imgs, savepath=savepath)
def main(args): # Define a compute_fn that should do three things: # 1. define an iterator over the slide's tiles # 2. compute an output with given model parameter # 3. if args.iter_type == 'python': def compute_fn(slide, args, model=None): print('Slide with {}'.format(len(slide.tile_list))) it_factory = PythonIterator(slide, args) for k, (img, idx) in enumerate(it_factory.yield_batch()): prob = model(img) if k % 50 == 0: print('Batch #{:04d} idx:{} img:{} prob:{}'.format( k, idx.shape, img.shape, prob.shape)) slide.place_batch(prob, idx, 'prob', mode='tile') ret = slide.output_imgs['prob'] return ret # Tensorflow multithreaded queue-based iterator (in eager mode) elif args.iter_type == 'tf': def compute_fn(slide, args, model=None): assert tf.executing_eagerly() print('Slide with {}'.format(len(slide.tile_list))) # In eager mode, we return a tf.contrib.eager.Iterator eager_iterator = TensorflowIterator(slide, args).make_iterator() # The iterator can be used directly. Ququeing and multithreading # are handled in the backend by the tf.data.Dataset ops features, indices = [], [] for k, (img, idx) in enumerate(eager_iterator): # img = tf.expand_dims(img, axis=0) features.append( model.encode_bag(img, training=False, return_z=True)) indices.append(idx.numpy()) img, idx = img.numpy(), idx.numpy() if k % 50 == 0: print('Batch #{:04d}\t{}'.format(k, img.shape)) features = tf.concat(features, axis=0) z_att, att = model.mil_attention(features, training=False, return_raw_att=True) att = np.squeeze(att) indices = np.concatenate(indices) slide.place_batch(att, indices, 'att', mode='tile') ret = slide.output_imgs['att'] return ret # Set up the model first encoder_args = get_encoder_args(args.encoder) model = MilkEager(encoder_args=encoder_args, mil_type=args.mil, deep_classifier=args.deep_classifier, batch_size=args.batchsize, temperature=args.temperature, heads=args.heads) x = tf.zeros((1, 1, args.process_size, args.process_size, 3)) _ = model(x, verbose=True, head='all', training=True) model.load_weights(args.snapshot, by_name=True) # keras Model subclass model.summary() # Read list of inputs with open(args.slides, 'r') as f: slides = [x.strip() for x in f] # Loop over slides for src in slides: # Dirty substitution of the file extension give us the # destination. Do this first so we can just skip the slide # if this destination already exists. # Set the --suffix option to reflect the model / type of processed output dst = repext(src, args.suffix) # Loading data from ramdisk incurs a one-time copy cost rdsrc = cpramdisk(src, args.ramdisk) print('File:', rdsrc) # Wrapped inside of a try-except-finally. # We want to make sure the slide gets cleaned from # memory in case there's an error or stop signal in the # middle of processing. try: # Initialze the side from our temporary path, with # the arguments passed in from command-line. # This returns an svsutils.Slide object slide = Slide(rdsrc, args) # This step will eventually be included in slide creation # with some default compute_fn's provided by svsutils # For now, do it case-by-case, and use the compute_fn # that we defined just above. slide.initialize_output('att', args.n_classes, mode='tile', compute_fn=compute_fn) # Call the compute function to compute this output. # Again, this may change to something like... # slide.compute_all # which would loop over all the defined output types. ret = slide.compute('att', args, model=model) print('{} --> {}'.format(ret.shape, dst)) np.save(dst, ret[:, :, ::-1]) except Exception as e: print(e) traceback.print_tb(e.__traceback__) finally: print('Removing {}'.format(rdsrc)) os.remove(rdsrc)
def main(args): transform_fn = data_utils.make_transform_fn(128, 128, args.input_dim, 1.0, normalize=True) snapshot = os.path.join(args.snapshot_dir, '{}.h5'.format(args.timestamp)) test_list = os.path.join(args.test_list_dir, '{}.txt'.format(args.timestamp)) encoder_args = get_encoder_args(args.encoder) model = MilkEager(encoder_args=encoder_args, deep_classifier=True, mil_type='instance', batch_size=args.batch_size, temperature=args.temperature, cls_normalize=args.cls_normalize) x_dummy = tf.zeros(shape=[1, args.batch_size, args.input_dim, args.input_dim, 3], dtype=tf.float32) retvals = model(x_dummy, verbose=True) model.load_weights(snapshot, by_name=True) model.summary() test_list = read_test_list(test_list) savebase = os.path.join(args.odir, args.timestamp) if os.path.exists(savebase): shutil.rmtree(savebase) os.makedirs(savebase) yhats, ytrues = [], [] features_case, features_classifier = [], [] for test_case in test_list: case_name = os.path.basename(test_case).replace('.npy', '') print(test_case, case_name) # case_path = os.path.join('../dataset/tiles_reduced', '{}.npy'.format(case_name)) case_x = np.load(test_case) case_x = np.stack([transform_fn(x) for x in case_x], 0) ytrue = case_dict[case_name] print(case_x.shape, ytrue) ytrues.append(ytrue) if args.sample: case_x = case_x[np.random.choice(range(case_x.shape[0]), args.sample), ...] print(case_x.shape) # TODO variable names. attention --> something else features, attention = model.encode_bag(case_x, training=False, return_z=True) yhat = np.mean(attention, axis=0, keepdims=True) attention = attention.numpy()[:,1] print('features:', features.shape) print('attention:', attention.shape) print('yhat:', yhat.shape) # features_att, attention = model.mil_attention(features, return_att=True, training=False) # print('features:', features_att.shape, 'attention:', attention.shape) features_avg = np.mean(features, axis=0, keepdims=True) features_att = features_avg yhats.append(yhat) # yhat, attention, features, feat_case, feat_class = retvals # attention = np.squeeze(attention.numpy(), axis=0) high_att_idx, high_att_imgs, low_att_idx, low_att_imgs = get_attention_extremes( attention, case_x, n = 5) print('Case {}: predicted={} ({})'.format( test_case, np.argmax(yhat, axis=-1) , yhat)) features = features.numpy() savepath = os.path.join(savebase, '{}_{:3.2f}.png'.format(case_name, yhat[0,1])) print('Saving figure {}'.format(savepath)) z = draw_projection(features, features_avg, features_att, attention, savepath=savepath) # savepath = os.path.join(savebase, '{}_{:3.2f}_ys.png'.format(case_name, yhat[0,1])) # print('Saving figure {}'.format(savepath)) # draw_projection_with_images(z, yhat_instances[:,1].numpy(), # high_att_idx, high_att_imgs, # low_att_idx, low_att_imgs, # savepath=savepath) savepath = os.path.join(savebase, '{}_{:3.2f}_imgs.png'.format(case_name, yhat[0,1])) print('Saving figure {}'.format(savepath)) draw_projection_with_images(z, attention, high_att_idx, high_att_imgs, low_att_idx, low_att_imgs, savepath=savepath) savepath = os.path.join(savebase, '{}_atns.npy'.format(case_name)) np.save(savepath, attention) savepath = os.path.join(savebase, '{}_feat.npy'.format(case_name)) np.save(savepath, features) yhats = np.concatenate(yhats, axis=0) yhats = np.argmax(yhats, axis=1) ytrues = np.array(ytrues) acc = (yhats == ytrues).mean() print(acc) cm = confusion_matrix(y_true=ytrues, y_pred=yhats) print(cm)
def main(args): if args.mnist is not None: (train_x, train_y), (test_x, test_y) = mnist.load_data(args.mnist) else: (train_x, train_y), (test_x, test_y) = mnist.load_data() print('train_x:', train_x.shape, train_x.dtype, train_x.min(), train_x.max()) print('train_y:', train_y.shape) print('test_x:', test_x.shape) print('test_y:', test_y.shape) positive_label = np.random.choice(range(10)) print('using positive label = {}'.format(positive_label)) train_x_pos, train_x_neg = rearrange_bagged_mnist(train_x, train_y, positive_label) test_x_pos, test_x_neg = rearrange_bagged_mnist(test_x, test_y, positive_label) print('rearranged training set:') print('\ttrain_x_pos:', train_x_pos.shape, train_x_pos.dtype, train_x_pos.min(), train_x_pos.max()) print('\ttrain_x_neg:', train_x_neg.shape) print('\ttest_x_pos:', test_x_pos.shape) print('\ttest_x_neg:', test_x_neg.shape) generator = generate_bagged_mnist(train_x_pos, train_x_neg, args.n, args.batch_size) val_generator = generate_bagged_mnist(test_x_pos, test_x_neg, args.n, args.batch_size) batch_x, batch_y = next(generator) print('batch_x:', batch_x.shape, 'batch_y:', batch_y.shape) encoder_args = get_encoder_args('mnist') model = MilkEager( encoder_args=encoder_args, mil_type=args.mil, deep_classifier=True, ) y_dummy = model(batch_x, verbose=True) model.summary() if args.pretrained is not None and os.path.exists(args.pretrained): model.load_weights(args.pretrained, by_name=True) else: print('Pretrained model not found ({}). Continuing end 2 end.'.format( args.pretrained)) if args.gpus > 1: print('Duplicating model onto 2 GPUs') model = tf.keras.utils.multi_gpu_model(model, args.gpus, cpu_merge=True, cpu_relocation=False) optimizer = tf.train.AdamOptimizer(learning_rate=args.lr) try: for k in range(int(args.steps_per_epoch * args.epochs)): with tf.GradientTape() as tape: x, y = next(generator) yhat = model(tf.constant(x), training=True) loss = tf.keras.losses.categorical_crossentropy( y_true=tf.constant(y, dtype=tf.float32), y_pred=yhat) grads = tape.gradient(loss, model.variables) optimizer.apply_gradients(zip(grads, model.variables)) if k % 50 == 0: print('{:06d}: loss={:3.5f}'.format(k, np.mean(loss))) for y_, yh_ in zip(y, yhat): print('\t{} {}'.format(y_, yh_)) except KeyboardInterrupt: print('Keyboard interrupt caught') except Exception as e: print('Other error caught') print(type(e)) print(e) finally: model.save_weights(args.o) print('Saved model: {}'.format(args.o))
def main(args): # Define a compute_fn that should do three things: # 1. define an iterator over the slide's tiles # 2. compute an output with a given model / arguments # 3. return a reconstructed slide def compute_fn(slide, args, model=None, n_dropout=10 ): assert tf.executing_eagerly() print('Slide with {}'.format(len(slide.tile_list))) # In eager mode, we return a tf.contrib.eager.Iterator eager_iterator = TensorflowIterator(slide, args).make_iterator() # The iterator can be used directly. Ququeing and multithreading # are handled in the backend by the tf.data.Dataset ops features, indices = [], [] for k, (img, idx) in enumerate(eager_iterator): # img = tf.expand_dims(img, axis=0) features.append( model.encode_bag(img, training=False, return_z=True) ) indices.append(idx.numpy()) img, idx = img.numpy(), idx.numpy() if k % 50 == 0: print('Batch #{:04d}\t{}'.format(k, img.shape)) features = tf.concat(features, axis=0) ## Sample-dropout # features = features.numpy() # print(features.shape) # n_instances = features.shape[0] # att = np.zeros(n_instances) # n_choice = int(n_instances * 0.7) # all_heads = list(range(args.heads)) # for j in range(n_dropout): # idx = np.random.choice(range(n_instances), n_choice, replace=False) # print(idx) # fdrop = features[idx, :] z_att, att = model.mil_attention(features, training=False, return_raw_att=True) # att[idx] += np.squeeze(attdrop) yhat_multihead = model.apply_classifier(z_att, heads=all_heads, training=False) print('yhat mean {}'.format(np.mean(yhat_multihead, axis=0))) indices = np.concatenate(indices) att = np.squeeze(att) slide.place_batch(att, indices, 'att', mode='tile') ret = slide.output_imgs['att'] print('Got attention image: {}'.format(ret.shape)) return ret, features.numpy() ## Begin main script: # Set up the model first encoder_args = get_encoder_args(args.encoder) model = MilkEager(encoder_args=encoder_args, mil_type=args.mil, deep_classifier=args.deep_classifier, batch_size=args.batchsize, temperature=args.temperature, heads = args.heads) x = tf.zeros((1, 1, args.process_size, args.process_size, 3)) all_heads = [0,1,2,3,4,5,6,7,8,9] _ = model(x, verbose=True, heads=all_heads, training=True) model.load_weights(args.snapshot, by_name=True) # keras Model subclass model.summary() # Read list of inputs with open(args.slides, 'r') as f: slides = [x.strip() for x in f] # Loop over slides for src in slides: # Dirty substitution of the file extension give us the # destination. Do this first so we can just skip the slide # if this destination already exists. # Set the --suffix option to reflect the model / type of processed output dst = repext(src, args.suffix) featdst = repext(src, args.suffix+'.feat.npy') # Loading data from ramdisk incurs a one-time copy cost rdsrc = cpramdisk(src, args.ramdisk) print('\n\nFile:', rdsrc) # Wrapped inside of a try-except-finally. # We want to make sure the slide gets cleaned from # memory in case there's an error or stop signal in the # middle of processing. try: # Initialze the side from our temporary path, with # the arguments passed in from command-line. # This returns an svsutils.Slide object slide = Slide(rdsrc, args) # This step will eventually be included in slide creation # with some default compute_fn's provided by svsutils # For now, do it case-by-case, and use the compute_fn # that we defined just above. slide.initialize_output('att', args.n_classes, mode='tile', compute_fn=compute_fn) # Call the compute function to compute this output. # Again, this may change to something like... # slide.compute_all # which would loop over all the defined output types. ret, features = slide.compute('att', args, model=model) print('{} --> {}'.format(ret.shape, dst)) print('{} --> {}'.format(features.shape, featdst)) np.save(dst, ret) np.save(featdst, features) except Exception as e: print(e) traceback.print_tb(e.__traceback__) finally: print('Removing {}'.format(rdsrc)) os.remove(rdsrc)
def main(args): """ 1. Create generator datasets from the provided lists 2. train and validate Milk v0 - create datasets within this script v1 - factor monolithic training_utils.mil_train_loop !! tpu - replace data feeder and mil_train_loop with tf.keras.Model.fit() July 4 2019 - added MILDataset that takes away all the dataset nonsense """ out_path, exptime_str = create_outputs(args) data_factory = MILDataset(args.dataset, crop=args.crop_size, n_classes=2) data_factory.split_train_val('case_id', seed=args.seed) #with tf.device('/gpu:0'): print('Model initializing') encoder_args = get_encoder_args(args.encoder) model = MilkEager(encoder_args=encoder_args, mil_type=args.mil, deep_classifier=args.deep_classifier, batch_size=16, temperature=args.temperature, heads=args.heads) print('Running once to load CUDA') x = tf.zeros((1, 1, args.crop_size, args.crop_size, args.channels)) ## The way we give the list is very particular. all_heads = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] yhat = model(x, heads=all_heads, training=True, verbose=True) print('yhat: {} ({} {})'.format(yhat[0], yhat[0].shape, yhat[0].dtype)) model.summary() ## Replace randomly initialized weights after model is compiled and on the correct device. if args.pretrained_model is not None and os.path.exists( args.pretrained_model): print('Replacing random weights with weights from {}'.format( args.pretrained_model)) try: model.load_weights(args.pretrained_model, by_name=True) except Exception as e: print(e) ## Controlling overfitting by monitoring a metric, with some patience since the last improvement if args.early_stop: stopper = ShouldStop(patience=args.heads) else: stopper = lambda x: False # accumulator = GradientAccumulator(n = args.accumulate, variable_list=trainable_variables) optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate) trackers = {stat: [] for stat in ['loss', 'acc']} trackers['step'] = 0 def py_it(): return data_factory.python_iterator(mode='train', subset=args.bag_size, attr='stage_code', seed=None, epochs=args.epochs) train_len = data_factory.dataset_lengths['train'] tf_dataset = ( tf.data.Dataset.from_generator(py_it, output_types=(tf.uint8, tf.uint8)) # .repeat(repeats) .map(data_factory.map_fn, num_parallel_calls=args.threads) # .prefetch(buffer_size) .batch(args.batch_size)) # data_factory.tensorflow_iterator(mode='train', ref='train', repeats=args.epochs, # subset=args.bag_size, seed=None, attr='stage_code', threads=args.threads) # data_factory.tensorflow_iterator(mode='val', ref='val', repeats=args.epochs, # subset=args.bag_size, seed=None, attr='stage_code', threads=args.threads) # val_iterator, val_len = data_factory.iterator_refs['val'] # train_iterator, train_len = data_factory.iterator_refs['train'] trainable_variables = model.trainable_variables try: # for epc in range(args.epochs): # tf.set_random_seed(1) # trackers = train_epoch(model, optimizer, train_iterator, train_len, epc, trackers, args) # train_head = [epc % args.heads] avglosses, steptimes = [], [] # print('\nTraining head {}'.format(train_head)) for k, (x, y) in enumerate(tf_dataset): if k % train_len == 0: gc.collect() tf.set_random_seed(1) # tf.reset_default_graph() train_head = [np.random.choice(args.heads)] print('\nTraining head [{}]'.format(train_head)) tstart = time.time() with tf.GradientTape() as tape: yhat = model(x, training=True, heads=train_head) loss = tf.keras.losses.categorical_crossentropy( y_true=tf.constant(y, dtype=tf.float32), y_pred=yhat[0]) grads = tape.gradient(loss, trainable_variables) del tape loss_mn = np.mean(loss) acc = eval_acc(y, yhat) avglosses.append(loss_mn) trackers['loss'].append(loss_mn) trackers['acc'].append(acc) trackers['step'] += 1 tend = time.time() steptimes.append(tend - tstart) # if should_update: # grads = accumulator.accumulate() # with tf.device('/cpu:0'): # print('Applying gradients') optimizer.apply_gradients(zip(grads, trainable_variables)) print('\r{:07d}: loss={:3.5f} dt={:3.3f}s '.format( k, np.mean(avglosses), np.mean(steptimes)), end='', flush=1) # if (k+1) % train_len == 0: break # if epc % args.snapshot_epochs == 0: # snapshot_path = out_path.replace('.h5', '-{:03d}.h5'.format(epc)) # print('Snapshotting to {}'.format(snapshot_path)) # model.save_weights(snapshot_path) except KeyboardInterrupt: print('Keyboard interrupt caught') except Exception as e: print('Other error caught') print(e) traceback.print_tb(e.__traceback__) finally: model.save_weights(out_path) print('Saved model: {}'.format(out_path)) # Save the loss profile training_stats = os.path.join( args.out_base, 'save', '{}_training_curves.txt'.format(exptime_str)) print('Dumping training stats --> {}'.format(training_stats)) with open(training_stats, 'w+') as f: for s, l, a in zip(np.arange(trackers['step']), trackers['loss'], trackers['acc']): f.write('{:06d}\t{:3.5f}\t{:3.5f}\n'.format(s, l, a))