def run_training(continue_run): logging.info('EXPERIMENT NAME: %s' % exp_config.experiment_name) init_step = 0 if continue_run: logging.info( '!!!!!!!!!!!!!!!!!!!!!!!!!!!! Continuing previous run !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!' ) try: init_checkpoint_path = utils.get_latest_model_checkpoint_path( log_dir, 'model.ckpt') logging.info('Checkpoint path: %s' % init_checkpoint_path) init_step = int( init_checkpoint_path.split('/')[-1].split('-') [-1]) + 1 # plus 1 b/c otherwise starts with eval logging.info('Latest step was: %d' % init_step) except: logging.warning( '!!! Didnt find init checkpoint. Maybe first run failed. Disabling continue mode...' ) continue_run = False init_step = 0 logging.info( '!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!' ) # Load data data = adni_data_loader_all.load_and_maybe_process_data( input_folder=exp_config.data_root, preprocessing_folder=exp_config.preproc_folder, size=exp_config.image_size, target_resolution=exp_config.target_resolution, label_list=exp_config.label_list, offset=exp_config.offset, rescale_to_one=True, force_overwrite=False) # the following are HDF5 datasets, not numpy arrays images_train = data['images_train'] fieldstr_train = data['field_strength_train'] labels_train = utils.fstr_to_label(fieldstr_train, exp_config.field_strength_list, exp_config.fs_label_list) ages_train = data['age_train'] if exp_config.age_ordinal_regression: ages_train = utils.age_to_ordinal_reg_format(ages_train, bins=exp_config.age_bins) ordinal_reg_weights = utils.get_ordinal_reg_weights(ages_train) else: ages_train = utils.age_to_bins(ages_train, bins=exp_config.age_bins) ordinal_reg_weights = None images_val = data['images_val'] fieldstr_val = data['field_strength_val'] labels_val = utils.fstr_to_label(fieldstr_val, exp_config.field_strength_list, exp_config.fs_label_list) ages_val = data['age_val'] if exp_config.age_ordinal_regression: ages_val = utils.age_to_ordinal_reg_format(ages_val, bins=exp_config.age_bins) else: ages_val = utils.age_to_bins(ages_val, bins=exp_config.age_bins) if exp_config.use_data_fraction: num_images = images_train.shape[0] new_last_index = int(float(num_images) * exp_config.use_data_fraction) logging.warning('USING ONLY FRACTION OF DATA!') logging.warning(' - Number of imgs orig: %d, Number of imgs new: %d' % (num_images, new_last_index)) images_train = images_train[0:new_last_index, ...] labels_train = labels_train[0:new_last_index, ...] logging.info('Data summary:') logging.info('TRAINING') logging.info(' - Images:') logging.info(images_train.shape) logging.info(images_train.dtype) logging.info(' - Labels:') logging.info(labels_train.shape) logging.info(labels_train.dtype) logging.info('VALIDATiON') logging.info(' - Images:') logging.info(images_val.shape) logging.info(images_val.dtype) logging.info(' - Labels:') logging.info(labels_val.shape) logging.info(labels_val.dtype) # Tell TensorFlow that the model will be built into the default Graph. with tf.Graph().as_default(): # Generate placeholders for the images and labels. image_tensor_shape = [exp_config.batch_size] + list( exp_config.image_size) + [1] labels_tensor_shape = [exp_config.batch_size] if exp_config.age_ordinal_regression: ages_tensor_shape = [ exp_config.batch_size, len(exp_config.age_bins) ] else: ages_tensor_shape = [exp_config.batch_size] images_placeholder = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images') diag_placeholder = tf.placeholder(tf.uint8, shape=labels_tensor_shape, name='labels') ages_placeholder = tf.placeholder(tf.uint8, shape=ages_tensor_shape, name='ages') learning_rate_placeholder = tf.placeholder(tf.float32, shape=[], name='learning_rate') training_time_placeholder = tf.placeholder(tf.bool, shape=[], name='training_time') tf.summary.scalar('learning_rate', learning_rate_placeholder) # Build a Graph that computes predictions from the inference model. diag_logits, ages_logits = exp_config.clf_model_handle( images_placeholder, nlabels=exp_config.nlabels, training=training_time_placeholder, n_age_thresholds=len(exp_config.age_bins), bn_momentum=exp_config.bn_momentum) # Add to the Graph the Ops for loss calculation. [loss, diag_loss, age_loss, weights_norm ] = model_mt.loss(diag_logits, ages_logits, diag_placeholder, ages_placeholder, nlabels=exp_config.nlabels, weight_decay=exp_config.weight_decay, diag_weight=exp_config.diag_weight, age_weight=exp_config.age_weight, use_ordinal_reg=exp_config.age_ordinal_regression, ordinal_reg_weights=ordinal_reg_weights) tf.summary.scalar('loss', loss) tf.summary.scalar('diag_loss', diag_loss) tf.summary.scalar('weights_norm_term', weights_norm) if exp_config.momentum is not None: optimiser = exp_config.optimizer_handle( learning_rate=learning_rate_placeholder, momentum=exp_config.momentum) else: optimiser = exp_config.optimizer_handle( learning_rate=learning_rate_placeholder) # create a copy of all trainable variables with `0` as initial values t_vars = tf.global_variables() #tf.trainable_variables() accum_tvars = [ tf.Variable(tf.zeros_like(tv.initialized_value()), trainable=False) for tv in t_vars ] # create a op to initialize all accums vars zero_ops = [tv.assign(tf.zeros_like(tv)) for tv in accum_tvars] # compute gradients for a batch update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): batch_grads_vars = optimiser.compute_gradients(loss, t_vars) # collect the batch gradient into accumulated vars accum_ops = [ accum_tvar.assign_add(batch_grad_var[0]) for accum_tvar, batch_grad_var in zip(accum_tvars, batch_grads_vars) ] accum_normaliser_pl = tf.placeholder(dtype=tf.float32, name='accum_normaliser') accum_mean_op = [ accum_tvar.assign(tf.divide(accum_tvar, accum_normaliser_pl)) for accum_tvar in accum_tvars ] # apply accums gradients with tf.control_dependencies(update_ops): train_op = optimiser.apply_gradients([ (accum_tvar, batch_grad_var[1]) for accum_tvar, batch_grad_var in zip(accum_tvars, batch_grads_vars) ]) eval_diag_loss, eval_ages_loss, pred_labels, ages_softmaxs = model_mt.evaluation( diag_logits, ages_logits, diag_placeholder, ages_placeholder, images_placeholder, diag_weight=exp_config.diag_weight, age_weight=exp_config.age_weight, nlabels=exp_config.nlabels, use_ordinal_reg=exp_config.age_ordinal_regression) # Build the summary Tensor based on the TF collection of Summaries. summary = tf.summary.merge_all() # Add the variable initializer Op. init = tf.global_variables_initializer() # Create a saver for writing training checkpoints. saver = tf.train.Saver(max_to_keep=3) saver_best_diag_f1 = tf.train.Saver(max_to_keep=2) saver_best_xent = tf.train.Saver(max_to_keep=2) # prevents ResourceExhaustError when a lot of memory is used config = tf.ConfigProto() config.gpu_options.allow_growth = True # Do not assign whole gpu memory, just use it on the go config.allow_soft_placement = True # If a operation is not defined in the default device, let it execute in another. # Create a session for running Ops on the Graph. sess = tf.Session(config=config) # Instantiate a SummaryWriter to output summaries and the Graph. summary_writer = tf.summary.FileWriter(log_dir, sess.graph) # with tf.name_scope('monitoring'): val_error_ = tf.placeholder(tf.float32, shape=[], name='val_error_diag') val_error_summary = tf.summary.scalar('validation_loss', val_error_) val_diag_f1_score_ = tf.placeholder(tf.float32, shape=[], name='val_diag_f1') val_f1_diag_summary = tf.summary.scalar('validation_diag_f1', val_diag_f1_score_) val_ages_f1_score_ = tf.placeholder(tf.float32, shape=[], name='val_ages_f1') val_summary = tf.summary.merge( [val_error_summary, val_f1_diag_summary]) train_error_ = tf.placeholder(tf.float32, shape=[], name='train_error_diag') train_error_summary = tf.summary.scalar('training_loss', train_error_) train_diag_f1_score_ = tf.placeholder(tf.float32, shape=[], name='train_diag_f1') train_diag_f1_summary = tf.summary.scalar('training_diag_f1', train_diag_f1_score_) train_ages_f1_score_ = tf.placeholder(tf.float32, shape=[], name='train_ages_f1') train_summary = tf.summary.merge( [train_error_summary, train_diag_f1_summary]) # Run the Op to initialize the variables. sess.run(init) if continue_run: # Restore session saver.restore(sess, init_checkpoint_path) step = init_step curr_lr = exp_config.learning_rate no_improvement_counter = 0 best_val = np.inf last_train = np.inf loss_history = [] loss_gradient = np.inf best_diag_f1_score = 0 # acum_manual = 0 #np.zeros((2,3,3,3,1,32)) for epoch in range(exp_config.max_epochs): logging.info('EPOCH %d' % epoch) sess.run(zero_ops) accum_counter = 0 for batch in iterate_minibatches( images_train, [labels_train, ages_train], batch_size=exp_config.batch_size, augmentation_function=exp_config.augmentation_function, exp_config=exp_config): if exp_config.warmup_training: if step < 50: curr_lr = exp_config.learning_rate / 10.0 elif step == 50: curr_lr = exp_config.learning_rate start_time = time.time() # get a batch x, [y, a] = batch # TEMPORARY HACK (to avoid incomplete batches) if y.shape[0] < exp_config.batch_size: step += 1 continue # Run accumulation feed_dict = { images_placeholder: x, diag_placeholder: y, ages_placeholder: a, learning_rate_placeholder: curr_lr, training_time_placeholder: True } _, loss_value = sess.run([accum_ops, loss], feed_dict=feed_dict) accum_counter += 1 if accum_counter == exp_config.n_accum_batches: # Average gradient over batches sess.run(accum_mean_op, feed_dict={ accum_normaliser_pl: float(exp_config.n_accum_batches) }) sess.run(train_op, feed_dict=feed_dict) # Reset all counters etc. sess.run(zero_ops) accum_counter = 0 duration = time.time() - start_time # Write the summaries and print an overview fairly often. if step % 10 == 0: # Print status to stdout. logging.info('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration)) # Update the events file. summary_str = sess.run(summary, feed_dict=feed_dict) summary_writer.add_summary(summary_str, step) summary_writer.flush() if (step + 1) % exp_config.train_eval_frequency == 0: # Evaluate against the training set logging.info('Training Data Eval:') [train_loss, train_diag_f1, train_ages_f1] = do_eval( sess, eval_diag_loss, eval_ages_loss, pred_labels, ages_softmaxs, images_placeholder, diag_placeholder, ages_placeholder, training_time_placeholder, images_train, [labels_train, ages_train], batch_size=exp_config.batch_size, do_ordinal_reg=exp_config.age_ordinal_regression) train_summary_msg = sess.run(train_summary, feed_dict={ train_error_: train_loss, train_diag_f1_score_: train_diag_f1, train_ages_f1_score_: train_ages_f1 }) summary_writer.add_summary(train_summary_msg, step) loss_history.append(train_loss) if len(loss_history) > 5: loss_history.pop(0) loss_gradient = (loss_history[-5] - loss_history[-1]) / 2 logging.info('loss gradient is currently %f' % loss_gradient) if exp_config.schedule_lr and loss_gradient < exp_config.schedule_gradient_threshold: logging.warning('Reducing learning rate!') curr_lr /= 10.0 logging.info('Learning rate changed to: %f' % curr_lr) # reset loss history to give the optimisation some time to start decreasing again loss_gradient = np.inf loss_history = [] if train_loss <= last_train: # best_train: logging.info('Decrease in training error!') else: logging.info( 'No improvment in training error for %d steps' % no_improvement_counter) last_train = train_loss # Save a checkpoint and evaluate the model periodically. if (step + 1) % exp_config.val_eval_frequency == 0: checkpoint_file = os.path.join(log_dir, 'model.ckpt') saver.save(sess, checkpoint_file, global_step=step) # Evaluate against the validation set. logging.info('Validation Data Eval:') [val_loss, val_diag_f1, val_ages_f1] = do_eval( sess, eval_diag_loss, eval_ages_loss, pred_labels, ages_softmaxs, images_placeholder, diag_placeholder, ages_placeholder, training_time_placeholder, images_val, [labels_val, ages_val], batch_size=exp_config.batch_size, do_ordinal_reg=exp_config.age_ordinal_regression) val_summary_msg = sess.run(val_summary, feed_dict={ val_error_: val_loss, val_diag_f1_score_: val_diag_f1, val_ages_f1_score_: val_ages_f1 }) summary_writer.add_summary(val_summary_msg, step) if val_diag_f1 >= best_diag_f1_score: best_diag_f1_score = val_diag_f1 best_file = os.path.join( log_dir, 'model_best_diag_f1.ckpt') saver_best_diag_f1.save(sess, best_file, global_step=step) logging.info( 'Found new best DIAGNOSIS F1 score on validation set! - %f - Saving model_best_diag_f1.ckpt' % val_diag_f1) if val_loss <= best_val: best_val = val_loss best_file = os.path.join(log_dir, 'model_best_xent.ckpt') saver_best_xent.save(sess, best_file, global_step=step) logging.info( 'Found new best crossentropy on validation set! - %f - Saving model_best_xent.ckpt' % val_loss) step += 1 sess.close()
def generate_and_evaluate_ad_classification(gan_experiment_path_list, clf_experiment_path, score_functions, image_saving_indices=set(), image_saving_path=None, max_batch_size=np.inf): """ :param gan_experiment_path_list: list of GAN experiment paths to be evaluated. They must all have the same image settings and source/target field strengths as the classifier only gan experiments with the same source and target field strength are permitted :param clf_experiment_path: AD classifier used :param verbose: boolean. log all image classifications :param image_saving_indices: set of indices of the images to be saved :param image_saving_path: where to save the images. They are saved in subfolders for each experiment :return: """ clf_config, logdir_clf = utils.load_log_exp_config(clf_experiment_path) # Load data data = adni_data_loader_all.load_and_maybe_process_data( input_folder=clf_config.data_root, preprocessing_folder=clf_config.preproc_folder, size=clf_config.image_size, target_resolution=clf_config.target_resolution, label_list=clf_config.label_list, offset=clf_config.offset, rescale_to_one=clf_config.rescale_to_one, force_overwrite=False ) # extract images and indices of source/target images for the test set images_test = data['images_test'] labels_test = data['diagnosis_test'] ages_test = data['age_test'] im_s = clf_config.image_size batch_size = min(clf_config.batch_size, std_params.batch_size, max_batch_size) logging.info('batch size %d is used for everything' % batch_size) img_tensor_shape = [batch_size, im_s[0], im_s[1], im_s[2], 1] clf_remainder_batch_size = images_test.shape[0] % batch_size # prevents ResourceExhaustError when a lot of memory is used config = tf.ConfigProto() config.gpu_options.allow_growth = True # Do not assign whole gpu memory, just use it on the go config.allow_soft_placement = True # If a operation is not defined in the default device, let it execute in another. # open field strength classifier save file from the selected experiment logging.info("loading Alzheimer's disease classifier") graph_clf, image_pl, predictions_clf_op, init_clf_op, saver_clf = build_clf_graph(img_tensor_shape, clf_config) # logging.info("getting savepoint with the best cross entropy") # init_checkpoint_path_clf = get_latest_checkpoint_and_log(logdir_clf, 'model_best_xent.ckpt') logging.info("getting savepoint with the best f1 score") init_checkpoint_path_clf = get_latest_checkpoint_and_log(logdir_clf, 'model_best_diag_f1.ckpt') sess_clf = tf.Session(config=config, graph=graph_clf) sess_clf.run(init_clf_op) saver_clf.restore(sess_clf, init_checkpoint_path_clf) # make a separate graph for the last batch where the batchsize is smaller if clf_remainder_batch_size > 0: img_tensor_shape_gan_remainder = [clf_remainder_batch_size, im_s[0], im_s[1], im_s[2], 1] graph_clf_rem, image_pl_rem, predictions_clf_op_rem, init_clf_op_rem, saver_clf_rem = build_clf_graph(img_tensor_shape_gan_remainder, clf_config) sess_clf_rem = tf.Session(config=config, graph=graph_clf_rem) sess_clf_rem.run(init_clf_op_rem) saver_clf_rem.restore(sess_clf_rem, init_checkpoint_path_clf) # classifiy all real test images logging.info('classify all original images') real_pred = [] for batch in iterate_minibatches(images_test, [labels_test, ages_test], batch_size=batch_size, exp_config=clf_config, map_labels_to_standard_range=False, shuffle_data=False, skip_remainder=False): # ignore the labels because data are in order, which means the label list in data can be used image_batch, [real_label, real_age] = batch current_batch_size = image_batch.shape[0] if current_batch_size < batch_size: clf_prediction_real = sess_clf_rem.run(predictions_clf_op_rem, feed_dict={image_pl_rem: image_batch}) else: clf_prediction_real = sess_clf.run(predictions_clf_op, feed_dict={image_pl: image_batch}) real_pred = real_pred + list(clf_prediction_real['label']) logging.info('new image batch') logging.info('ground truth labels: ' + str(real_label)) logging.info('predicted labels: ' + str(clf_prediction_real['label'])) gan_config0, logdir_gan0 = utils.load_log_exp_config(gan_experiment_path_list[0]) source_indices = [] target_indices = [] source_true_labels = [] target_true_labels = [] for i, field_strength in enumerate(data['field_strength_test']): if field_strength == gan_config0.source_field_strength: source_indices.append(i) source_true_labels.append(labels_test[i]) elif field_strength == gan_config0.target_field_strength: target_indices.append(i) target_true_labels.append(labels_test[i]) # balance the test set (source_indices, source_true_labels), ( target_indices, target_true_labels) = utils.balance_source_target( (source_indices, source_true_labels), (target_indices, target_true_labels), random_seed=0) source_pred = [pred for ind, pred in enumerate(real_pred) if ind in source_indices] target_pred = [pred for ind, pred in enumerate(real_pred) if ind in target_indices] assert len(source_pred) == len(source_true_labels) assert len(target_pred) == len(target_true_labels) # no unexpected labels assert all([label in clf_config.label_list for label in source_true_labels]) assert all([label in clf_config.label_list for label in target_true_labels]) assert all([label in clf_config.label_list for label in source_pred]) assert all([label in clf_config.label_list for label in target_pred]) num_source_images = len(source_indices) num_target_images = len(target_indices) source_label_count = Counter(source_true_labels) target_label_count = Counter(target_true_labels) logging.info('Data summary:') logging.info(' - Domains:') logging.info('number of source images: ' + str(num_source_images)) logging.info('source label distribution ' + str(source_label_count)) logging.info('number of target images: ' + str(num_target_images)) logging.info('target label distribution ' + str(target_label_count)) assert num_source_images == num_target_images assert source_label_count == target_label_count #2d image saving folder folder_2d = 'coronal_2d' image_saving_path2d = os.path.join(image_saving_path, folder_2d) utils.makefolder(image_saving_path2d) # save real images target_image_path = os.path.join(image_saving_path, 'target') source_image_path = os.path.join(image_saving_path, 'source') utils.makefolder(target_image_path) utils.makefolder(source_image_path) target_image_path2d = os.path.join(image_saving_path2d, 'target') source_image_path2d = os.path.join(image_saving_path2d, 'source') utils.makefolder(target_image_path2d) utils.makefolder(source_image_path2d) sorted_saving_indices = sorted(image_saving_indices) target_saving_indices = [target_indices[index] for index in sorted_saving_indices] for target_index in target_saving_indices: target_img_name = 'target_img_%.1fT_diag%d_ind%d' % (gan_config0.target_field_strength, labels_test[target_index], target_index) utils.save_image_and_cut(images_test[target_index], target_img_name, target_image_path, target_image_path2d) logging.info(target_img_name + ' saved') source_saving_indices = [source_indices[index] for index in sorted_saving_indices] for source_index in source_saving_indices: source_img_name = 'source_img_%.1fT_diag%d_ind%d' % (gan_config0.source_field_strength, labels_test[source_index], source_index) utils.save_image_and_cut(images_test[source_index], source_img_name, source_image_path, source_image_path2d) logging.info(source_img_name + ' saved') logging.info('source and target images saved') gan_remainder_batch_size = num_source_images % batch_size scores = {} for gan_experiment_path in gan_experiment_path_list: gan_config, logdir_gan = utils.load_log_exp_config(gan_experiment_path) gan_experiment_name = gan_config.experiment_name # make sure the experiments all have the same configuration as the classifier assert gan_config.source_field_strength == gan_config0.source_field_strength assert gan_config.target_field_strength == gan_config0.target_field_strength assert gan_config.image_size == clf_config.image_size assert gan_config.target_resolution == clf_config.target_resolution assert gan_config.offset == clf_config.offset logging.info('\nGAN Experiment (%.1f T to %.1f T): %s' % (gan_config.source_field_strength, gan_config.target_field_strength, gan_experiment_name)) logging.info(gan_config) # open GAN save file from the selected experiment logging.info('loading GAN') # open the latest GAN savepoint init_checkpoint_path_gan = get_latest_checkpoint_and_log(logdir_gan, 'model.ckpt') # build a separate graph for the generator graph_generator, generator_img_pl, x_fake_op, init_gan_op, saver_gan = test_utils.build_gen_graph(img_tensor_shape, gan_config) # Create a session for running Ops on the Graph. sess_gan = tf.Session(config=config, graph=graph_generator) # Run the Op to initialize the variables. sess_gan.run(init_gan_op) saver_gan.restore(sess_gan, init_checkpoint_path_gan) # path where the generated images are saved experiment_generate_path = os.path.join(image_saving_path, gan_experiment_name) experiment_generate_path2d = os.path.join(image_saving_path2d, gan_experiment_name) # make a folder for the generated images utils.makefolder(experiment_generate_path) utils.makefolder(experiment_generate_path2d) # make separate graphs for the last batch where the batchsize is smaller if clf_remainder_batch_size > 0: img_tensor_shape_gan_remainder = [gan_remainder_batch_size, im_s[0], im_s[1], im_s[2], 1] # classifier graph_clf_rem, image_pl_rem, predictions_clf_op_rem, init_clf_op_rem, saver_clf_rem = build_clf_graph(img_tensor_shape_gan_remainder, clf_config) sess_clf_rem = tf.Session(config=config, graph=graph_clf_rem) sess_clf_rem.run(init_clf_op_rem) saver_clf_rem.restore(sess_clf_rem, init_checkpoint_path_clf) # generator graph_generator_rem, generator_img_rem_pl, x_fake_op_rem, init_gan_op_rem, saver_gan_rem = \ test_utils.build_gen_graph(img_tensor_shape_gan_remainder, gan_config) # Create a session for running Ops on the Graph. sess_gan_rem = tf.Session(config=config, graph=graph_generator_rem) # Run the Op to initialize the variables. sess_gan_rem.run(init_gan_op_rem) saver_gan_rem.restore(sess_gan_rem, init_checkpoint_path_gan) logging.info('image generation begins') generated_pred = [] batch_beginning_index = 0 # loops through all images from the source domain for batch in iterate_minibatches(images_test, [labels_test, ages_test], batch_size=batch_size, exp_config=clf_config, map_labels_to_standard_range=False, selection_indices=source_indices, shuffle_data=False, skip_remainder=False): # ignore the labels because data are in order, which means the label list in data can be used image_batch, [real_label, real_age] = batch current_batch_size = image_batch.shape[0] if current_batch_size < batch_size: fake_img = sess_gan_rem.run(x_fake_op_rem, feed_dict={generator_img_rem_pl: image_batch}) # classify fake image clf_prediction_fake = sess_clf_rem.run(predictions_clf_op_rem, feed_dict={image_pl_rem: fake_img}) else: fake_img = sess_gan.run(x_fake_op, feed_dict={generator_img_pl: image_batch}) # classify fake image clf_prediction_fake = sess_clf.run(predictions_clf_op, feed_dict={image_pl: fake_img}) generated_pred = generated_pred + list(clf_prediction_fake['label']) # save images current_source_indices = range(batch_beginning_index, batch_beginning_index + current_batch_size) # test whether minibatches are really iterated in order by checking if the labels are as expected assert [source_true_labels[i] for i in current_source_indices] == list(real_label) source_indices_to_save = image_saving_indices.intersection(set(current_source_indices)) for source_index in source_indices_to_save: batch_index = source_index - batch_beginning_index # index of the image in the complete test data global_index = source_indices[source_index] generated_img_name = 'generated_img_%.1fT_diag%d_ind%d' % (gan_config.target_field_strength, labels_test[global_index], global_index) utils.save_image_and_cut(np.squeeze(fake_img[batch_index]), generated_img_name, experiment_generate_path, experiment_generate_path2d) logging.info(generated_img_name + ' saved') # save the difference g(xs)-xs corresponding_source_img = images_test[global_index] difference_image_gs = np.squeeze(fake_img[batch_index]) - corresponding_source_img difference_img_name = 'difference_img_%.1fT_diag%d_ind%d' % (gan_config.target_field_strength, labels_test[global_index], global_index) utils.save_image_and_cut(difference_image_gs, difference_img_name, experiment_generate_path, experiment_generate_path2d) logging.info(difference_img_name + ' saved') logging.info('new image batch') logging.info('ground truth labels: ' + str(real_label)) logging.info('predicted labels for generated images: ' + str(clf_prediction_fake['label'])) # no unexpected labels assert all([label in clf_config.label_list for label in clf_prediction_fake['label']]) batch_beginning_index += current_batch_size logging.info('generated prediction for %s: %s' % (gan_experiment_name, str(generated_pred))) scores[gan_experiment_name] = evaluate_scores(source_true_labels, generated_pred, score_functions) logging.info('source prediction: ' + str(source_pred)) logging.info('source ground truth: ' + str(source_true_labels)) logging.info('target prediction: ' + str(target_pred)) logging.info('target ground truth: ' + str(target_true_labels)) scores['source_%.1fT' % gan_config0.source_field_strength] = evaluate_scores(source_true_labels, source_pred, score_functions) scores['target_%.1fT' % gan_config0.target_field_strength] = evaluate_scores(target_true_labels, target_pred, score_functions) return scores
import clf_model_multitask as model_mt import utils from batch_generator_list import iterate_minibatches import data_utils import gan_model from collections import OrderedDict import csv from experiments.adni_clf import allconv_bn as exp_config # Load data data = adni_data_loader_all.load_and_maybe_process_data( input_folder=exp_config.data_root, preprocessing_folder=exp_config.preproc_folder, size=exp_config.image_size, target_resolution=exp_config.target_resolution, label_list=exp_config.label_list, offset=exp_config.offset, rescale_to_one=exp_config.rescale_to_one, force_overwrite=False) for tt in ['train', 'test', 'val']: print(len(np.unique(data['rid_%s' % tt]))) # make list of index, label, rid test_labels = data['diagnosis_test'] logging.info(test_labels) with open(os.path.join(sys_config.project_root, 'results/final/label_list.csv'), 'w+', newline='') as csvfile:
def generate_with_noise(gan_experiment_path_list, noise_list, image_saving_indices=set(), image_saving_path3d=None, image_saving_path2d=None): """ :param gan_experiment_path_list: list of GAN experiment paths to be evaluated. They must all have the same image settings and source/target field strengths as the classifier :param clf_experiment_path: AD classifier used :param image_saving_indices: set of indices of the images to be saved :param image_saving_path: where to save the images. They are saved in subfolders for each experiment :return: """ batch_size = 1 logging.info('batch size %d is used for everything' % batch_size) for gan_experiment_path in gan_experiment_path_list: gan_config, logdir_gan = utils.load_log_exp_config(gan_experiment_path) gan_experiment_name = gan_config.experiment_name log_dir_ending = logdir_gan.split('_')[-1] continued_experiment = (log_dir_ending == 'cont') if continued_experiment: gan_experiment_name += '_cont' # make sure the noise has the right dimension assert gan_config.use_generator_input_noise assert gan_config.generator_input_noise_shape[ 1:] == std_params.generator_input_noise_shape[1:] # Load data data = adni_data_loader_all.load_and_maybe_process_data( input_folder=gan_config.data_root, preprocessing_folder=gan_config.preproc_folder, size=gan_config.image_size, target_resolution=gan_config.target_resolution, label_list=gan_config.label_list, offset=gan_config.offset, rescale_to_one=gan_config.rescale_to_one, force_overwrite=False) # extract images and indices of source/target images for the test set images_test = data['images_test'] im_s = gan_config.image_size img_tensor_shape = [batch_size, im_s[0], im_s[1], im_s[2], 1] logging.info('\nGAN Experiment (%.1f T to %.1f T): %s' % (gan_config.source_field_strength, gan_config.target_field_strength, gan_experiment_name)) logging.info(gan_config) # open GAN save file from the selected experiment # prevents ResourceExhaustError when a lot of memory is used config = tf.ConfigProto() config.gpu_options.allow_growth = True # Do not assign whole gpu memory, just use it on the go config.allow_soft_placement = True # If a operation is not defined in the default device, let it execute in another. source_indices = [] target_indices = [] for i, field_strength in enumerate(data['field_strength_test']): if field_strength == gan_config.source_field_strength: source_indices.append(i) elif field_strength == gan_config.target_field_strength: target_indices.append(i) num_source_images = len(source_indices) num_target_images = len(target_indices) logging.info('Data summary:') logging.info(' - Images:') logging.info(images_test.shape) logging.info(images_test.dtype) logging.info(' - Domains:') logging.info('number of source images: ' + str(num_source_images)) logging.info('number of target images: ' + str(num_target_images)) # save real images source_image_path = os.path.join(image_saving_path3d, 'source') utils.makefolder(source_image_path) sorted_saving_indices = sorted(image_saving_indices) source_saving_indices = [ source_indices[index] for index in sorted_saving_indices ] for source_index in source_saving_indices: source_img_name = 'source_img_%.1fT_%d.nii.gz' % ( gan_config.source_field_strength, source_index) utils.create_and_save_nii( images_test[source_index], os.path.join(source_image_path, source_img_name)) logging.info(source_img_name + ' saved') logging.info('source images saved') logging.info('loading GAN') # open the latest GAN savepoint init_checkpoint_path_gan, last_gan_step = utils.get_latest_checkpoint_and_step( logdir_gan, 'model.ckpt') logging.info(init_checkpoint_path_gan) # build a separate graph for the generator graph_generator, generator_img_pl, z_noise_pl, x_fake_op, init_gan_op, saver_gan = build_gen_graph( img_tensor_shape, gan_config) # Create a session for running Ops on the Graph. sess_gan = tf.Session(config=config, graph=graph_generator) # Run the Op to initialize the variables. sess_gan.run(init_gan_op) saver_gan.restore(sess_gan, init_checkpoint_path_gan) # path where the generated images are saved experiment_generate_path_3d = os.path.join( image_saving_path_3d, gan_experiment_name + ('_%.1fT_source' % gan_config.source_field_strength)) # make a folder for the generated images utils.makefolder(experiment_generate_path_3d) # path where the generated image 2d cuts are saved experiment_generate_path_2d = os.path.join( image_saving_path_2d, gan_experiment_name + ('_%.1fT_source' % gan_config.source_field_strength)) # make a folder for the generated images utils.makefolder(experiment_generate_path_2d) logging.info('image generation begins') generated_pred = [] batch_beginning_index = 0 # loops through all images from the source domain for image_index, curr_img in zip( source_saving_indices, itertools.compress(images_test, source_saving_indices)): img_folder_name = 'image_test%d' % image_index curr_img_path_3d = os.path.join(experiment_generate_path_3d, img_folder_name) utils.makefolder(curr_img_path_3d) curr_img_path_2d = os.path.join(experiment_generate_path_2d, img_folder_name) utils.makefolder(curr_img_path_2d) # save source image source_img_name = 'source_img' utils.save_image_and_cut(np.squeeze(curr_img), source_img_name, curr_img_path_3d, curr_img_path_2d, vmin=-1, vmax=1) logging.info(source_img_name + ' saved') img_list = [] for noise_index, noise in enumerate(noise_list): fake_img = sess_gan.run(x_fake_op, feed_dict={ generator_img_pl: np.reshape(curr_img, img_tensor_shape), z_noise_pl: noise }) fake_img = np.squeeze(fake_img) # make sure the dimensions are right assert len(fake_img.shape) == 3 img_list.append(fake_img) generated_img_name = 'generated_img_noise_%d' % (noise_index) utils.save_image_and_cut(np.squeeze(fake_img), generated_img_name, curr_img_path_3d, curr_img_path_2d, vmin=-1, vmax=1) logging.info(generated_img_name + ' saved') # save the difference g(xs)-xs difference_image_gs = np.squeeze(fake_img) - curr_img difference_img_name = 'difference_img_noise_%d' % (noise_index) utils.save_image_and_cut(difference_image_gs, difference_img_name, curr_img_path_3d, curr_img_path_2d, vmin=-1, vmax=1) logging.info(difference_img_name + ' saved') # works because axis 0 all_imgs = np.stack(img_list, axis=0) std_img = np.std(all_imgs, axis=0) std_img_name = 'std_img' utils.save_image_and_cut(std_img, std_img_name, curr_img_path_3d, curr_img_path_2d, vmin=0, vmax=1) logging.info(std_img_name + ' saved') logging.info('generated all images for %s' % (gan_experiment_name))
def run_training(continue_run, log_dir): logging.info('===== RUNNING EXPERIMENT ========') logging.info(exp_config.experiment_name) logging.info('=================================') init_step = 0 if continue_run: logging.info('!!!!!!!!!!!!!!!!!!!!!!!!!!!! Continuing previous run !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!') try: init_checkpoint_path = utils.get_latest_model_checkpoint_path(log_dir, 'model.ckpt') logging.info('Checkpoint path: %s' % init_checkpoint_path) init_step = int(init_checkpoint_path.split('/')[-1].split('-')[-1]) + 1 # plus 1 b/c otherwise starts with eval logging.info('Latest step was: %d' % init_step) log_dir += '_cont' except: logging.warning('!!! Didnt find init checkpoint. Maybe first run failed. Disabling continue mode...') continue_run = False init_step = 0 logging.info('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!') # import data data = adni_data_loader_all.load_and_maybe_process_data( input_folder=exp_config.data_root, preprocessing_folder=exp_config.preproc_folder, size=exp_config.image_size, target_resolution=exp_config.target_resolution, label_list = exp_config.label_list, offset=exp_config.offset, rescale_to_one=exp_config.rescale_to_one, force_overwrite=False ) # extract images and indices of source/target images for the training and validation set images_train, source_images_train_ind, target_images_train_ind,\ images_val, source_images_val_ind, target_images_val_ind = data_utils.get_images_and_fieldstrength_indices( data, exp_config.source_field_strength, exp_config.target_field_strength) generator = exp_config.generator discriminator = exp_config.discriminator z_sampler_train = iterate_minibatches_endlessly(images_train, batch_size=exp_config.batch_size, exp_config=exp_config, selection_indices=source_images_train_ind) x_sampler_train = iterate_minibatches_endlessly(images_train, batch_size=exp_config.batch_size, exp_config=exp_config, selection_indices=target_images_train_ind) with tf.Graph().as_default(): # Generate placeholders for the images and labels. im_s = exp_config.image_size training_placeholder = tf.placeholder(tf.bool, name='training_phase') if exp_config.use_generator_input_noise: noise_in_gen_pl = tf.random_uniform(shape=exp_config.generator_input_noise_shape, minval=-1, maxval=1) else: noise_in_gen_pl = None # target image batch x_pl = tf.placeholder(tf.float32, [exp_config.batch_size, im_s[0], im_s[1], im_s[2], exp_config.n_channels], name='x') # source image batch z_pl = tf.placeholder(tf.float32, [exp_config.batch_size, im_s[0], im_s[1], im_s[2], exp_config.n_channels], name='z') # generated fake image batch x_pl_ = generator(z_pl, noise_in_gen_pl, training_placeholder) # difference between generated and source images diff_img_pl = x_pl_ - z_pl # visualize the images by showing one slice of them in the z direction tf.summary.image('sample_outputs', tf_utils.put_kernels_on_grid3d(x_pl_, exp_config.cut_axis, exp_config.cut_index, rescale_mode='manual', input_range=exp_config.image_range)) tf.summary.image('sample_xs', tf_utils.put_kernels_on_grid3d(x_pl, exp_config.cut_axis, exp_config.cut_index, rescale_mode='manual', input_range=exp_config.image_range)) tf.summary.image('sample_zs', tf_utils.put_kernels_on_grid3d(z_pl, exp_config.cut_axis, exp_config.cut_index, rescale_mode='manual', input_range=exp_config.image_range)) tf.summary.image('sample_difference_gx-x', tf_utils.put_kernels_on_grid3d(diff_img_pl, exp_config.cut_axis, exp_config.cut_index, rescale_mode='centered', cutoff_abs=exp_config.diff_threshold)) # output of the discriminator for real image d_pl = discriminator(x_pl, training_placeholder, scope_reuse=False) # output of the discriminator for fake image d_pl_ = discriminator(x_pl_, training_placeholder, scope_reuse=True) d_hat = None x_hat = None if exp_config.improved_training: epsilon = tf.random_uniform([], 0.0, 1.0) x_hat = epsilon * x_pl + (1 - epsilon) * x_pl_ d_hat = discriminator(x_hat, training_placeholder, scope_reuse=True) dist_l1 = tf.reduce_mean(tf.abs(diff_img_pl)) # nr means no regularization, meaning the loss without the regularization term discriminator_train_op, generator_train_op, \ disc_loss_pl, gen_loss_pl, \ disc_loss_nr_pl, gen_loss_nr_pl = gan_model.training_ops(d_pl, d_pl_, optimizer_handle=exp_config.optimizer_handle, learning_rate=exp_config.learning_rate, l1_img_dist=dist_l1, w_reg_img_dist_l1=exp_config.w_reg_img_dist_l1, w_reg_gen_l1=exp_config.w_reg_gen_l1, w_reg_disc_l1=exp_config.w_reg_disc_l1, w_reg_gen_l2=exp_config.w_reg_gen_l2, w_reg_disc_l2=exp_config.w_reg_disc_l2, d_hat=d_hat, x_hat=x_hat, scale=exp_config.scale) # Build the operation for clipping the discriminator weights d_clip_op = gan_model.clip_op() # Put L1 distance of generated image and original image on summary dist_l1_summary_op = tf.summary.scalar('L1_distance_to_source_img', dist_l1) # Build the summary Tensor based on the TF collection of Summaries. summary_op = tf.summary.merge_all() # validation summaries val_disc_loss_pl = tf.placeholder(tf.float32, shape=[], name='disc_val_loss') disc_val_summary_op = tf.summary.scalar('validation_discriminator_loss', val_disc_loss_pl) val_gen_loss_pl = tf.placeholder(tf.float32, shape=[], name='gen_val_loss') gen_val_summary_op = tf.summary.scalar('validation_generator_loss', val_gen_loss_pl) val_summary_op = tf.summary.merge([disc_val_summary_op, gen_val_summary_op]) # Add the variable initializer Op. init = tf.global_variables_initializer() # Create a savers for writing training checkpoints. saver_latest = tf.train.Saver(max_to_keep=3) saver_best_disc = tf.train.Saver(max_to_keep=3) # disc loss is scaled negative EM distance # prevents ResourceExhaustError when a lot of memory is used config = tf.ConfigProto() config.gpu_options.allow_growth = True # Do not assign whole gpu memory, just use it on the go config.allow_soft_placement = True # If a operation is not defined in the default device, let it execute in another. # Create a session for running Ops on the Graph. sess = tf.Session(config=config) summary_writer = tf.summary.FileWriter(log_dir, sess.graph) # Run the Op to initialize the variables. sess.run(init) if continue_run: # Restore session saver_latest.restore(sess, init_checkpoint_path) # initialize value of lowest (i. e. best) discriminator loss best_d_loss = np.inf for step in range(init_step, 1000000): start_time = time.time() # discriminator training iterations d_iters = 5 if step % 500 == 0 or step < 25: d_iters = 100 for _ in range(d_iters): x = next(x_sampler_train) z = next(z_sampler_train) # train discriminator sess.run(discriminator_train_op, feed_dict={z_pl: z, x_pl: x, training_placeholder: True}) if not exp_config.improved_training: sess.run(d_clip_op) elapsed_time = time.time() - start_time # train generator x = next(x_sampler_train) # why not sample a new x?? z = next(z_sampler_train) sess.run(generator_train_op, feed_dict={z_pl: z, x_pl: x, training_placeholder: True}) if step % exp_config.update_tensorboard_frequency == 0: x = next(x_sampler_train) z = next(z_sampler_train) g_loss_train, d_loss_train, summary_str = sess.run( [gen_loss_nr_pl, disc_loss_nr_pl, summary_op], feed_dict={z_pl: z, x_pl: x, training_placeholder: False}) summary_writer.add_summary(summary_str, step) summary_writer.flush() logging.info("[Step: %d], generator loss: %g, discriminator_loss: %g" % (step, g_loss_train, d_loss_train)) logging.info(" - elapsed time for one step: %f secs" % elapsed_time) if step % exp_config.validation_frequency == 0: z_sampler_val = iterate_minibatches_endlessly(images_val, batch_size=exp_config.batch_size, exp_config=exp_config, selection_indices=source_images_val_ind) x_sampler_val = iterate_minibatches_endlessly(images_val, batch_size=exp_config.batch_size, exp_config=exp_config, selection_indices=target_images_val_ind) # evaluate the validation batch with batch_size images (from each domain) at a time g_loss_val_list = [] d_loss_val_list = [] for _ in range(exp_config.num_val_batches): x = next(x_sampler_val) z = next(z_sampler_val) g_loss_val, d_loss_val = sess.run( [gen_loss_nr_pl, disc_loss_nr_pl], feed_dict={z_pl: z, x_pl: x, training_placeholder: False}) g_loss_val_list.append(g_loss_val) d_loss_val_list.append(d_loss_val) g_loss_val_avg = np.mean(g_loss_val_list) d_loss_val_avg = np.mean(d_loss_val_list) validation_summary_str = sess.run(val_summary_op, feed_dict={val_disc_loss_pl: d_loss_val_avg, val_gen_loss_pl: g_loss_val_avg} ) summary_writer.add_summary(validation_summary_str, step) summary_writer.flush() # save best variables (if discriminator loss is the lowest yet) if d_loss_val_avg <= best_d_loss: best_d_loss = d_loss_val_avg best_file = os.path.join(log_dir, 'model_best_d_loss.ckpt') saver_best_disc.save(sess, best_file, global_step=step) logging.info('Found new best discriminator loss on validation set! - %f - Saving model_best_d_loss.ckpt' % best_d_loss) logging.info("[Validation], generator loss: %g, discriminator_loss: %g" % (g_loss_val_avg, d_loss_val_avg)) # Write the summaries and print an overview fairly often. if step % exp_config.save_frequency == 0: saver_latest.save(sess, os.path.join(log_dir, 'model.ckpt'), global_step=step)
def classifier_test(clf_experiment_path, score_functions, batch_size=1, balanced_test=True, checkpoint_file_name='model_best_xent.ckpt'): """ :param clf_experiment_path: AD classifier used :return: """ clf_config, logdir_clf = utils.load_log_exp_config(clf_experiment_path) # Load data data = adni_data_loader_all.load_and_maybe_process_data( input_folder=clf_config.data_root, preprocessing_folder=clf_config.preproc_folder, size=clf_config.image_size, target_resolution=clf_config.target_resolution, label_list=clf_config.label_list, offset=clf_config.offset, rescale_to_one=clf_config.rescale_to_one, force_overwrite=False) # extract images and indices of source/target images for the test set images_test = data['images_test'] labels_test = data['diagnosis_test'] ages_test = data['age_test'] logging.info('batch size %d is used for classifier' % batch_size) img_tensor_shape = [None] + list(clf_config.image_size) + [1] # prevents ResourceExhaustError when a lot of memory is used config = tf.ConfigProto() config.gpu_options.allow_growth = True # Do not assign whole gpu memory, just use it on the go config.allow_soft_placement = True # If a operation is not defined in the default device, let it execute in another. # open field strength classifier save file from the selected experiment logging.info("loading Alzheimer's disease classifier") graph_clf, image_pl, predictions_clf_op, init_clf_op, saver_clf = test_utils.build_clf_graph( img_tensor_shape, clf_config) logging.info("getting savepoint %s" % checkpoint_file_name) init_checkpoint_path_clf, latest_step = utils.get_latest_checkpoint_and_step( logdir_clf, checkpoint_file_name) # logging.info("getting savepoint with the best f1 score") # init_checkpoint_path_clf = get_latest_checkpoint_and_log(logdir_clf, 'model_best_diag_f1.ckpt') sess_clf = tf.Session(config=config, graph=graph_clf) sess_clf.run(init_clf_op) # probably not necessary saver_clf.restore(sess_clf, init_checkpoint_path_clf) # classifiy all real test images logging.info('classify all test images') all_predictions = [] ground_truth_labels = [] for batch in iterate_minibatches(images_test, [labels_test, ages_test], batch_size=batch_size, exp_config=clf_config, map_labels_to_standard_range=False, shuffle_data=False, skip_remainder=False): # ignore the labels because data are in order, which means the label list in data can be used image_batch, [real_label, real_age] = batch current_batch_size = image_batch.shape[0] clf_prediction_batch = sess_clf.run(predictions_clf_op, feed_dict={image_pl: image_batch}) all_predictions = all_predictions + list(clf_prediction_batch['label']) ground_truth_labels = ground_truth_labels + list(real_label) logging.info('new image batch') logging.info('ground truth labels: ' + str(real_label)) logging.info('predicted labels: ' + str(clf_prediction_batch['label'])) # check that the data has really been iterated in order and in full assert np.array_equal(ground_truth_labels, labels_test) source_indices = [] target_indices = [] source_true_labels = [] source_pred = [] target_true_labels = [] target_pred = [] for i, field_strength in enumerate(data['field_strength_test']): if field_strength == clf_config.source_field_strength: source_indices.append(i) source_true_labels.append(labels_test[i]) elif field_strength == clf_config.target_field_strength: target_indices.append(i) target_true_labels.append(labels_test[i]) # check that the source and target images together are all images all_indices = source_indices + target_indices all_indices.sort() assert np.array_equal(all_indices, range(images_test.shape[0])) source_label_count = Counter(source_true_labels) target_label_count = Counter(target_true_labels) logging.info('before balancing') logging.info('source labels count: ' + str(source_label_count)) logging.info('target labels count: ' + str(target_label_count)) # throw away some data from source and target such that they have the same AD/normal ratio # this stratified test dataset should make comparisons between the scores with the different test sets more meaningful # the seed makes sure that the new test data are always the same if balanced_test: (source_indices_new, source_true_labels_new), ( target_indices_new, target_true_labels_new) = utils.balance_source_target( (source_indices, source_true_labels), (target_indices, target_true_labels), random_seed=0) all_indices = source_indices_new + target_indices_new all_indices.sort() labels_test = [ label for ind, label in enumerate(labels_test) if ind in all_indices ] # to make sure the new indices and labels are subsets of the old ones source_label_count = Counter(source_true_labels_new) target_label_count = Counter(target_true_labels_new) logging.info('balanced the test set') logging.info('source labels count: ' + str(source_label_count)) logging.info('target labels count: ' + str(target_label_count)) source_set_new = set(source_indices_new) target_set_new = set(target_indices_new) # check if the new indices are a subset of the old ones assert source_set_new <= set(source_indices) assert target_set_new <= set(target_indices) # check for duplicates assert len(source_set_new) == len(source_indices_new) assert len(target_set_new) == len(target_indices_new) # make tuples of (index, label) to check if the new index label pairs are a subset of the old ones source_tuples = utils.tuple_of_lists_to_list_of_tuples( (source_indices, source_true_labels)) target_tuples = utils.tuple_of_lists_to_list_of_tuples( (target_indices, target_true_labels)) source_tuples_new = utils.tuple_of_lists_to_list_of_tuples( (source_indices_new, source_true_labels_new)) target_tuples_new = utils.tuple_of_lists_to_list_of_tuples( (target_indices_new, target_true_labels_new)) assert set(source_tuples_new) <= set(source_tuples) assert set(target_tuples_new) <= set(target_tuples) [(source_indices, source_true_labels), (target_indices, target_true_labels)] = \ [(source_indices_new, source_true_labels_new), (target_indices_new, target_true_labels_new)] source_pred = [all_predictions[ind] for ind in source_indices] target_pred = [all_predictions[ind] for ind in target_indices] # no unexpected labels assert all( [label in clf_config.label_list for label in source_true_labels]) assert all( [label in clf_config.label_list for label in target_true_labels]) assert all([label in clf_config.label_list for label in source_pred]) assert all([label in clf_config.label_list for label in target_pred]) num_source_images = len(source_indices) num_target_images = len(target_indices) assert set(source_indices).isdisjoint(target_indices) assert num_source_images == len(source_true_labels) assert num_source_images == len(source_true_labels) assert num_target_images == len(target_true_labels) assert num_target_images == len(target_true_labels) assert num_target_images + num_source_images == len(labels_test) if balanced_test: assert num_source_images == num_target_images label_count = Counter(labels_test) assert label_count == source_label_count + target_label_count logging.info('Data summary:') logging.info(' - Images (before reduction):') logging.info(images_test.shape) logging.info(images_test.dtype) logging.info(' - Labels:') logging.info(len(labels_test)) logging.info('number of images for each label') logging.info(label_count) logging.info(' - Domains:') logging.info('number of source images: ' + str(num_source_images)) logging.info('source label distribution ' + str(source_label_count)) logging.info('number of target images: ' + str(num_target_images)) logging.info('target label distribution ' + str(target_label_count)) # find out how many unique subjects there are in the test set rid_numbers = data['rid_test'] reduced_rid_numbers = [ number for ind, number in enumerate(rid_numbers) if ind in all_indices ] logging.info('number of unique subjects: %d' % len(np.unique(reduced_rid_numbers))) scores = {} logging.info('source prediction: ' + str(source_pred)) logging.info('source ground truth: ' + str(source_true_labels)) logging.info('target prediction: ' + str(target_pred)) logging.info('target ground truth: ' + str(target_true_labels)) scores[clf_config.source_field_strength] = test_utils.evaluate_scores( source_true_labels, source_pred, score_functions) scores[clf_config.target_field_strength] = test_utils.evaluate_scores( target_true_labels, target_pred, score_functions) true_labels_together = source_true_labels + target_true_labels pred_together = source_pred + target_pred scores['all data'] = test_utils.evaluate_scores(true_labels_together, pred_together, score_functions) # dictionary sorted by key sorted_scores = OrderedDict(sorted(scores.items(), key=lambda t: str(t[0]))) return sorted_scores, latest_step
def run_training(continue_run, log_dir): logging.info('===== RUNNING EXPERIMENT ========') logging.info(exp_config.experiment_name) logging.info('=================================') init_step = 0 if continue_run: logging.info('!!!!!!!!!!!!!!!!!!!!!!!!!!!! Continuing previous run !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!') try: init_checkpoint_path = utils.get_latest_model_checkpoint_path(log_dir, 'model.ckpt') logging.info('Checkpoint path: %s' % init_checkpoint_path) init_step = int(init_checkpoint_path.split('/')[-1].split('-')[-1]) + 1 # plus 1 b/c otherwise starts with eval logging.info('Latest step was: %d' % init_step) log_dir += '_cont' except: logging.warning('!!! Didnt find init checkpoint. Maybe first run failed. Disabling continue mode...') continue_run = False init_step = 0 logging.info('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!') # import data data = adni_data_loader_all.load_and_maybe_process_data( input_folder=exp_config.data_root, preprocessing_folder=exp_config.preproc_folder, size=exp_config.image_size, target_resolution=exp_config.target_resolution, label_list = exp_config.label_list, offset=exp_config.offset, rescale_to_one=exp_config.rescale_to_one, force_overwrite=False ) # extract images and indices of source/target images for the training and validation set images_train, source_images_train_ind, target_images_train_ind,\ images_val, source_images_val_ind, target_images_val_ind = data_utils.get_images_and_fieldstrength_indices( data, exp_config.source_field_strength, exp_config.target_field_strength) # get labels # the following are HDF5 datasets, not numpy arrays labels_train = data['diagnosis_train'] ages_train = data['age_train'] labels_val = data['diagnosis_val'] ages_val = data['age_val'] if exp_config.age_ordinal_regression: ages_train = utils.age_to_ordinal_reg_format(ages_train, bins=exp_config.age_bins) ordinal_reg_weights = utils.get_ordinal_reg_weights(ages_train) else: ages_train = utils.age_to_bins(ages_train, bins=exp_config.age_bins) ordinal_reg_weights = None if exp_config.age_ordinal_regression: ages_val = utils.age_to_ordinal_reg_format(ages_val, bins=exp_config.age_bins) else: ages_val= utils.age_to_bins(ages_val, bins=exp_config.age_bins) generator = exp_config.generator discriminator = exp_config.discriminator augmentation_function = exp_config.augmentation_function if exp_config.use_augmentation else None s_sampler_train = iterate_minibatches_endlessly(images_train, batch_size=2*exp_config.batch_size, exp_config=exp_config, labels_list=[labels_train, ages_train], selection_indices=source_images_train_ind, augmentation_function=augmentation_function) t_sampler_train = iterate_minibatches_endlessly(images_train, batch_size=exp_config.batch_size, exp_config=exp_config, labels_list=[labels_train, ages_train], selection_indices=target_images_train_ind, augmentation_function=augmentation_function) with tf.Graph().as_default(): training_time_placeholder = tf.placeholder(tf.bool, shape=[], name='training_time') # GAN # input noise for generator if exp_config.use_generator_input_noise: noise_in_gen_pl = tf.random_uniform(shape=exp_config.generator_input_noise_shape, minval=-1, maxval=1) else: noise_in_gen_pl = None # target image batch xt_pl = tf.placeholder(tf.float32, image_tensor_shape(exp_config.batch_size), name='x_target') # the classifier uses 2 times the batch size of the GAN clf_batch_size = 2 * exp_config.batch_size # source image batch xs_pl, diag_s_pl, ages_s_pl = placeholders_clf(clf_batch_size, 'source') # split source batch into 1 to be translated to xf and 2 for the classifier # for the discriminator train op half 2 of the batch is not used xs1_pl, xs2_pl = tf.split(xs_pl, 2, axis=0) # generated fake image batch xf_pl = generator(xs1_pl, noise_in_gen_pl, training_time_placeholder) # difference between generated and source images diff_img_pl = xf_pl - xs1_pl # visualize the images by showing one slice of them in the z direction tf.summary.image('sample_outputs', tf_utils.put_kernels_on_grid3d(xf_pl, exp_config.cut_axis, exp_config.cut_index, rescale_mode='manual', input_range=exp_config.image_range)) tf.summary.image('sample_xt', tf_utils.put_kernels_on_grid3d(xt_pl, exp_config.cut_axis, exp_config.cut_index, rescale_mode='manual', input_range=exp_config.image_range)) tf.summary.image('sample_xs', tf_utils.put_kernels_on_grid3d(xs1_pl, exp_config.cut_axis, exp_config.cut_index, rescale_mode='manual', input_range=exp_config.image_range)) tf.summary.image('sample_difference_xf-xs', tf_utils.put_kernels_on_grid3d(diff_img_pl, exp_config.cut_axis, exp_config.cut_index, rescale_mode='centered', cutoff_abs=exp_config.diff_threshold)) # output of the discriminator for real image d_pl = discriminator(xt_pl, training_time_placeholder, scope_reuse=False) # output of the discriminator for fake image d_pl_ = discriminator(xf_pl, training_time_placeholder, scope_reuse=True) d_hat = None x_hat = None if exp_config.improved_training: epsilon = tf.random_uniform([], 0.0, 1.0) x_hat = epsilon * xt_pl + (1 - epsilon) * xf_pl d_hat = discriminator(x_hat, training_time_placeholder, scope_reuse=True) dist_l1 = tf.reduce_mean(tf.abs(diff_img_pl)) learning_rate_gan_pl = tf.placeholder(tf.float32, shape=[], name='learning_rate') learning_rate_clf_pl = tf.placeholder(tf.float32, shape=[], name='learning_rate') if exp_config.momentum is not None: optimizer_handle = lambda learning_rate: exp_config.optimizer_handle(learning_rate=learning_rate, momentum=exp_config.momentum) else: optimizer_handle = lambda learning_rate: exp_config.optimizer_handle(learning_rate=learning_rate) # Build the operation for clipping the discriminator weights d_clip_op = gan_model.clip_op() # Put L1 distance of generated image and original image on summary dist_l1_summary_op = tf.summary.scalar('L1_distance_to_source_img', dist_l1) # Classifier ---------------------------------------------------------------------------------------- # for training usually false so xt and xf get concatenated as classifier input, otherwise directly_feed_clf_pl = tf.placeholder(tf.bool, shape=[], name='direct_classifier_feeding') # conditionally assign either a concatenation of the generated dataset and the source data # cond to avoid having to specify not needed placeholders in the feed dict images_clf, diag_clf, ages_clf = tf.cond( directly_feed_clf_pl, lambda: placeholders_clf(clf_batch_size, 'direct_clf'), lambda: concatenate_clf_input([xf_pl, xs2_pl], diag_s_pl, ages_s_pl, scope_name = 'fs_concat') ) tf.summary.scalar('learning_rate_gan', learning_rate_gan_pl) tf.summary.scalar('learning_rate_clf', learning_rate_clf_pl) # Build a Graph that computes predictions from the inference model. diag_logits_train, ages_logits_train = exp_config.clf_model_handle(images_clf, nlabels=exp_config.nlabels, training=training_time_placeholder, n_age_thresholds=len(exp_config.age_bins), bn_momentum=exp_config.bn_momentum) # Add to the Graph the Ops for loss calculation. [classifier_loss, diag_loss, age_loss, weights_norm_clf] = clf_model_mt.loss(diag_logits_train, ages_logits_train, diag_clf, ages_clf, nlabels=exp_config.nlabels, weight_decay=exp_config.weight_decay, diag_weight=exp_config.diag_weight, age_weight=exp_config.age_weight, use_ordinal_reg=exp_config.age_ordinal_regression, ordinal_reg_weights=ordinal_reg_weights) # nr means no regularization, meaning the loss without the regularization term train_ops_dict, losses_gan_dict = joint_model.training_ops(d_pl, d_pl_, classifier_loss, optimizer_handle=optimizer_handle, learning_rate_gan=learning_rate_gan_pl, learning_rate_clf=learning_rate_clf_pl, l1_img_dist=dist_l1, gan_loss_weight=exp_config.gan_loss_weight, task_loss_weight=exp_config.task_loss_weight, w_reg_img_dist_l1=exp_config.w_reg_img_dist_l1, w_reg_gen_l1=exp_config.w_reg_gen_l1, w_reg_disc_l1=exp_config.w_reg_disc_l1, w_reg_gen_l2=exp_config.w_reg_gen_l2, w_reg_disc_l2=exp_config.w_reg_disc_l2, d_hat=d_hat, x_hat=x_hat, scale=exp_config.scale) tf.summary.scalar('classifier loss', classifier_loss) tf.summary.scalar('diag_loss', diag_loss) tf.summary.scalar('age_loss', age_loss) tf.summary.scalar('weights_norm_term_classifier', weights_norm_clf) tf.summary.scalar('generator loss joint', losses_gan_dict['gen']['joint']) tf.summary.scalar('discriminator loss joint', losses_gan_dict['disc']['joint']) eval_diag_loss, eval_ages_loss, pred_labels, ages_softmaxs = clf_model_mt.evaluation(diag_logits_train, ages_logits_train, diag_clf, ages_clf, images_clf, diag_weight=exp_config.diag_weight, age_weight=exp_config.age_weight, nlabels=exp_config.nlabels, use_ordinal_reg=exp_config.age_ordinal_regression) # Build the summary Tensor based on the TF collection of Summaries. summary = tf.summary.merge_all() # Add the variable initializer Op. init = tf.global_variables_initializer() # Create a savers for writing training checkpoints. saver_latest = tf.train.Saver(max_to_keep=2) saver_best_disc = tf.train.Saver(max_to_keep=2) # disc loss is scaled negative EM distance saver_best_diag_f1 = tf.train.Saver(max_to_keep=5) saver_best_ages_f1 = tf.train.Saver(max_to_keep=1) saver_best_xent = tf.train.Saver(max_to_keep=5) # validation summaries gan val_disc_loss_pl = tf.placeholder(tf.float32, shape=[], name='disc_val_loss') disc_val_summary_op = tf.summary.scalar('validation_discriminator_loss', val_disc_loss_pl) val_gen_loss_pl = tf.placeholder(tf.float32, shape=[], name='gen_val_loss') gen_val_summary_op = tf.summary.scalar('validation_generator_loss', val_gen_loss_pl) val_summary_gan = tf.summary.merge([disc_val_summary_op, gen_val_summary_op]) # Classifier summary val_error_clf_ = tf.placeholder(tf.float32, shape=[], name='val_error_diag') val_error_summary = tf.summary.scalar('classifier_validation_loss', val_error_clf_) val_diag_f1_score_ = tf.placeholder(tf.float32, shape=[], name='val_diag_f1') val_f1_diag_summary = tf.summary.scalar('validation_diag_f1', val_diag_f1_score_) val_ages_f1_score_ = tf.placeholder(tf.float32, shape=[], name='val_ages_f1') val_f1_ages_summary = tf.summary.scalar('validation_ages_f1', val_ages_f1_score_) val_summary_clf = tf.summary.merge([val_error_summary, val_f1_diag_summary, val_f1_ages_summary]) val_summary = tf.summary.merge([val_summary_clf, val_summary_gan]) train_error_clf_ = tf.placeholder(tf.float32, shape=[], name='train_error_diag') train_error_clf_summary = tf.summary.scalar('classifier_training_loss', train_error_clf_) train_diag_f1_score_ = tf.placeholder(tf.float32, shape=[], name='train_diag_f1') train_diag_f1_summary = tf.summary.scalar('training_diag_f1', train_diag_f1_score_) train_ages_f1_score_ = tf.placeholder(tf.float32, shape=[], name='train_ages_f1') train_f1_ages_summary = tf.summary.scalar('training_ages_f1', train_ages_f1_score_) train_summary = tf.summary.merge([train_error_clf_summary, train_diag_f1_summary, train_f1_ages_summary]) # prevents ResourceExhaustError when a lot of memory is used config = tf.ConfigProto() config.gpu_options.allow_growth = True # Do not assign whole gpu memory, just use it on the go config.allow_soft_placement = True # If a operation is not defined in the default device, let it execute in another. # Create a session for running Ops on the Graph. sess = tf.Session(config=config) summary_writer = tf.summary.FileWriter(log_dir, sess.graph) sess.graph.finalize() # Run the Op to initialize the variables. sess.run(init) if continue_run: # Restore session saver_latest.restore(sess, init_checkpoint_path) curr_lr_gan = exp_config.learning_rate_gan curr_lr_clf = exp_config.learning_rate_clf no_improvement_counter = 0 best_val = np.inf last_train = np.inf loss_history = [] loss_gradient = np.inf best_diag_f1_score = 0 best_ages_f1_score = 0 # initialize value of lowest (i. e. best) discriminator loss best_d_loss = np.inf for step in range(init_step, exp_config.max_steps): start_time = time.time() # discriminator and classifier (task) training iterations d_iters = 5 t_iters = 1 if step % 500 == 0 or step < 25: d_iters = 100 for iteration in range(max(d_iters, t_iters)): x_t, [diag_t, age_t] = next(t_sampler_train) x_s, [diag_s, age_s] = next(s_sampler_train) feed_dict_dc = {xs_pl: x_s, xt_pl: x_t, learning_rate_gan_pl: curr_lr_gan, learning_rate_clf_pl: curr_lr_clf, diag_s_pl: diag_s, ages_s_pl: age_s, training_time_placeholder: True, directly_feed_clf_pl: False} train_ops_list_dc = [] if iteration < t_iters: # train classifier train_ops_list_dc.append(train_ops_dict['clf']) if iteration < d_iters: # train discriminator train_ops_list_dc.append(train_ops_dict['disc']) sess.run(train_ops_list_dc, feed_dict=feed_dict_dc) if not exp_config.improved_training: sess.run(d_clip_op) elapsed_time = time.time() - start_time # train generator x_t, [diag_t, age_t] = next(t_sampler_train) x_s, [diag_s, age_s] = next(s_sampler_train) sess.run(train_ops_dict['gen'], feed_dict={xs_pl: x_s, xt_pl: x_t, learning_rate_gan_pl: curr_lr_gan, learning_rate_clf_pl: curr_lr_clf, diag_s_pl: diag_s, ages_s_pl: age_s, training_time_placeholder: True, directly_feed_clf_pl: False }) if step % exp_config.update_tensorboard_frequency == 0: x_t, [diag_t, age_t] = next(t_sampler_train) x_s, [diag_s, age_s] = next(s_sampler_train) feed_dict_summary={xs_pl: x_s, xt_pl: x_t, learning_rate_gan_pl: curr_lr_gan, learning_rate_clf_pl: curr_lr_clf, diag_s_pl: diag_s, ages_s_pl: age_s, training_time_placeholder: True, directly_feed_clf_pl: False } c_loss_one_batch, gan_losses_one_batch_dict, summary_str = sess.run( [classifier_loss, losses_gan_dict, summary], feed_dict=feed_dict_summary) summary_writer.add_summary(summary_str, step) summary_writer.flush() logging.info("[Step: %d], classifier_loss: %g, GAN losses: %s" % (step, c_loss_one_batch, str(gan_losses_one_batch_dict))) logging.info(" - elapsed time for one step: %f secs" % elapsed_time) if (step + 1) % exp_config.train_eval_frequency == 0: # Evaluate against the training set logging.info('Training data eval for classifier (target domain):') [train_loss, train_diag_f1, train_ages_f1] = do_eval_classifier(sess, eval_diag_loss, eval_ages_loss, pred_labels, ages_softmaxs, xs_pl, diag_s_pl, ages_s_pl, training_time_placeholder, directly_feed_clf_pl, images_train, [labels_train, ages_train], clf_batch_size=clf_batch_size, do_ordinal_reg=exp_config.age_ordinal_regression, selection_indices=source_images_train_ind) train_summary_msg = sess.run(train_summary, feed_dict={train_error_clf_: train_loss, train_diag_f1_score_: train_diag_f1, train_ages_f1_score_: train_ages_f1} ) summary_writer.add_summary(train_summary_msg, step) loss_history.append(train_loss) if len(loss_history) > 5: loss_history.pop(0) loss_gradient = (loss_history[-5] - loss_history[-1]) / 2 logging.info('loss gradient is currently %f' % loss_gradient) if exp_config.schedule_lr and loss_gradient < exp_config.schedule_gradient_threshold: logging.warning('Reducing learning rate of the classifier!') curr_lr_clf /= 10.0 logging.info('Learning rate of the classifier changed to: %f' % curr_lr_clf) # reset loss history to give the optimisation some time to start decreasing again loss_gradient = np.inf loss_history = [] if train_loss <= last_train: # best_train: logging.info('Decrease in training error!') else: logging.info('No improvment in training error for %d steps' % no_improvement_counter) last_train = train_loss if (step + 1) % exp_config.validation_frequency == 0: # evaluate gan losses g_loss_val_avg, d_loss_val_avg = do_eval_gan(sess=sess, losses=[losses_gan_dict['gen']['nr'], losses_gan_dict['disc']['nr']], images_s_pl=xs_pl, images_t_pl=xt_pl, training_time_placeholder=training_time_placeholder, images=images_val, source_images_ind=source_images_val_ind, target_images_ind=target_images_val_ind) # evaluate classifier losses [val_loss, val_diag_f1, val_ages_f1] = do_eval_classifier(sess, eval_diag_loss, eval_ages_loss, pred_labels, ages_softmaxs, xs_pl, diag_s_pl, ages_s_pl, training_time_pl=training_time_placeholder, directly_feed_clf_pl=directly_feed_clf_pl, images=images_val, labels_list=[labels_val, ages_val], clf_batch_size=clf_batch_size, do_ordinal_reg=exp_config.age_ordinal_regression, selection_indices=source_images_val_ind) feed_dict_val = { val_error_clf_: val_loss, val_diag_f1_score_: val_diag_f1, val_ages_f1_score_: val_ages_f1, val_disc_loss_pl: d_loss_val_avg, val_gen_loss_pl: g_loss_val_avg } validation_summary_msg = sess.run(val_summary, feed_dict=feed_dict_val) summary_writer.add_summary(validation_summary_msg, step) summary_writer.flush() # save best variables (if discriminator loss is the lowest yet) if d_loss_val_avg <= best_d_loss: best_d_loss = d_loss_val_avg best_file = os.path.join(log_dir, 'model_best_d_loss.ckpt') saver_best_disc.save(sess, best_file, global_step=step) logging.info('Found new best discriminator loss on validation set! - %f - Saving model_best_d_loss.ckpt' % best_d_loss) if val_diag_f1 >= best_diag_f1_score: best_diag_f1_score = val_diag_f1 best_file = os.path.join(log_dir, 'model_best_diag_f1.ckpt') saver_best_diag_f1.save(sess, best_file, global_step=step) logging.info( 'Found new best DIAGNOSIS F1 score on validation set! - %f - Saving model_best_diag_f1.ckpt' % val_diag_f1) if val_ages_f1 >= best_ages_f1_score: best_ages_f1_score = val_ages_f1 best_file = os.path.join(log_dir, 'model_best_ages_f1.ckpt') saver_best_ages_f1.save(sess, best_file, global_step=step) logging.info( 'Found new best AGES F1 score on validation set! - %f - Saving model_best_ages_f1.ckpt' % val_ages_f1) if val_loss <= best_val: best_val = val_loss best_file = os.path.join(log_dir, 'model_best_xent.ckpt') saver_best_xent.save(sess, best_file, global_step=step) logging.info( 'Found new best crossentropy on validation set! - %f - Saving model_best_xent.ckpt' % val_loss) logging.info("[Validation], generator loss: %g, discriminator_loss: %g" % (g_loss_val_avg, d_loss_val_avg)) # Write the summaries and print an overview fairly often. if step % exp_config.save_frequency == 0: saver_latest.save(sess, os.path.join(log_dir, 'model.ckpt'), global_step=step) sess.close()