Ejemplo n.º 1
0
def run_training(continue_run):

    logging.info('EXPERIMENT NAME: %s' % exp_config.experiment_name)

    init_step = 0

    if continue_run:
        logging.info(
            '!!!!!!!!!!!!!!!!!!!!!!!!!!!! Continuing previous run !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'
        )
        try:
            init_checkpoint_path = utils.get_latest_model_checkpoint_path(
                log_dir, 'model.ckpt')
            logging.info('Checkpoint path: %s' % init_checkpoint_path)
            init_step = int(
                init_checkpoint_path.split('/')[-1].split('-')
                [-1]) + 1  # plus 1 b/c otherwise starts with eval
            logging.info('Latest step was: %d' % init_step)
        except:
            logging.warning(
                '!!! Didnt find init checkpoint. Maybe first run failed. Disabling continue mode...'
            )
            continue_run = False
            init_step = 0

        logging.info(
            '!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'
        )

    # Load data
    data = adni_data_loader_all.load_and_maybe_process_data(
        input_folder=exp_config.data_root,
        preprocessing_folder=exp_config.preproc_folder,
        size=exp_config.image_size,
        target_resolution=exp_config.target_resolution,
        label_list=exp_config.label_list,
        offset=exp_config.offset,
        rescale_to_one=True,
        force_overwrite=False)

    # the following are HDF5 datasets, not numpy arrays
    images_train = data['images_train']
    fieldstr_train = data['field_strength_train']
    labels_train = utils.fstr_to_label(fieldstr_train,
                                       exp_config.field_strength_list,
                                       exp_config.fs_label_list)
    ages_train = data['age_train']

    if exp_config.age_ordinal_regression:
        ages_train = utils.age_to_ordinal_reg_format(ages_train,
                                                     bins=exp_config.age_bins)
        ordinal_reg_weights = utils.get_ordinal_reg_weights(ages_train)
    else:
        ages_train = utils.age_to_bins(ages_train, bins=exp_config.age_bins)
        ordinal_reg_weights = None

    images_val = data['images_val']
    fieldstr_val = data['field_strength_val']
    labels_val = utils.fstr_to_label(fieldstr_val,
                                     exp_config.field_strength_list,
                                     exp_config.fs_label_list)
    ages_val = data['age_val']

    if exp_config.age_ordinal_regression:
        ages_val = utils.age_to_ordinal_reg_format(ages_val,
                                                   bins=exp_config.age_bins)
    else:
        ages_val = utils.age_to_bins(ages_val, bins=exp_config.age_bins)

    if exp_config.use_data_fraction:
        num_images = images_train.shape[0]
        new_last_index = int(float(num_images) * exp_config.use_data_fraction)

        logging.warning('USING ONLY FRACTION OF DATA!')
        logging.warning(' - Number of imgs orig: %d, Number of imgs new: %d' %
                        (num_images, new_last_index))
        images_train = images_train[0:new_last_index, ...]
        labels_train = labels_train[0:new_last_index, ...]

    logging.info('Data summary:')
    logging.info('TRAINING')
    logging.info(' - Images:')
    logging.info(images_train.shape)
    logging.info(images_train.dtype)
    logging.info(' - Labels:')
    logging.info(labels_train.shape)
    logging.info(labels_train.dtype)
    logging.info('VALIDATiON')
    logging.info(' - Images:')
    logging.info(images_val.shape)
    logging.info(images_val.dtype)
    logging.info(' - Labels:')
    logging.info(labels_val.shape)
    logging.info(labels_val.dtype)

    # Tell TensorFlow that the model will be built into the default Graph.

    with tf.Graph().as_default():

        # Generate placeholders for the images and labels.

        image_tensor_shape = [exp_config.batch_size] + list(
            exp_config.image_size) + [1]
        labels_tensor_shape = [exp_config.batch_size]

        if exp_config.age_ordinal_regression:
            ages_tensor_shape = [
                exp_config.batch_size,
                len(exp_config.age_bins)
            ]
        else:
            ages_tensor_shape = [exp_config.batch_size]

        images_placeholder = tf.placeholder(tf.float32,
                                            shape=image_tensor_shape,
                                            name='images')
        diag_placeholder = tf.placeholder(tf.uint8,
                                          shape=labels_tensor_shape,
                                          name='labels')
        ages_placeholder = tf.placeholder(tf.uint8,
                                          shape=ages_tensor_shape,
                                          name='ages')

        learning_rate_placeholder = tf.placeholder(tf.float32,
                                                   shape=[],
                                                   name='learning_rate')
        training_time_placeholder = tf.placeholder(tf.bool,
                                                   shape=[],
                                                   name='training_time')

        tf.summary.scalar('learning_rate', learning_rate_placeholder)

        # Build a Graph that computes predictions from the inference model.
        diag_logits, ages_logits = exp_config.clf_model_handle(
            images_placeholder,
            nlabels=exp_config.nlabels,
            training=training_time_placeholder,
            n_age_thresholds=len(exp_config.age_bins),
            bn_momentum=exp_config.bn_momentum)

        # Add to the Graph the Ops for loss calculation.

        [loss, diag_loss, age_loss, weights_norm
         ] = model_mt.loss(diag_logits,
                           ages_logits,
                           diag_placeholder,
                           ages_placeholder,
                           nlabels=exp_config.nlabels,
                           weight_decay=exp_config.weight_decay,
                           diag_weight=exp_config.diag_weight,
                           age_weight=exp_config.age_weight,
                           use_ordinal_reg=exp_config.age_ordinal_regression,
                           ordinal_reg_weights=ordinal_reg_weights)

        tf.summary.scalar('loss', loss)
        tf.summary.scalar('diag_loss', diag_loss)
        tf.summary.scalar('weights_norm_term', weights_norm)

        if exp_config.momentum is not None:
            optimiser = exp_config.optimizer_handle(
                learning_rate=learning_rate_placeholder,
                momentum=exp_config.momentum)
        else:
            optimiser = exp_config.optimizer_handle(
                learning_rate=learning_rate_placeholder)

        # create a copy of all trainable variables with `0` as initial values
        t_vars = tf.global_variables()  #tf.trainable_variables()
        accum_tvars = [
            tf.Variable(tf.zeros_like(tv.initialized_value()), trainable=False)
            for tv in t_vars
        ]

        # create a op to initialize all accums vars
        zero_ops = [tv.assign(tf.zeros_like(tv)) for tv in accum_tvars]

        # compute gradients for a batch
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            batch_grads_vars = optimiser.compute_gradients(loss, t_vars)

        # collect the batch gradient into accumulated vars

        accum_ops = [
            accum_tvar.assign_add(batch_grad_var[0]) for accum_tvar,
            batch_grad_var in zip(accum_tvars, batch_grads_vars)
        ]

        accum_normaliser_pl = tf.placeholder(dtype=tf.float32,
                                             name='accum_normaliser')
        accum_mean_op = [
            accum_tvar.assign(tf.divide(accum_tvar, accum_normaliser_pl))
            for accum_tvar in accum_tvars
        ]

        # apply accums gradients
        with tf.control_dependencies(update_ops):
            train_op = optimiser.apply_gradients([
                (accum_tvar, batch_grad_var[1])
                for accum_tvar, batch_grad_var in zip(accum_tvars,
                                                      batch_grads_vars)
            ])

        eval_diag_loss, eval_ages_loss, pred_labels, ages_softmaxs = model_mt.evaluation(
            diag_logits,
            ages_logits,
            diag_placeholder,
            ages_placeholder,
            images_placeholder,
            diag_weight=exp_config.diag_weight,
            age_weight=exp_config.age_weight,
            nlabels=exp_config.nlabels,
            use_ordinal_reg=exp_config.age_ordinal_regression)

        # Build the summary Tensor based on the TF collection of Summaries.
        summary = tf.summary.merge_all()

        # Add the variable initializer Op.
        init = tf.global_variables_initializer()

        # Create a saver for writing training checkpoints.
        saver = tf.train.Saver(max_to_keep=3)
        saver_best_diag_f1 = tf.train.Saver(max_to_keep=2)
        saver_best_xent = tf.train.Saver(max_to_keep=2)

        # prevents ResourceExhaustError when a lot of memory is used
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True  # Do not assign whole gpu memory, just use it on the go
        config.allow_soft_placement = True  # If a operation is not defined in the default device, let it execute in another.

        # Create a session for running Ops on the Graph.
        sess = tf.Session(config=config)

        # Instantiate a SummaryWriter to output summaries and the Graph.
        summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

        # with tf.name_scope('monitoring'):

        val_error_ = tf.placeholder(tf.float32,
                                    shape=[],
                                    name='val_error_diag')
        val_error_summary = tf.summary.scalar('validation_loss', val_error_)

        val_diag_f1_score_ = tf.placeholder(tf.float32,
                                            shape=[],
                                            name='val_diag_f1')
        val_f1_diag_summary = tf.summary.scalar('validation_diag_f1',
                                                val_diag_f1_score_)

        val_ages_f1_score_ = tf.placeholder(tf.float32,
                                            shape=[],
                                            name='val_ages_f1')

        val_summary = tf.summary.merge(
            [val_error_summary, val_f1_diag_summary])

        train_error_ = tf.placeholder(tf.float32,
                                      shape=[],
                                      name='train_error_diag')
        train_error_summary = tf.summary.scalar('training_loss', train_error_)

        train_diag_f1_score_ = tf.placeholder(tf.float32,
                                              shape=[],
                                              name='train_diag_f1')
        train_diag_f1_summary = tf.summary.scalar('training_diag_f1',
                                                  train_diag_f1_score_)

        train_ages_f1_score_ = tf.placeholder(tf.float32,
                                              shape=[],
                                              name='train_ages_f1')

        train_summary = tf.summary.merge(
            [train_error_summary, train_diag_f1_summary])

        # Run the Op to initialize the variables.
        sess.run(init)

        if continue_run:
            # Restore session
            saver.restore(sess, init_checkpoint_path)

        step = init_step
        curr_lr = exp_config.learning_rate

        no_improvement_counter = 0
        best_val = np.inf
        last_train = np.inf
        loss_history = []
        loss_gradient = np.inf
        best_diag_f1_score = 0

        # acum_manual = 0  #np.zeros((2,3,3,3,1,32))

        for epoch in range(exp_config.max_epochs):

            logging.info('EPOCH %d' % epoch)
            sess.run(zero_ops)
            accum_counter = 0

            for batch in iterate_minibatches(
                    images_train, [labels_train, ages_train],
                    batch_size=exp_config.batch_size,
                    augmentation_function=exp_config.augmentation_function,
                    exp_config=exp_config):

                if exp_config.warmup_training:
                    if step < 50:
                        curr_lr = exp_config.learning_rate / 10.0
                    elif step == 50:
                        curr_lr = exp_config.learning_rate

                start_time = time.time()

                # get a batch
                x, [y, a] = batch

                # TEMPORARY HACK (to avoid incomplete batches)
                if y.shape[0] < exp_config.batch_size:
                    step += 1
                    continue

                # Run accumulation
                feed_dict = {
                    images_placeholder: x,
                    diag_placeholder: y,
                    ages_placeholder: a,
                    learning_rate_placeholder: curr_lr,
                    training_time_placeholder: True
                }

                _, loss_value = sess.run([accum_ops, loss],
                                         feed_dict=feed_dict)

                accum_counter += 1

                if accum_counter == exp_config.n_accum_batches:

                    # Average gradient over batches
                    sess.run(accum_mean_op,
                             feed_dict={
                                 accum_normaliser_pl:
                                 float(exp_config.n_accum_batches)
                             })
                    sess.run(train_op, feed_dict=feed_dict)

                    # Reset all counters etc.
                    sess.run(zero_ops)
                    accum_counter = 0

                    duration = time.time() - start_time

                    # Write the summaries and print an overview fairly often.
                    if step % 10 == 0:
                        # Print status to stdout.

                        logging.info('Step %d: loss = %.2f (%.3f sec)' %
                                     (step, loss_value, duration))
                        # Update the events file.

                        summary_str = sess.run(summary, feed_dict=feed_dict)
                        summary_writer.add_summary(summary_str, step)
                        summary_writer.flush()

                    if (step + 1) % exp_config.train_eval_frequency == 0:

                        # Evaluate against the training set
                        logging.info('Training Data Eval:')
                        [train_loss, train_diag_f1, train_ages_f1] = do_eval(
                            sess,
                            eval_diag_loss,
                            eval_ages_loss,
                            pred_labels,
                            ages_softmaxs,
                            images_placeholder,
                            diag_placeholder,
                            ages_placeholder,
                            training_time_placeholder,
                            images_train, [labels_train, ages_train],
                            batch_size=exp_config.batch_size,
                            do_ordinal_reg=exp_config.age_ordinal_regression)

                        train_summary_msg = sess.run(train_summary,
                                                     feed_dict={
                                                         train_error_:
                                                         train_loss,
                                                         train_diag_f1_score_:
                                                         train_diag_f1,
                                                         train_ages_f1_score_:
                                                         train_ages_f1
                                                     })
                        summary_writer.add_summary(train_summary_msg, step)

                        loss_history.append(train_loss)
                        if len(loss_history) > 5:
                            loss_history.pop(0)
                            loss_gradient = (loss_history[-5] -
                                             loss_history[-1]) / 2

                        logging.info('loss gradient is currently %f' %
                                     loss_gradient)

                        if exp_config.schedule_lr and loss_gradient < exp_config.schedule_gradient_threshold:
                            logging.warning('Reducing learning rate!')
                            curr_lr /= 10.0
                            logging.info('Learning rate changed to: %f' %
                                         curr_lr)

                            # reset loss history to give the optimisation some time to start decreasing again
                            loss_gradient = np.inf
                            loss_history = []

                        if train_loss <= last_train:  # best_train:
                            logging.info('Decrease in training error!')
                        else:
                            logging.info(
                                'No improvment in training error for %d steps'
                                % no_improvement_counter)

                        last_train = train_loss

                    # Save a checkpoint and evaluate the model periodically.
                    if (step + 1) % exp_config.val_eval_frequency == 0:

                        checkpoint_file = os.path.join(log_dir, 'model.ckpt')
                        saver.save(sess, checkpoint_file, global_step=step)

                        # Evaluate against the validation set.
                        logging.info('Validation Data Eval:')

                        [val_loss, val_diag_f1, val_ages_f1] = do_eval(
                            sess,
                            eval_diag_loss,
                            eval_ages_loss,
                            pred_labels,
                            ages_softmaxs,
                            images_placeholder,
                            diag_placeholder,
                            ages_placeholder,
                            training_time_placeholder,
                            images_val, [labels_val, ages_val],
                            batch_size=exp_config.batch_size,
                            do_ordinal_reg=exp_config.age_ordinal_regression)

                        val_summary_msg = sess.run(val_summary,
                                                   feed_dict={
                                                       val_error_:
                                                       val_loss,
                                                       val_diag_f1_score_:
                                                       val_diag_f1,
                                                       val_ages_f1_score_:
                                                       val_ages_f1
                                                   })
                        summary_writer.add_summary(val_summary_msg, step)

                        if val_diag_f1 >= best_diag_f1_score:
                            best_diag_f1_score = val_diag_f1
                            best_file = os.path.join(
                                log_dir, 'model_best_diag_f1.ckpt')
                            saver_best_diag_f1.save(sess,
                                                    best_file,
                                                    global_step=step)
                            logging.info(
                                'Found new best DIAGNOSIS F1 score on validation set! - %f -  Saving model_best_diag_f1.ckpt'
                                % val_diag_f1)

                        if val_loss <= best_val:
                            best_val = val_loss
                            best_file = os.path.join(log_dir,
                                                     'model_best_xent.ckpt')
                            saver_best_xent.save(sess,
                                                 best_file,
                                                 global_step=step)
                            logging.info(
                                'Found new best crossentropy on validation set! - %f -  Saving model_best_xent.ckpt'
                                % val_loss)

                    step += 1

        sess.close()
Ejemplo n.º 2
0
def generate_and_evaluate_ad_classification(gan_experiment_path_list, clf_experiment_path, score_functions,
                                            image_saving_indices=set(), image_saving_path=None, max_batch_size=np.inf):
    """

    :param gan_experiment_path_list: list of GAN experiment paths to be evaluated. They must all have the same image settings and source/target field strengths as the classifier
    only gan experiments with the same source and target field strength are permitted
    :param clf_experiment_path: AD classifier used
    :param verbose: boolean. log all image classifications
    :param image_saving_indices: set of indices of the images to be saved
    :param image_saving_path: where to save the images. They are saved in subfolders for each experiment
    :return:
    """

    clf_config, logdir_clf = utils.load_log_exp_config(clf_experiment_path)

    # Load data
    data = adni_data_loader_all.load_and_maybe_process_data(
        input_folder=clf_config.data_root,
        preprocessing_folder=clf_config.preproc_folder,
        size=clf_config.image_size,
        target_resolution=clf_config.target_resolution,
        label_list=clf_config.label_list,
        offset=clf_config.offset,
        rescale_to_one=clf_config.rescale_to_one,
        force_overwrite=False
    )

    # extract images and indices of source/target images for the test set
    images_test = data['images_test']
    labels_test = data['diagnosis_test']
    ages_test = data['age_test']

    im_s = clf_config.image_size
    batch_size = min(clf_config.batch_size, std_params.batch_size, max_batch_size)
    logging.info('batch size %d is used for everything' % batch_size)
    img_tensor_shape = [batch_size, im_s[0], im_s[1], im_s[2], 1]
    clf_remainder_batch_size = images_test.shape[0] % batch_size

    # prevents ResourceExhaustError when a lot of memory is used
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True  # Do not assign whole gpu memory, just use it on the go
    config.allow_soft_placement = True  # If a operation is not defined in the default device, let it execute in another.

    # open field strength classifier save file from the selected experiment
    logging.info("loading Alzheimer's disease classifier")
    graph_clf, image_pl, predictions_clf_op, init_clf_op, saver_clf = build_clf_graph(img_tensor_shape, clf_config)
    # logging.info("getting savepoint with the best cross entropy")
    # init_checkpoint_path_clf = get_latest_checkpoint_and_log(logdir_clf, 'model_best_xent.ckpt')
    logging.info("getting savepoint with the best f1 score")
    init_checkpoint_path_clf = get_latest_checkpoint_and_log(logdir_clf, 'model_best_diag_f1.ckpt')
    sess_clf = tf.Session(config=config, graph=graph_clf)
    sess_clf.run(init_clf_op)
    saver_clf.restore(sess_clf, init_checkpoint_path_clf)

    # make a separate graph for the last batch where the batchsize is smaller
    if clf_remainder_batch_size > 0:
        img_tensor_shape_gan_remainder = [clf_remainder_batch_size, im_s[0], im_s[1], im_s[2], 1]
        graph_clf_rem, image_pl_rem, predictions_clf_op_rem, init_clf_op_rem, saver_clf_rem = build_clf_graph(img_tensor_shape_gan_remainder, clf_config)
        sess_clf_rem = tf.Session(config=config, graph=graph_clf_rem)
        sess_clf_rem.run(init_clf_op_rem)
        saver_clf_rem.restore(sess_clf_rem, init_checkpoint_path_clf)

    # classifiy all real test images
    logging.info('classify all original images')
    real_pred = []
    for batch in iterate_minibatches(images_test,
                                     [labels_test, ages_test],
                                     batch_size=batch_size,
                                     exp_config=clf_config,
                                     map_labels_to_standard_range=False,
                                     shuffle_data=False,
                                     skip_remainder=False):
        # ignore the labels because data are in order, which means the label list in data can be used
        image_batch, [real_label, real_age] = batch

        current_batch_size = image_batch.shape[0]
        if current_batch_size < batch_size:
            clf_prediction_real = sess_clf_rem.run(predictions_clf_op_rem, feed_dict={image_pl_rem: image_batch})
        else:
            clf_prediction_real = sess_clf.run(predictions_clf_op, feed_dict={image_pl: image_batch})

        real_pred = real_pred + list(clf_prediction_real['label'])
        logging.info('new image batch')
        logging.info('ground truth labels: ' + str(real_label))
        logging.info('predicted labels: ' + str(clf_prediction_real['label']))

    gan_config0, logdir_gan0 = utils.load_log_exp_config(gan_experiment_path_list[0])

    source_indices = []
    target_indices = []
    source_true_labels = []
    target_true_labels = []
    for i, field_strength in enumerate(data['field_strength_test']):
        if field_strength == gan_config0.source_field_strength:
            source_indices.append(i)
            source_true_labels.append(labels_test[i])
        elif field_strength == gan_config0.target_field_strength:
            target_indices.append(i)
            target_true_labels.append(labels_test[i])

    # balance the test set
    (source_indices, source_true_labels), (
    target_indices, target_true_labels) = utils.balance_source_target(
        (source_indices, source_true_labels), (target_indices, target_true_labels), random_seed=0)
    source_pred = [pred for ind, pred in enumerate(real_pred) if ind in source_indices]
    target_pred = [pred for ind, pred in enumerate(real_pred) if ind in target_indices]

    assert len(source_pred) == len(source_true_labels)
    assert len(target_pred) == len(target_true_labels)

    # no unexpected labels
    assert all([label in clf_config.label_list for label in source_true_labels])
    assert all([label in clf_config.label_list for label in target_true_labels])
    assert all([label in clf_config.label_list for label in source_pred])
    assert all([label in clf_config.label_list for label in target_pred])

    num_source_images = len(source_indices)
    num_target_images = len(target_indices)

    source_label_count = Counter(source_true_labels)
    target_label_count = Counter(target_true_labels)

    logging.info('Data summary:')
    logging.info(' - Domains:')
    logging.info('number of source images: ' + str(num_source_images))
    logging.info('source label distribution ' + str(source_label_count))
    logging.info('number of target images: ' + str(num_target_images))
    logging.info('target label distribution ' + str(target_label_count))

    assert num_source_images == num_target_images
    assert source_label_count == target_label_count

    #2d image saving folder
    folder_2d = 'coronal_2d'
    image_saving_path2d = os.path.join(image_saving_path, folder_2d)
    utils.makefolder(image_saving_path2d)

    # save real images
    target_image_path = os.path.join(image_saving_path, 'target')
    source_image_path = os.path.join(image_saving_path, 'source')
    utils.makefolder(target_image_path)
    utils.makefolder(source_image_path)
    target_image_path2d = os.path.join(image_saving_path2d, 'target')
    source_image_path2d = os.path.join(image_saving_path2d, 'source')
    utils.makefolder(target_image_path2d)
    utils.makefolder(source_image_path2d)
    sorted_saving_indices = sorted(image_saving_indices)
    target_saving_indices = [target_indices[index] for index in sorted_saving_indices]
    for target_index in target_saving_indices:
        target_img_name = 'target_img_%.1fT_diag%d_ind%d' % (gan_config0.target_field_strength, labels_test[target_index], target_index)
        utils.save_image_and_cut(images_test[target_index], target_img_name, target_image_path, target_image_path2d)
        logging.info(target_img_name + ' saved')

    source_saving_indices = [source_indices[index] for index in sorted_saving_indices]
    for source_index in source_saving_indices:
        source_img_name = 'source_img_%.1fT_diag%d_ind%d' % (gan_config0.source_field_strength, labels_test[source_index], source_index)
        utils.save_image_and_cut(images_test[source_index], source_img_name, source_image_path,
                                 source_image_path2d)
        logging.info(source_img_name + ' saved')

    logging.info('source and target images saved')

    gan_remainder_batch_size = num_source_images % batch_size

    scores = {}
    for gan_experiment_path in gan_experiment_path_list:
        gan_config, logdir_gan = utils.load_log_exp_config(gan_experiment_path)

        gan_experiment_name = gan_config.experiment_name

        # make sure the experiments all have the same configuration as the classifier
        assert gan_config.source_field_strength == gan_config0.source_field_strength
        assert gan_config.target_field_strength == gan_config0.target_field_strength
        assert gan_config.image_size == clf_config.image_size
        assert gan_config.target_resolution == clf_config.target_resolution
        assert gan_config.offset == clf_config.offset

        logging.info('\nGAN Experiment (%.1f T to %.1f T): %s' % (gan_config.source_field_strength,
                                                              gan_config.target_field_strength, gan_experiment_name))
        logging.info(gan_config)
        # open GAN save file from the selected experiment
        logging.info('loading GAN')
        # open the latest GAN savepoint
        init_checkpoint_path_gan = get_latest_checkpoint_and_log(logdir_gan, 'model.ckpt')

        # build a separate graph for the generator
        graph_generator, generator_img_pl, x_fake_op, init_gan_op, saver_gan = test_utils.build_gen_graph(img_tensor_shape, gan_config)

        # Create a session for running Ops on the Graph.
        sess_gan = tf.Session(config=config, graph=graph_generator)

        # Run the Op to initialize the variables.
        sess_gan.run(init_gan_op)
        saver_gan.restore(sess_gan, init_checkpoint_path_gan)

        # path where the generated images are saved
        experiment_generate_path = os.path.join(image_saving_path, gan_experiment_name)
        experiment_generate_path2d = os.path.join(image_saving_path2d, gan_experiment_name)
        # make a folder for the generated images
        utils.makefolder(experiment_generate_path)
        utils.makefolder(experiment_generate_path2d)

        # make separate graphs for the last batch where the batchsize is smaller
        if clf_remainder_batch_size > 0:
            img_tensor_shape_gan_remainder = [gan_remainder_batch_size, im_s[0], im_s[1], im_s[2], 1]
            # classifier
            graph_clf_rem, image_pl_rem, predictions_clf_op_rem, init_clf_op_rem, saver_clf_rem = build_clf_graph(img_tensor_shape_gan_remainder, clf_config)
            sess_clf_rem = tf.Session(config=config, graph=graph_clf_rem)
            sess_clf_rem.run(init_clf_op_rem)
            saver_clf_rem.restore(sess_clf_rem, init_checkpoint_path_clf)

            # generator
            graph_generator_rem, generator_img_rem_pl, x_fake_op_rem, init_gan_op_rem, saver_gan_rem = \
                test_utils.build_gen_graph(img_tensor_shape_gan_remainder, gan_config)
            # Create a session for running Ops on the Graph.
            sess_gan_rem = tf.Session(config=config, graph=graph_generator_rem)
            # Run the Op to initialize the variables.
            sess_gan_rem.run(init_gan_op_rem)
            saver_gan_rem.restore(sess_gan_rem, init_checkpoint_path_gan)

        logging.info('image generation begins')
        generated_pred = []
        batch_beginning_index = 0
        # loops through all images from the source domain
        for batch in iterate_minibatches(images_test,
                                     [labels_test, ages_test],
                                     batch_size=batch_size,
                                     exp_config=clf_config,
                                     map_labels_to_standard_range=False,
                                     selection_indices=source_indices,
                                     shuffle_data=False,
                                     skip_remainder=False):
            # ignore the labels because data are in order, which means the label list in data can be used
            image_batch, [real_label, real_age] = batch

            current_batch_size = image_batch.shape[0]
            if current_batch_size < batch_size:
                fake_img = sess_gan_rem.run(x_fake_op_rem, feed_dict={generator_img_rem_pl: image_batch})
                # classify fake image
                clf_prediction_fake = sess_clf_rem.run(predictions_clf_op_rem, feed_dict={image_pl_rem: fake_img})
            else:
                fake_img = sess_gan.run(x_fake_op, feed_dict={generator_img_pl: image_batch})
                # classify fake image
                clf_prediction_fake = sess_clf.run(predictions_clf_op, feed_dict={image_pl: fake_img})

            generated_pred = generated_pred + list(clf_prediction_fake['label'])

            # save images
            current_source_indices = range(batch_beginning_index, batch_beginning_index + current_batch_size)

            # test whether minibatches are really iterated in order by checking if the labels are as expected
            assert [source_true_labels[i] for i in current_source_indices] == list(real_label)

            source_indices_to_save = image_saving_indices.intersection(set(current_source_indices))
            for source_index in source_indices_to_save:
                batch_index = source_index - batch_beginning_index
                # index of the image in the complete test data
                global_index = source_indices[source_index]
                generated_img_name = 'generated_img_%.1fT_diag%d_ind%d' % (gan_config.target_field_strength, labels_test[global_index], global_index)
                utils.save_image_and_cut(np.squeeze(fake_img[batch_index]), generated_img_name, experiment_generate_path, experiment_generate_path2d)
                logging.info(generated_img_name + ' saved')
                # save the difference g(xs)-xs
                corresponding_source_img = images_test[global_index]
                difference_image_gs = np.squeeze(fake_img[batch_index]) - corresponding_source_img
                difference_img_name = 'difference_img_%.1fT_diag%d_ind%d' % (gan_config.target_field_strength, labels_test[global_index], global_index)
                utils.save_image_and_cut(difference_image_gs, difference_img_name,
                                         experiment_generate_path, experiment_generate_path2d)
                logging.info(difference_img_name + ' saved')

            logging.info('new image batch')
            logging.info('ground truth labels: ' + str(real_label))
            logging.info('predicted labels for generated images: ' + str(clf_prediction_fake['label']))
            # no unexpected labels
            assert all([label in clf_config.label_list for label in clf_prediction_fake['label']])

            batch_beginning_index += current_batch_size
        logging.info('generated prediction for %s: %s' % (gan_experiment_name, str(generated_pred)))
        scores[gan_experiment_name] = evaluate_scores(source_true_labels, generated_pred, score_functions)

    logging.info('source prediction: ' + str(source_pred))
    logging.info('source ground truth: ' + str(source_true_labels))
    logging.info('target prediction: ' + str(target_pred))
    logging.info('target ground truth: ' + str(target_true_labels))

    scores['source_%.1fT' % gan_config0.source_field_strength] = evaluate_scores(source_true_labels, source_pred, score_functions)
    scores['target_%.1fT' % gan_config0.target_field_strength] = evaluate_scores(target_true_labels, target_pred, score_functions)

    return scores
import clf_model_multitask as model_mt
import utils
from batch_generator_list import iterate_minibatches
import data_utils
import gan_model
from collections import OrderedDict
import csv

from experiments.adni_clf import allconv_bn as exp_config

# Load data
data = adni_data_loader_all.load_and_maybe_process_data(
    input_folder=exp_config.data_root,
    preprocessing_folder=exp_config.preproc_folder,
    size=exp_config.image_size,
    target_resolution=exp_config.target_resolution,
    label_list=exp_config.label_list,
    offset=exp_config.offset,
    rescale_to_one=exp_config.rescale_to_one,
    force_overwrite=False)

for tt in ['train', 'test', 'val']:
    print(len(np.unique(data['rid_%s' % tt])))

# make list of index, label, rid
test_labels = data['diagnosis_test']
logging.info(test_labels)
with open(os.path.join(sys_config.project_root,
                       'results/final/label_list.csv'),
          'w+',
          newline='') as csvfile:
Ejemplo n.º 4
0
def generate_with_noise(gan_experiment_path_list,
                        noise_list,
                        image_saving_indices=set(),
                        image_saving_path3d=None,
                        image_saving_path2d=None):
    """

    :param gan_experiment_path_list: list of GAN experiment paths to be evaluated. They must all have the same image settings and source/target field strengths as the classifier
    :param clf_experiment_path: AD classifier used
    :param image_saving_indices: set of indices of the images to be saved
    :param image_saving_path: where to save the images. They are saved in subfolders for each experiment
    :return:
    """

    batch_size = 1
    logging.info('batch size %d is used for everything' % batch_size)

    for gan_experiment_path in gan_experiment_path_list:
        gan_config, logdir_gan = utils.load_log_exp_config(gan_experiment_path)

        gan_experiment_name = gan_config.experiment_name

        log_dir_ending = logdir_gan.split('_')[-1]
        continued_experiment = (log_dir_ending == 'cont')
        if continued_experiment:
            gan_experiment_name += '_cont'

        # make sure the noise has the right dimension
        assert gan_config.use_generator_input_noise
        assert gan_config.generator_input_noise_shape[
            1:] == std_params.generator_input_noise_shape[1:]

        # Load data
        data = adni_data_loader_all.load_and_maybe_process_data(
            input_folder=gan_config.data_root,
            preprocessing_folder=gan_config.preproc_folder,
            size=gan_config.image_size,
            target_resolution=gan_config.target_resolution,
            label_list=gan_config.label_list,
            offset=gan_config.offset,
            rescale_to_one=gan_config.rescale_to_one,
            force_overwrite=False)

        # extract images and indices of source/target images for the test set
        images_test = data['images_test']

        im_s = gan_config.image_size

        img_tensor_shape = [batch_size, im_s[0], im_s[1], im_s[2], 1]

        logging.info('\nGAN Experiment (%.1f T to %.1f T): %s' %
                     (gan_config.source_field_strength,
                      gan_config.target_field_strength, gan_experiment_name))
        logging.info(gan_config)
        # open GAN save file from the selected experiment

        # prevents ResourceExhaustError when a lot of memory is used
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True  # Do not assign whole gpu memory, just use it on the go
        config.allow_soft_placement = True  # If a operation is not defined in the default device, let it execute in another.

        source_indices = []
        target_indices = []
        for i, field_strength in enumerate(data['field_strength_test']):
            if field_strength == gan_config.source_field_strength:
                source_indices.append(i)
            elif field_strength == gan_config.target_field_strength:
                target_indices.append(i)

        num_source_images = len(source_indices)
        num_target_images = len(target_indices)

        logging.info('Data summary:')
        logging.info(' - Images:')
        logging.info(images_test.shape)
        logging.info(images_test.dtype)
        logging.info(' - Domains:')
        logging.info('number of source images: ' + str(num_source_images))
        logging.info('number of target images: ' + str(num_target_images))

        # save real images
        source_image_path = os.path.join(image_saving_path3d, 'source')
        utils.makefolder(source_image_path)
        sorted_saving_indices = sorted(image_saving_indices)

        source_saving_indices = [
            source_indices[index] for index in sorted_saving_indices
        ]
        for source_index in source_saving_indices:
            source_img_name = 'source_img_%.1fT_%d.nii.gz' % (
                gan_config.source_field_strength, source_index)
            utils.create_and_save_nii(
                images_test[source_index],
                os.path.join(source_image_path, source_img_name))
            logging.info(source_img_name + ' saved')

        logging.info('source images saved')

        logging.info('loading GAN')
        # open the latest GAN savepoint
        init_checkpoint_path_gan, last_gan_step = utils.get_latest_checkpoint_and_step(
            logdir_gan, 'model.ckpt')

        logging.info(init_checkpoint_path_gan)

        # build a separate graph for the generator
        graph_generator, generator_img_pl, z_noise_pl, x_fake_op, init_gan_op, saver_gan = build_gen_graph(
            img_tensor_shape, gan_config)

        # Create a session for running Ops on the Graph.
        sess_gan = tf.Session(config=config, graph=graph_generator)

        # Run the Op to initialize the variables.
        sess_gan.run(init_gan_op)
        saver_gan.restore(sess_gan, init_checkpoint_path_gan)

        # path where the generated images are saved
        experiment_generate_path_3d = os.path.join(
            image_saving_path_3d, gan_experiment_name +
            ('_%.1fT_source' % gan_config.source_field_strength))
        # make a folder for the generated images
        utils.makefolder(experiment_generate_path_3d)

        # path where the generated image 2d cuts are saved
        experiment_generate_path_2d = os.path.join(
            image_saving_path_2d, gan_experiment_name +
            ('_%.1fT_source' % gan_config.source_field_strength))
        # make a folder for the generated images
        utils.makefolder(experiment_generate_path_2d)

        logging.info('image generation begins')
        generated_pred = []
        batch_beginning_index = 0
        # loops through all images from the source domain
        for image_index, curr_img in zip(
                source_saving_indices,
                itertools.compress(images_test, source_saving_indices)):
            img_folder_name = 'image_test%d' % image_index
            curr_img_path_3d = os.path.join(experiment_generate_path_3d,
                                            img_folder_name)
            utils.makefolder(curr_img_path_3d)
            curr_img_path_2d = os.path.join(experiment_generate_path_2d,
                                            img_folder_name)
            utils.makefolder(curr_img_path_2d)
            # save source image
            source_img_name = 'source_img'
            utils.save_image_and_cut(np.squeeze(curr_img),
                                     source_img_name,
                                     curr_img_path_3d,
                                     curr_img_path_2d,
                                     vmin=-1,
                                     vmax=1)
            logging.info(source_img_name + ' saved')
            img_list = []
            for noise_index, noise in enumerate(noise_list):
                fake_img = sess_gan.run(x_fake_op,
                                        feed_dict={
                                            generator_img_pl:
                                            np.reshape(curr_img,
                                                       img_tensor_shape),
                                            z_noise_pl:
                                            noise
                                        })
                fake_img = np.squeeze(fake_img)
                # make sure the dimensions are right
                assert len(fake_img.shape) == 3

                img_list.append(fake_img)

                generated_img_name = 'generated_img_noise_%d' % (noise_index)
                utils.save_image_and_cut(np.squeeze(fake_img),
                                         generated_img_name,
                                         curr_img_path_3d,
                                         curr_img_path_2d,
                                         vmin=-1,
                                         vmax=1)
                logging.info(generated_img_name + ' saved')

                # save the difference g(xs)-xs
                difference_image_gs = np.squeeze(fake_img) - curr_img
                difference_img_name = 'difference_img_noise_%d' % (noise_index)
                utils.save_image_and_cut(difference_image_gs,
                                         difference_img_name,
                                         curr_img_path_3d,
                                         curr_img_path_2d,
                                         vmin=-1,
                                         vmax=1)
                logging.info(difference_img_name + ' saved')

            # works because axis 0
            all_imgs = np.stack(img_list, axis=0)
            std_img = np.std(all_imgs, axis=0)
            std_img_name = 'std_img'
            utils.save_image_and_cut(std_img,
                                     std_img_name,
                                     curr_img_path_3d,
                                     curr_img_path_2d,
                                     vmin=0,
                                     vmax=1)
            logging.info(std_img_name + ' saved')

        logging.info('generated all images for %s' % (gan_experiment_name))
Ejemplo n.º 5
0
def run_training(continue_run, log_dir):

    logging.info('===== RUNNING EXPERIMENT ========')
    logging.info(exp_config.experiment_name)
    logging.info('=================================')

    init_step = 0

    if continue_run:
        logging.info('!!!!!!!!!!!!!!!!!!!!!!!!!!!! Continuing previous run !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
        try:
            init_checkpoint_path = utils.get_latest_model_checkpoint_path(log_dir, 'model.ckpt')
            logging.info('Checkpoint path: %s' % init_checkpoint_path)
            init_step = int(init_checkpoint_path.split('/')[-1].split('-')[-1]) + 1  # plus 1 b/c otherwise starts with eval
            logging.info('Latest step was: %d' % init_step)
            log_dir += '_cont'

        except:
            logging.warning('!!! Didnt find init checkpoint. Maybe first run failed. Disabling continue mode...')
            continue_run = False
            init_step = 0

        logging.info('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')

    # import data
    data = adni_data_loader_all.load_and_maybe_process_data(
        input_folder=exp_config.data_root,
        preprocessing_folder=exp_config.preproc_folder,
        size=exp_config.image_size,
        target_resolution=exp_config.target_resolution,
        label_list = exp_config.label_list,
        offset=exp_config.offset,
        rescale_to_one=exp_config.rescale_to_one,
        force_overwrite=False
    )

    # extract images and indices of source/target images for the training and validation set
    images_train, source_images_train_ind, target_images_train_ind,\
    images_val, source_images_val_ind, target_images_val_ind = data_utils.get_images_and_fieldstrength_indices(
        data, exp_config.source_field_strength, exp_config.target_field_strength)

    generator = exp_config.generator
    discriminator = exp_config.discriminator

    z_sampler_train = iterate_minibatches_endlessly(images_train,
                                                    batch_size=exp_config.batch_size,
                                                    exp_config=exp_config,
                                                    selection_indices=source_images_train_ind)
    x_sampler_train = iterate_minibatches_endlessly(images_train,
                                                    batch_size=exp_config.batch_size,
                                                    exp_config=exp_config,
                                                    selection_indices=target_images_train_ind)


    with tf.Graph().as_default():

        # Generate placeholders for the images and labels.

        im_s = exp_config.image_size

        training_placeholder = tf.placeholder(tf.bool, name='training_phase')

        if exp_config.use_generator_input_noise:
            noise_in_gen_pl = tf.random_uniform(shape=exp_config.generator_input_noise_shape, minval=-1, maxval=1)
        else:
            noise_in_gen_pl = None

        # target image batch
        x_pl = tf.placeholder(tf.float32, [exp_config.batch_size, im_s[0], im_s[1], im_s[2], exp_config.n_channels], name='x')

        # source image batch
        z_pl = tf.placeholder(tf.float32, [exp_config.batch_size, im_s[0], im_s[1], im_s[2], exp_config.n_channels], name='z')

        # generated fake image batch
        x_pl_ = generator(z_pl, noise_in_gen_pl, training_placeholder)

        # difference between generated and source images
        diff_img_pl = x_pl_ - z_pl

        # visualize the images by showing one slice of them in the z direction
        tf.summary.image('sample_outputs', tf_utils.put_kernels_on_grid3d(x_pl_, exp_config.cut_axis,
                                                                          exp_config.cut_index, rescale_mode='manual',
                                                                          input_range=exp_config.image_range))

        tf.summary.image('sample_xs', tf_utils.put_kernels_on_grid3d(x_pl, exp_config.cut_axis,
                                                                          exp_config.cut_index, rescale_mode='manual',
                                                                          input_range=exp_config.image_range))

        tf.summary.image('sample_zs', tf_utils.put_kernels_on_grid3d(z_pl, exp_config.cut_axis,
                                                                          exp_config.cut_index, rescale_mode='manual',
                                                                          input_range=exp_config.image_range))

        tf.summary.image('sample_difference_gx-x', tf_utils.put_kernels_on_grid3d(diff_img_pl, exp_config.cut_axis,
                                                                          exp_config.cut_index, rescale_mode='centered',
                                                                          cutoff_abs=exp_config.diff_threshold))

        # output of the discriminator for real image
        d_pl = discriminator(x_pl, training_placeholder, scope_reuse=False)

        # output of the discriminator for fake image
        d_pl_ = discriminator(x_pl_, training_placeholder, scope_reuse=True)

        d_hat = None
        x_hat = None
        if exp_config.improved_training:

            epsilon = tf.random_uniform([], 0.0, 1.0)
            x_hat = epsilon * x_pl + (1 - epsilon) * x_pl_
            d_hat = discriminator(x_hat, training_placeholder, scope_reuse=True)

        dist_l1 = tf.reduce_mean(tf.abs(diff_img_pl))

        # nr means no regularization, meaning the loss without the regularization term
        discriminator_train_op, generator_train_op, \
        disc_loss_pl, gen_loss_pl, \
        disc_loss_nr_pl, gen_loss_nr_pl = gan_model.training_ops(d_pl, d_pl_,
                                                                 optimizer_handle=exp_config.optimizer_handle,
                                                                 learning_rate=exp_config.learning_rate,
                                                                 l1_img_dist=dist_l1,
                                                                 w_reg_img_dist_l1=exp_config.w_reg_img_dist_l1,
                                                                 w_reg_gen_l1=exp_config.w_reg_gen_l1,
                                                                 w_reg_disc_l1=exp_config.w_reg_disc_l1,
                                                                 w_reg_gen_l2=exp_config.w_reg_gen_l2,
                                                                 w_reg_disc_l2=exp_config.w_reg_disc_l2,
                                                                 d_hat=d_hat, x_hat=x_hat, scale=exp_config.scale)


        # Build the operation for clipping the discriminator weights
        d_clip_op = gan_model.clip_op()

        # Put L1 distance of generated image and original image on summary
        dist_l1_summary_op = tf.summary.scalar('L1_distance_to_source_img', dist_l1)

        # Build the summary Tensor based on the TF collection of Summaries.
        summary_op = tf.summary.merge_all()

        # validation summaries
        val_disc_loss_pl = tf.placeholder(tf.float32, shape=[], name='disc_val_loss')
        disc_val_summary_op = tf.summary.scalar('validation_discriminator_loss', val_disc_loss_pl)

        val_gen_loss_pl = tf.placeholder(tf.float32, shape=[], name='gen_val_loss')
        gen_val_summary_op = tf.summary.scalar('validation_generator_loss', val_gen_loss_pl)

        val_summary_op = tf.summary.merge([disc_val_summary_op, gen_val_summary_op])

        # Add the variable initializer Op.
        init = tf.global_variables_initializer()

        # Create a savers for writing training checkpoints.
        saver_latest = tf.train.Saver(max_to_keep=3)
        saver_best_disc = tf.train.Saver(max_to_keep=3)  # disc loss is scaled negative EM distance

        # prevents ResourceExhaustError when a lot of memory is used
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True  # Do not assign whole gpu memory, just use it on the go
        config.allow_soft_placement = True  # If a operation is not defined in the default device, let it execute in another.

        # Create a session for running Ops on the Graph.
        sess = tf.Session(config=config)

        summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

        # Run the Op to initialize the variables.
        sess.run(init)

        if continue_run:
            # Restore session
            saver_latest.restore(sess, init_checkpoint_path)


        # initialize value of lowest (i. e. best) discriminator loss
        best_d_loss = np.inf

        for step in range(init_step, 1000000):

            start_time = time.time()

            # discriminator training iterations
            d_iters = 5
            if step % 500 == 0 or step < 25:
                d_iters = 100

            for _ in range(d_iters):

                x = next(x_sampler_train)
                z = next(z_sampler_train)

                # train discriminator
                sess.run(discriminator_train_op,
                         feed_dict={z_pl: z, x_pl: x, training_placeholder: True})

                if not exp_config.improved_training:
                    sess.run(d_clip_op)

            elapsed_time = time.time() - start_time

            # train generator
            x = next(x_sampler_train)  # why not sample a new x??
            z = next(z_sampler_train)
            sess.run(generator_train_op,
                     feed_dict={z_pl: z, x_pl: x, training_placeholder: True})

            if step % exp_config.update_tensorboard_frequency == 0:

                x = next(x_sampler_train)
                z = next(z_sampler_train)

                g_loss_train, d_loss_train, summary_str = sess.run(
                        [gen_loss_nr_pl, disc_loss_nr_pl, summary_op], feed_dict={z_pl: z, x_pl: x, training_placeholder: False})

                summary_writer.add_summary(summary_str, step)
                summary_writer.flush()

                logging.info("[Step: %d], generator loss: %g, discriminator_loss: %g" % (step, g_loss_train, d_loss_train))
                logging.info(" - elapsed time for one step: %f secs" % elapsed_time)


            if step % exp_config.validation_frequency == 0:

                z_sampler_val = iterate_minibatches_endlessly(images_val,
                                                    batch_size=exp_config.batch_size,
                                                    exp_config=exp_config,
                                                    selection_indices=source_images_val_ind)
                x_sampler_val = iterate_minibatches_endlessly(images_val,
                                                    batch_size=exp_config.batch_size,
                                                    exp_config=exp_config,
                                                    selection_indices=target_images_val_ind)

                # evaluate the validation batch with batch_size images (from each domain) at a time
                g_loss_val_list = []
                d_loss_val_list = []
                for _ in range(exp_config.num_val_batches):
                    x = next(x_sampler_val)
                    z = next(z_sampler_val)
                    g_loss_val, d_loss_val = sess.run(
                        [gen_loss_nr_pl, disc_loss_nr_pl], feed_dict={z_pl: z,
                                                                      x_pl: x,
                                                                      training_placeholder: False})
                    g_loss_val_list.append(g_loss_val)
                    d_loss_val_list.append(d_loss_val)

                g_loss_val_avg = np.mean(g_loss_val_list)
                d_loss_val_avg = np.mean(d_loss_val_list)

                validation_summary_str = sess.run(val_summary_op, feed_dict={val_disc_loss_pl: d_loss_val_avg,
                                                                                 val_gen_loss_pl: g_loss_val_avg}
                                             )
                summary_writer.add_summary(validation_summary_str, step)
                summary_writer.flush()

                # save best variables (if discriminator loss is the lowest yet)
                if d_loss_val_avg <= best_d_loss:
                    best_d_loss = d_loss_val_avg
                    best_file = os.path.join(log_dir, 'model_best_d_loss.ckpt')
                    saver_best_disc.save(sess, best_file, global_step=step)
                    logging.info('Found new best discriminator loss on validation set! - %f -  Saving model_best_d_loss.ckpt' % best_d_loss)

                logging.info("[Validation], generator loss: %g, discriminator_loss: %g" % (g_loss_val_avg, d_loss_val_avg))

            # Write the summaries and print an overview fairly often.
            if step % exp_config.save_frequency == 0:

                saver_latest.save(sess, os.path.join(log_dir, 'model.ckpt'), global_step=step)
Ejemplo n.º 6
0
def classifier_test(clf_experiment_path,
                    score_functions,
                    batch_size=1,
                    balanced_test=True,
                    checkpoint_file_name='model_best_xent.ckpt'):
    """

    :param clf_experiment_path: AD classifier used
    :return:
    """

    clf_config, logdir_clf = utils.load_log_exp_config(clf_experiment_path)

    # Load data
    data = adni_data_loader_all.load_and_maybe_process_data(
        input_folder=clf_config.data_root,
        preprocessing_folder=clf_config.preproc_folder,
        size=clf_config.image_size,
        target_resolution=clf_config.target_resolution,
        label_list=clf_config.label_list,
        offset=clf_config.offset,
        rescale_to_one=clf_config.rescale_to_one,
        force_overwrite=False)

    # extract images and indices of source/target images for the test set
    images_test = data['images_test']
    labels_test = data['diagnosis_test']
    ages_test = data['age_test']

    logging.info('batch size %d is used for classifier' % batch_size)
    img_tensor_shape = [None] + list(clf_config.image_size) + [1]

    # prevents ResourceExhaustError when a lot of memory is used
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True  # Do not assign whole gpu memory, just use it on the go
    config.allow_soft_placement = True  # If a operation is not defined in the default device, let it execute in another.

    # open field strength classifier save file from the selected experiment
    logging.info("loading Alzheimer's disease classifier")
    graph_clf, image_pl, predictions_clf_op, init_clf_op, saver_clf = test_utils.build_clf_graph(
        img_tensor_shape, clf_config)
    logging.info("getting savepoint %s" % checkpoint_file_name)
    init_checkpoint_path_clf, latest_step = utils.get_latest_checkpoint_and_step(
        logdir_clf, checkpoint_file_name)
    # logging.info("getting savepoint with the best f1 score")
    # init_checkpoint_path_clf = get_latest_checkpoint_and_log(logdir_clf, 'model_best_diag_f1.ckpt')
    sess_clf = tf.Session(config=config, graph=graph_clf)
    sess_clf.run(init_clf_op)  # probably not necessary
    saver_clf.restore(sess_clf, init_checkpoint_path_clf)

    # classifiy all real test images
    logging.info('classify all test images')
    all_predictions = []
    ground_truth_labels = []
    for batch in iterate_minibatches(images_test, [labels_test, ages_test],
                                     batch_size=batch_size,
                                     exp_config=clf_config,
                                     map_labels_to_standard_range=False,
                                     shuffle_data=False,
                                     skip_remainder=False):
        # ignore the labels because data are in order, which means the label list in data can be used
        image_batch, [real_label, real_age] = batch

        current_batch_size = image_batch.shape[0]
        clf_prediction_batch = sess_clf.run(predictions_clf_op,
                                            feed_dict={image_pl: image_batch})

        all_predictions = all_predictions + list(clf_prediction_batch['label'])
        ground_truth_labels = ground_truth_labels + list(real_label)
        logging.info('new image batch')
        logging.info('ground truth labels: ' + str(real_label))
        logging.info('predicted labels: ' + str(clf_prediction_batch['label']))

    # check that the data has really been iterated in order and in full
    assert np.array_equal(ground_truth_labels, labels_test)

    source_indices = []
    target_indices = []
    source_true_labels = []
    source_pred = []
    target_true_labels = []
    target_pred = []
    for i, field_strength in enumerate(data['field_strength_test']):
        if field_strength == clf_config.source_field_strength:
            source_indices.append(i)
            source_true_labels.append(labels_test[i])
        elif field_strength == clf_config.target_field_strength:
            target_indices.append(i)
            target_true_labels.append(labels_test[i])

    # check that the source and target images together are all images
    all_indices = source_indices + target_indices
    all_indices.sort()
    assert np.array_equal(all_indices, range(images_test.shape[0]))

    source_label_count = Counter(source_true_labels)
    target_label_count = Counter(target_true_labels)
    logging.info('before balancing')
    logging.info('source labels count: ' + str(source_label_count))
    logging.info('target labels count: ' + str(target_label_count))

    # throw away some data from source and target such that they have the same AD/normal ratio
    # this stratified test dataset should make comparisons between the scores with the different test sets more meaningful
    # the seed makes sure that the new test data are always the same
    if balanced_test:
        (source_indices_new, source_true_labels_new), (
            target_indices_new,
            target_true_labels_new) = utils.balance_source_target(
                (source_indices, source_true_labels),
                (target_indices, target_true_labels),
                random_seed=0)
        all_indices = source_indices_new + target_indices_new
        all_indices.sort()
        labels_test = [
            label for ind, label in enumerate(labels_test)
            if ind in all_indices
        ]

        # to make sure the new indices and labels are subsets of the old ones
        source_label_count = Counter(source_true_labels_new)
        target_label_count = Counter(target_true_labels_new)
        logging.info('balanced the test set')
        logging.info('source labels count: ' + str(source_label_count))
        logging.info('target labels count: ' + str(target_label_count))

        source_set_new = set(source_indices_new)
        target_set_new = set(target_indices_new)
        # check if the new indices are a subset of the old ones
        assert source_set_new <= set(source_indices)
        assert target_set_new <= set(target_indices)
        # check for duplicates
        assert len(source_set_new) == len(source_indices_new)
        assert len(target_set_new) == len(target_indices_new)
        # make tuples of (index, label) to check if the new index label pairs are a subset of the old ones
        source_tuples = utils.tuple_of_lists_to_list_of_tuples(
            (source_indices, source_true_labels))
        target_tuples = utils.tuple_of_lists_to_list_of_tuples(
            (target_indices, target_true_labels))
        source_tuples_new = utils.tuple_of_lists_to_list_of_tuples(
            (source_indices_new, source_true_labels_new))
        target_tuples_new = utils.tuple_of_lists_to_list_of_tuples(
            (target_indices_new, target_true_labels_new))
        assert set(source_tuples_new) <= set(source_tuples)
        assert set(target_tuples_new) <= set(target_tuples)

        [(source_indices, source_true_labels), (target_indices, target_true_labels)] = \
            [(source_indices_new, source_true_labels_new), (target_indices_new, target_true_labels_new)]

    source_pred = [all_predictions[ind] for ind in source_indices]
    target_pred = [all_predictions[ind] for ind in target_indices]

    # no unexpected labels
    assert all(
        [label in clf_config.label_list for label in source_true_labels])
    assert all(
        [label in clf_config.label_list for label in target_true_labels])
    assert all([label in clf_config.label_list for label in source_pred])
    assert all([label in clf_config.label_list for label in target_pred])

    num_source_images = len(source_indices)
    num_target_images = len(target_indices)

    assert set(source_indices).isdisjoint(target_indices)
    assert num_source_images == len(source_true_labels)
    assert num_source_images == len(source_true_labels)
    assert num_target_images == len(target_true_labels)
    assert num_target_images == len(target_true_labels)
    assert num_target_images + num_source_images == len(labels_test)

    if balanced_test:
        assert num_source_images == num_target_images

    label_count = Counter(labels_test)
    assert label_count == source_label_count + target_label_count

    logging.info('Data summary:')
    logging.info(' - Images (before reduction):')
    logging.info(images_test.shape)
    logging.info(images_test.dtype)
    logging.info(' - Labels:')
    logging.info(len(labels_test))
    logging.info('number of images for each label')
    logging.info(label_count)
    logging.info(' - Domains:')
    logging.info('number of source images: ' + str(num_source_images))
    logging.info('source label distribution ' + str(source_label_count))
    logging.info('number of target images: ' + str(num_target_images))
    logging.info('target label distribution ' + str(target_label_count))

    # find out how many unique subjects there are in the test set
    rid_numbers = data['rid_test']
    reduced_rid_numbers = [
        number for ind, number in enumerate(rid_numbers) if ind in all_indices
    ]
    logging.info('number of unique subjects: %d' %
                 len(np.unique(reduced_rid_numbers)))

    scores = {}

    logging.info('source prediction: ' + str(source_pred))
    logging.info('source ground truth: ' + str(source_true_labels))
    logging.info('target prediction: ' + str(target_pred))
    logging.info('target ground truth: ' + str(target_true_labels))

    scores[clf_config.source_field_strength] = test_utils.evaluate_scores(
        source_true_labels, source_pred, score_functions)
    scores[clf_config.target_field_strength] = test_utils.evaluate_scores(
        target_true_labels, target_pred, score_functions)
    true_labels_together = source_true_labels + target_true_labels
    pred_together = source_pred + target_pred
    scores['all data'] = test_utils.evaluate_scores(true_labels_together,
                                                    pred_together,
                                                    score_functions)
    # dictionary sorted by key
    sorted_scores = OrderedDict(sorted(scores.items(),
                                       key=lambda t: str(t[0])))

    return sorted_scores, latest_step
Ejemplo n.º 7
0
def run_training(continue_run, log_dir):

    logging.info('===== RUNNING EXPERIMENT ========')
    logging.info(exp_config.experiment_name)
    logging.info('=================================')

    init_step = 0

    if continue_run:
        logging.info('!!!!!!!!!!!!!!!!!!!!!!!!!!!! Continuing previous run !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
        try:
            init_checkpoint_path = utils.get_latest_model_checkpoint_path(log_dir, 'model.ckpt')
            logging.info('Checkpoint path: %s' % init_checkpoint_path)
            init_step = int(init_checkpoint_path.split('/')[-1].split('-')[-1]) + 1  # plus 1 b/c otherwise starts with eval
            logging.info('Latest step was: %d' % init_step)
            log_dir += '_cont'
        except:
            logging.warning('!!! Didnt find init checkpoint. Maybe first run failed. Disabling continue mode...')
            continue_run = False
            init_step = 0

        logging.info('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')

    # import data
    data = adni_data_loader_all.load_and_maybe_process_data(
        input_folder=exp_config.data_root,
        preprocessing_folder=exp_config.preproc_folder,
        size=exp_config.image_size,
        target_resolution=exp_config.target_resolution,
        label_list = exp_config.label_list,
        offset=exp_config.offset,
        rescale_to_one=exp_config.rescale_to_one,
        force_overwrite=False
    )

    # extract images and indices of source/target images for the training and validation set
    images_train, source_images_train_ind, target_images_train_ind,\
    images_val, source_images_val_ind, target_images_val_ind = data_utils.get_images_and_fieldstrength_indices(
        data, exp_config.source_field_strength, exp_config.target_field_strength)

    # get labels
    # the following are HDF5 datasets, not numpy arrays
    labels_train = data['diagnosis_train']
    ages_train = data['age_train']
    labels_val = data['diagnosis_val']
    ages_val = data['age_val']

    if exp_config.age_ordinal_regression:
        ages_train = utils.age_to_ordinal_reg_format(ages_train, bins=exp_config.age_bins)
        ordinal_reg_weights = utils.get_ordinal_reg_weights(ages_train)
    else:
        ages_train = utils.age_to_bins(ages_train, bins=exp_config.age_bins)
        ordinal_reg_weights = None

    if exp_config.age_ordinal_regression:
        ages_val = utils.age_to_ordinal_reg_format(ages_val, bins=exp_config.age_bins)
    else:
        ages_val= utils.age_to_bins(ages_val, bins=exp_config.age_bins)

    generator = exp_config.generator
    discriminator = exp_config.discriminator
    augmentation_function = exp_config.augmentation_function if exp_config.use_augmentation else None

    s_sampler_train = iterate_minibatches_endlessly(images_train,
                                                    batch_size=2*exp_config.batch_size,
                                                    exp_config=exp_config,
                                                    labels_list=[labels_train, ages_train],
                                                    selection_indices=source_images_train_ind,
                                                    augmentation_function=augmentation_function)

    t_sampler_train = iterate_minibatches_endlessly(images_train,
                                                    batch_size=exp_config.batch_size,
                                                    exp_config=exp_config,
                                                    labels_list=[labels_train, ages_train],
                                                    selection_indices=target_images_train_ind,
                                                    augmentation_function=augmentation_function)


    with tf.Graph().as_default():

        training_time_placeholder = tf.placeholder(tf.bool, shape=[], name='training_time')

        # GAN

        # input noise for generator
        if exp_config.use_generator_input_noise:
            noise_in_gen_pl = tf.random_uniform(shape=exp_config.generator_input_noise_shape, minval=-1, maxval=1)
        else:
            noise_in_gen_pl = None

        # target image batch
        xt_pl = tf.placeholder(tf.float32, image_tensor_shape(exp_config.batch_size), name='x_target')

        # the classifier uses 2 times the batch size of the GAN
        clf_batch_size = 2 * exp_config.batch_size

        # source image batch
        xs_pl, diag_s_pl, ages_s_pl = placeholders_clf(clf_batch_size, 'source')
        # split source batch into 1 to be translated to xf and 2 for the classifier
        # for the discriminator train op half 2 of the batch is not used
        xs1_pl, xs2_pl = tf.split(xs_pl, 2, axis=0)

        # generated fake image batch
        xf_pl = generator(xs1_pl, noise_in_gen_pl, training_time_placeholder)

        # difference between generated and source images
        diff_img_pl = xf_pl - xs1_pl

        # visualize the images by showing one slice of them in the z direction
        tf.summary.image('sample_outputs', tf_utils.put_kernels_on_grid3d(xf_pl, exp_config.cut_axis,
                                                                          exp_config.cut_index, rescale_mode='manual',
                                                                          input_range=exp_config.image_range))

        tf.summary.image('sample_xt', tf_utils.put_kernels_on_grid3d(xt_pl, exp_config.cut_axis,
                                                                          exp_config.cut_index, rescale_mode='manual',
                                                                          input_range=exp_config.image_range))

        tf.summary.image('sample_xs', tf_utils.put_kernels_on_grid3d(xs1_pl, exp_config.cut_axis,
                                                                          exp_config.cut_index, rescale_mode='manual',
                                                                          input_range=exp_config.image_range))

        tf.summary.image('sample_difference_xf-xs', tf_utils.put_kernels_on_grid3d(diff_img_pl, exp_config.cut_axis,
                                                                          exp_config.cut_index, rescale_mode='centered',
                                                                          cutoff_abs=exp_config.diff_threshold))

        # output of the discriminator for real image
        d_pl = discriminator(xt_pl, training_time_placeholder, scope_reuse=False)

        # output of the discriminator for fake image
        d_pl_ = discriminator(xf_pl, training_time_placeholder, scope_reuse=True)

        d_hat = None
        x_hat = None
        if exp_config.improved_training:

            epsilon = tf.random_uniform([], 0.0, 1.0)
            x_hat = epsilon * xt_pl + (1 - epsilon) * xf_pl
            d_hat = discriminator(x_hat, training_time_placeholder, scope_reuse=True)

        dist_l1 = tf.reduce_mean(tf.abs(diff_img_pl))

        learning_rate_gan_pl = tf.placeholder(tf.float32, shape=[], name='learning_rate')
        learning_rate_clf_pl = tf.placeholder(tf.float32, shape=[], name='learning_rate')

        if exp_config.momentum is not None:
            optimizer_handle = lambda learning_rate: exp_config.optimizer_handle(learning_rate=learning_rate,
                                                    momentum=exp_config.momentum)
        else:
            optimizer_handle = lambda learning_rate: exp_config.optimizer_handle(learning_rate=learning_rate)

        # Build the operation for clipping the discriminator weights
        d_clip_op = gan_model.clip_op()

        # Put L1 distance of generated image and original image on summary
        dist_l1_summary_op = tf.summary.scalar('L1_distance_to_source_img', dist_l1)

        # Classifier ----------------------------------------------------------------------------------------
        # for training usually false so xt and xf get concatenated as classifier input, otherwise
        directly_feed_clf_pl = tf.placeholder(tf.bool, shape=[], name='direct_classifier_feeding')

        # conditionally assign either a concatenation of the generated dataset and the source data
        # cond to avoid having to specify not needed placeholders in the feed dict
        images_clf, diag_clf, ages_clf = tf.cond(
            directly_feed_clf_pl,
            lambda: placeholders_clf(clf_batch_size, 'direct_clf'),
            lambda: concatenate_clf_input([xf_pl, xs2_pl], diag_s_pl, ages_s_pl, scope_name = 'fs_concat')
        )

        tf.summary.scalar('learning_rate_gan', learning_rate_gan_pl)
        tf.summary.scalar('learning_rate_clf', learning_rate_clf_pl)

        # Build a Graph that computes predictions from the inference model.
        diag_logits_train, ages_logits_train = exp_config.clf_model_handle(images_clf,
                                                                           nlabels=exp_config.nlabels,
                                                                           training=training_time_placeholder,
                                                                           n_age_thresholds=len(exp_config.age_bins),
                                                                           bn_momentum=exp_config.bn_momentum)

        # Add to the Graph the Ops for loss calculation.

        [classifier_loss, diag_loss, age_loss, weights_norm_clf] = clf_model_mt.loss(diag_logits_train,
                                                                                 ages_logits_train,
                                                                                 diag_clf,
                                                                                 ages_clf,
                                                                                 nlabels=exp_config.nlabels,
                                                                                 weight_decay=exp_config.weight_decay,
                                                                                 diag_weight=exp_config.diag_weight,
                                                                                 age_weight=exp_config.age_weight,
                                                                                 use_ordinal_reg=exp_config.age_ordinal_regression,
                                                                                 ordinal_reg_weights=ordinal_reg_weights)

        # nr means no regularization, meaning the loss without the regularization term
        train_ops_dict, losses_gan_dict = joint_model.training_ops(d_pl, d_pl_,
                                                             classifier_loss,
                                                             optimizer_handle=optimizer_handle,
                                                             learning_rate_gan=learning_rate_gan_pl,
                                                             learning_rate_clf=learning_rate_clf_pl,
                                                             l1_img_dist=dist_l1,
                                                             gan_loss_weight=exp_config.gan_loss_weight,
                                                             task_loss_weight=exp_config.task_loss_weight,
                                                             w_reg_img_dist_l1=exp_config.w_reg_img_dist_l1,
                                                             w_reg_gen_l1=exp_config.w_reg_gen_l1,
                                                             w_reg_disc_l1=exp_config.w_reg_disc_l1,
                                                             w_reg_gen_l2=exp_config.w_reg_gen_l2,
                                                             w_reg_disc_l2=exp_config.w_reg_disc_l2,
                                                             d_hat=d_hat, x_hat=x_hat, scale=exp_config.scale)


        tf.summary.scalar('classifier loss', classifier_loss)
        tf.summary.scalar('diag_loss', diag_loss)
        tf.summary.scalar('age_loss', age_loss)
        tf.summary.scalar('weights_norm_term_classifier', weights_norm_clf)
        tf.summary.scalar('generator loss joint', losses_gan_dict['gen']['joint'])
        tf.summary.scalar('discriminator loss joint', losses_gan_dict['disc']['joint'])

        eval_diag_loss, eval_ages_loss, pred_labels, ages_softmaxs = clf_model_mt.evaluation(diag_logits_train, ages_logits_train,
                                                                                             diag_clf,
                                                                                             ages_clf,
                                                                                             images_clf,
                                                                                             diag_weight=exp_config.diag_weight,
                                                                                             age_weight=exp_config.age_weight,
                                                                                             nlabels=exp_config.nlabels,
                                                                                             use_ordinal_reg=exp_config.age_ordinal_regression)

        # Build the summary Tensor based on the TF collection of Summaries.
        summary = tf.summary.merge_all()


        # Add the variable initializer Op.
        init = tf.global_variables_initializer()

        # Create a savers for writing training checkpoints.
        saver_latest = tf.train.Saver(max_to_keep=2)
        saver_best_disc = tf.train.Saver(max_to_keep=2)  # disc loss is scaled negative EM distance
        saver_best_diag_f1 = tf.train.Saver(max_to_keep=5)
        saver_best_ages_f1 = tf.train.Saver(max_to_keep=1)
        saver_best_xent = tf.train.Saver(max_to_keep=5)

        # validation summaries gan
        val_disc_loss_pl = tf.placeholder(tf.float32, shape=[], name='disc_val_loss')
        disc_val_summary_op = tf.summary.scalar('validation_discriminator_loss', val_disc_loss_pl)

        val_gen_loss_pl = tf.placeholder(tf.float32, shape=[], name='gen_val_loss')
        gen_val_summary_op = tf.summary.scalar('validation_generator_loss', val_gen_loss_pl)

        val_summary_gan = tf.summary.merge([disc_val_summary_op, gen_val_summary_op])

        # Classifier summary
        val_error_clf_ = tf.placeholder(tf.float32, shape=[], name='val_error_diag')
        val_error_summary = tf.summary.scalar('classifier_validation_loss', val_error_clf_)

        val_diag_f1_score_ = tf.placeholder(tf.float32, shape=[], name='val_diag_f1')
        val_f1_diag_summary = tf.summary.scalar('validation_diag_f1', val_diag_f1_score_)

        val_ages_f1_score_ = tf.placeholder(tf.float32, shape=[], name='val_ages_f1')
        val_f1_ages_summary = tf.summary.scalar('validation_ages_f1', val_ages_f1_score_)

        val_summary_clf = tf.summary.merge([val_error_summary, val_f1_diag_summary, val_f1_ages_summary])
        val_summary = tf.summary.merge([val_summary_clf, val_summary_gan])

        train_error_clf_ = tf.placeholder(tf.float32, shape=[], name='train_error_diag')
        train_error_clf_summary = tf.summary.scalar('classifier_training_loss', train_error_clf_)

        train_diag_f1_score_ = tf.placeholder(tf.float32, shape=[], name='train_diag_f1')
        train_diag_f1_summary = tf.summary.scalar('training_diag_f1', train_diag_f1_score_)

        train_ages_f1_score_ = tf.placeholder(tf.float32, shape=[], name='train_ages_f1')
        train_f1_ages_summary = tf.summary.scalar('training_ages_f1', train_ages_f1_score_)

        train_summary = tf.summary.merge([train_error_clf_summary, train_diag_f1_summary, train_f1_ages_summary])

        # prevents ResourceExhaustError when a lot of memory is used
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True  # Do not assign whole gpu memory, just use it on the go
        config.allow_soft_placement = True  # If a operation is not defined in the default device, let it execute in another.

        # Create a session for running Ops on the Graph.
        sess = tf.Session(config=config)

        summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

        sess.graph.finalize()

        # Run the Op to initialize the variables.
        sess.run(init)

        if continue_run:
            # Restore session
            saver_latest.restore(sess, init_checkpoint_path)

        curr_lr_gan = exp_config.learning_rate_gan
        curr_lr_clf = exp_config.learning_rate_clf

        no_improvement_counter = 0
        best_val = np.inf
        last_train = np.inf
        loss_history = []
        loss_gradient = np.inf
        best_diag_f1_score = 0
        best_ages_f1_score = 0
        # initialize value of lowest (i. e. best) discriminator loss
        best_d_loss = np.inf

        for step in range(init_step, exp_config.max_steps):

            start_time = time.time()

            # discriminator and classifier (task) training iterations
            d_iters = 5
            t_iters = 1
            if step % 500 == 0 or step < 25:
                d_iters = 100
            for iteration in range(max(d_iters, t_iters)):

                x_t, [diag_t, age_t] = next(t_sampler_train)
                x_s, [diag_s, age_s] = next(s_sampler_train)

                feed_dict_dc = {xs_pl: x_s,
                                xt_pl: x_t,
                                learning_rate_gan_pl: curr_lr_gan,
                                learning_rate_clf_pl: curr_lr_clf,
                                diag_s_pl: diag_s,
                                ages_s_pl: age_s,
                                training_time_placeholder: True,
                                directly_feed_clf_pl: False}
                train_ops_list_dc = []
                if iteration < t_iters:
                    # train classifier
                    train_ops_list_dc.append(train_ops_dict['clf'])

                if iteration < d_iters:
                    # train discriminator
                    train_ops_list_dc.append(train_ops_dict['disc'])

                sess.run(train_ops_list_dc, feed_dict=feed_dict_dc)

                if not exp_config.improved_training:
                    sess.run(d_clip_op)

            elapsed_time = time.time() - start_time

            # train generator
            x_t, [diag_t, age_t] = next(t_sampler_train)
            x_s, [diag_s, age_s] = next(s_sampler_train)
            sess.run(train_ops_dict['gen'],
                     feed_dict={xs_pl: x_s,
                                xt_pl: x_t,
                                learning_rate_gan_pl: curr_lr_gan,
                                learning_rate_clf_pl: curr_lr_clf,
                                diag_s_pl: diag_s,
                                ages_s_pl: age_s,
                                training_time_placeholder: True,
                                directly_feed_clf_pl: False
                                })

            if step % exp_config.update_tensorboard_frequency == 0:
                x_t, [diag_t, age_t] = next(t_sampler_train)
                x_s, [diag_s, age_s] = next(s_sampler_train)

                feed_dict_summary={xs_pl: x_s,
                                    xt_pl: x_t,
                                    learning_rate_gan_pl: curr_lr_gan,
                                    learning_rate_clf_pl: curr_lr_clf,
                                    diag_s_pl: diag_s,
                                    ages_s_pl: age_s,
                                    training_time_placeholder: True,
                                    directly_feed_clf_pl: False
                                    }

                c_loss_one_batch, gan_losses_one_batch_dict, summary_str = sess.run(
                        [classifier_loss, losses_gan_dict, summary], feed_dict=feed_dict_summary)

                summary_writer.add_summary(summary_str, step)
                summary_writer.flush()

                logging.info("[Step: %d], classifier_loss: %g, GAN losses: %s" % (step, c_loss_one_batch, str(gan_losses_one_batch_dict)))
                logging.info(" - elapsed time for one step: %f secs" % elapsed_time)

            if (step + 1) % exp_config.train_eval_frequency == 0:

                # Evaluate against the training set
                logging.info('Training data eval for classifier (target domain):')
                [train_loss, train_diag_f1, train_ages_f1] = do_eval_classifier(sess, eval_diag_loss,
                                                                                eval_ages_loss,
                                                                                pred_labels,
                                                                                ages_softmaxs,
                                                                                xs_pl,
                                                                                diag_s_pl,
                                                                                ages_s_pl,
                                                                                training_time_placeholder,
                                                                                directly_feed_clf_pl,
                                                                                images_train,
                                                                                [labels_train, ages_train],
                                                                                clf_batch_size=clf_batch_size,
                                                                                do_ordinal_reg=exp_config.age_ordinal_regression,
                                                                                selection_indices=source_images_train_ind)

                train_summary_msg = sess.run(train_summary, feed_dict={train_error_clf_: train_loss,
                                                                       train_diag_f1_score_: train_diag_f1,
                                                                       train_ages_f1_score_: train_ages_f1}
                                             )
                summary_writer.add_summary(train_summary_msg, step)

                loss_history.append(train_loss)
                if len(loss_history) > 5:
                    loss_history.pop(0)
                    loss_gradient = (loss_history[-5] - loss_history[-1]) / 2

                logging.info('loss gradient is currently %f' % loss_gradient)

                if exp_config.schedule_lr and loss_gradient < exp_config.schedule_gradient_threshold:
                    logging.warning('Reducing learning rate of the classifier!')
                    curr_lr_clf /= 10.0
                    logging.info('Learning rate of the classifier changed to: %f' % curr_lr_clf)

                    # reset loss history to give the optimisation some time to start decreasing again
                    loss_gradient = np.inf
                    loss_history = []

                if train_loss <= last_train:  # best_train:
                    logging.info('Decrease in training error!')
                else:
                    logging.info('No improvment in training error for %d steps' % no_improvement_counter)

                last_train = train_loss


            if (step + 1) % exp_config.validation_frequency == 0:

                # evaluate gan losses
                g_loss_val_avg, d_loss_val_avg = do_eval_gan(sess=sess,
                                                             losses=[losses_gan_dict['gen']['nr'], losses_gan_dict['disc']['nr']],
                                                             images_s_pl=xs_pl,
                                                             images_t_pl=xt_pl,
                                                             training_time_placeholder=training_time_placeholder,
                                                             images=images_val,
                                                             source_images_ind=source_images_val_ind,
                                                             target_images_ind=target_images_val_ind)

                # evaluate classifier losses
                [val_loss, val_diag_f1, val_ages_f1] = do_eval_classifier(sess,
                                                                          eval_diag_loss,
                                                                          eval_ages_loss,
                                                                          pred_labels,
                                                                          ages_softmaxs,
                                                                          xs_pl,
                                                                          diag_s_pl,
                                                                          ages_s_pl,
                                                                          training_time_pl=training_time_placeholder,
                                                                          directly_feed_clf_pl=directly_feed_clf_pl,
                                                                          images=images_val,
                                                                          labels_list=[labels_val, ages_val],
                                                                          clf_batch_size=clf_batch_size,
                                                                          do_ordinal_reg=exp_config.age_ordinal_regression,
                                                                          selection_indices=source_images_val_ind)


                feed_dict_val = {
                    val_error_clf_: val_loss,
                    val_diag_f1_score_: val_diag_f1,
                    val_ages_f1_score_: val_ages_f1,
                    val_disc_loss_pl: d_loss_val_avg,
                    val_gen_loss_pl: g_loss_val_avg
                }

                validation_summary_msg = sess.run(val_summary, feed_dict=feed_dict_val)
                summary_writer.add_summary(validation_summary_msg, step)
                summary_writer.flush()

                # save best variables (if discriminator loss is the lowest yet)
                if d_loss_val_avg <= best_d_loss:
                    best_d_loss = d_loss_val_avg
                    best_file = os.path.join(log_dir, 'model_best_d_loss.ckpt')
                    saver_best_disc.save(sess, best_file, global_step=step)
                    logging.info('Found new best discriminator loss on validation set! - %f -  Saving model_best_d_loss.ckpt' % best_d_loss)

                if val_diag_f1 >= best_diag_f1_score:
                    best_diag_f1_score = val_diag_f1
                    best_file = os.path.join(log_dir, 'model_best_diag_f1.ckpt')
                    saver_best_diag_f1.save(sess, best_file, global_step=step)
                    logging.info(
                        'Found new best DIAGNOSIS F1 score on validation set! - %f -  Saving model_best_diag_f1.ckpt' % val_diag_f1)

                if val_ages_f1 >= best_ages_f1_score:
                    best_ages_f1_score = val_ages_f1
                    best_file = os.path.join(log_dir, 'model_best_ages_f1.ckpt')
                    saver_best_ages_f1.save(sess, best_file, global_step=step)
                    logging.info(
                        'Found new best AGES F1 score on validation set! - %f -  Saving model_best_ages_f1.ckpt' % val_ages_f1)

                if val_loss <= best_val:
                    best_val = val_loss
                    best_file = os.path.join(log_dir, 'model_best_xent.ckpt')
                    saver_best_xent.save(sess, best_file, global_step=step)
                    logging.info(
                        'Found new best crossentropy on validation set! - %f -  Saving model_best_xent.ckpt' % val_loss)

                logging.info("[Validation], generator loss: %g, discriminator_loss: %g" % (g_loss_val_avg, d_loss_val_avg))

            # Write the summaries and print an overview fairly often.
            if step % exp_config.save_frequency == 0:

                saver_latest.save(sess, os.path.join(log_dir, 'model.ckpt'), global_step=step)

        sess.close()