示例#1
0
def do_eval(sess, eval_loss, images_placeholder, labels_placeholder,
            training_time_placeholder, images, labels, batch_size):
    '''
    Function for running the evaluations every X iterations on the training and validation sets. 
    :param sess: The current tf session 
    :param eval_loss: The placeholder containing the eval loss
    :param images_placeholder: Placeholder for the images
    :param labels_placeholder: Placeholder for the masks
    :param training_time_placeholder: Placeholder toggling the training/testing mode. 
    :param images: A numpy array or h5py dataset containing the images
    :param labels: A numpy array or h45py dataset containing the corresponding labels 
    :param batch_size: The batch_size to use. 
    :return: The average loss (as defined in the experiment), and the average dice over all `images`. 
    '''

    loss_ii = 0
    dice_ii = 0
    num_batches = 0

    for batch in BackgroundGenerator(
            iterate_minibatches(images,
                                labels,
                                batch_size=batch_size,
                                augment_batch=False)):  # No aug in evaluation
        # As before you can wrap the iterate_minibatches function in the BackgroundGenerator class for speed improvements
        # but at the risk of not catching exceptions

        x, y = batch

        if y.shape[0] < batch_size:
            continue

        feed_dict = {
            images_placeholder: x,
            labels_placeholder: y,
            training_time_placeholder: False
        }

        closs, cdice = sess.run(eval_loss, feed_dict=feed_dict)
        loss_ii += closs
        dice_ii += cdice
        num_batches += 1

    avg_loss = loss_ii / num_batches
    avg_dice = dice_ii / num_batches

    logging.info('  Average loss: %0.04f, average dice: %0.04f' %
                 (avg_loss, avg_dice))

    return avg_loss, avg_dice
示例#2
0
 def __init__(self, **kwargs):
     super(YACSGame, self).__init__(**kwargs)
     self.gameworld.init_gameworld([
         'back_stars', 'mid_stars', 'position', 'sun1', 'sun2',
         'camera_stars1', 'camera_stars2', 'map', 'planet1', 'planet2',
         'camera_sun1', 'camera_sun2', 'camera_planet1', 'camera_planet2',
         'scale', 'rotate', 'color', 'particles', 'emitters',
         'particle_renderer', 'cymunk_physics', 'steering', 'ship_system',
         'projectiles', 'projectile_weapons', 'lifespan', 'combat_stats',
         'asteroids', 'steering_ai', 'weapon_ai', 'shields',
         'shield_renderer', 'map_grid', 'grid_camera', 'radar_renderer',
         'radar_color', 'world_grid', 'global_map', 'global_camera',
         'world_map', 'global_map_renderer', 'global_map_renderer2',
         'global_map_planet_renderer'
     ],
                                   callback=self.init_game)
     self.background_generator = BackgroundGenerator(self.gameworld)
示例#3
0
def run_training(continue_run):

    logging.info('EXPERIMENT NAME: %s' % exp_config.experiment_name)
    already_created_recursion = True
    print("ALready created recursion : " + str(already_created_recursion))
    init_step = 0
    # Load data
    base_data, recursion_data, recursion = acdc_data.load_and_maybe_process_scribbles(
        scribble_file=sys_config.project_root + exp_config.scribble_data,
        target_folder=log_dir,
        percent_full_sup=exp_config.percent_full_sup,
        scr_ratio=exp_config.length_ratio)
    #wrap everything from this point onwards in a try-except to catch keyboard interrupt so
    #can control h5py closing data
    try:
        loaded_previous_recursion = False
        start_epoch = 0
        if continue_run:
            logging.info(
                '!!!!!!!!!!!!!!!!!!!!!!!!!!!! Continuing previous run !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'
            )
            try:
                try:
                    init_checkpoint_path = utils.get_latest_model_checkpoint_path(
                        log_dir, 'recursion_{}_model.ckpt'.format(recursion))

                except:
                    print("EXCEPTE GİRDİ")
                    init_checkpoint_path = utils.get_latest_model_checkpoint_path(
                        log_dir,
                        'recursion_{}_model.ckpt'.format(recursion - 1))
                    loaded_previous_recursion = True
                logging.info('Checkpoint path: %s' % init_checkpoint_path)
                init_step = int(
                    init_checkpoint_path.split('/')[-1].split('-')
                    [-1]) + 1  # plus 1 b/c otherwise starts with eval
                start_epoch = int(init_step /
                                  (len(base_data['images_train']) / 4))
                logging.info('Latest step was: %d' % init_step)
                logging.info('Continuing with epoch: %d' % start_epoch)
            except:
                logging.warning(
                    '!!! Did not find init checkpoint. Maybe first run failed. Disabling continue mode...'
                )
                continue_run = False
                init_step = 0
                start_epoch = 0

            logging.info(
                '!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'
            )

        if loaded_previous_recursion:
            logging.info(
                "Data file exists for recursion {} "
                "but checkpoints only present up to recursion {}".format(
                    recursion, recursion - 1))
            logging.info("Likely means postprocessing was terminated")

            #            if not already_created_recursion:
            #
            #                recursion_data = acdc_data.load_different_recursion(recursion_data, -1)
            #                recursion-=1
            #            else:
            start_epoch = 0
            init_step = 0
        # load images and validation data
        images_train = np.array(base_data['images_train'])
        scribbles_train = np.array(base_data['scribbles_train'])
        images_val = np.array(base_data['images_test'])
        labels_val = np.array(base_data['masks_test'])

        # if exp_config.use_data_fraction:
        #     num_images = images_train.shape[0]
        #     new_last_index = int(float(num_images)*exp_config.use_data_fraction)
        #
        #     logging.warning('USING ONLY FRACTION OF DATA!')
        #     logging.warning(' - Number of imgs orig: %d, Number of imgs new: %d' % (num_images, new_last_index))
        #     images_train = images_train[0:new_last_index,...]
        #     labels_train = labels_train[0:new_last_index,...]

        logging.info('Data summary:')
        logging.info(' - Images:')
        logging.info(images_train.shape)
        logging.info(images_train.dtype)
        #logging.info(' - Labels:')
        #logging.info(labels_train.shape)
        #logging.info(labels_train.dtype)

        # Tell TensorFlow that the model will be built into the default Graph.
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
        #        with tf.Graph().as_default():
        with tf.Session(config=config) as sess:
            # Generate placeholders for the images and labels.

            image_tensor_shape = [exp_config.batch_size] + list(
                exp_config.image_size) + [1]
            mask_tensor_shape = [exp_config.batch_size] + list(
                exp_config.image_size)

            images_placeholder = tf.placeholder(tf.float32,
                                                shape=image_tensor_shape,
                                                name='images')
            labels_placeholder = tf.placeholder(tf.uint8,
                                                shape=mask_tensor_shape,
                                                name='labels')

            learning_rate_placeholder = tf.placeholder(tf.float32, shape=[])
            training_time_placeholder = tf.placeholder(tf.bool, shape=[])
            keep_prob = tf.placeholder(tf.float32, shape=[])
            crf_learning_rate_placeholder = tf.placeholder(tf.float32,
                                                           shape=[])
            tf.summary.scalar('learning_rate', learning_rate_placeholder)

            # Build a Graph that computes predictions from the inference model.
            logits = model.inference(images_placeholder,
                                     keep_prob,
                                     exp_config.model_handle,
                                     training=training_time_placeholder,
                                     nlabels=exp_config.nlabels)

            # Add to the Graph the Ops for loss calculation.
            [loss, _, weights_norm
             ] = model.loss(logits,
                            labels_placeholder,
                            nlabels=exp_config.nlabels,
                            loss_type=exp_config.loss_type,
                            weight_decay=exp_config.weight_decay
                            )  # second output is unregularised loss

            tf.summary.scalar('loss', loss)
            tf.summary.scalar('weights_norm_term', weights_norm)

            # Add to the Graph the Ops that calculate and apply gradients.

            global_step = tf.Variable(0, name='global_step', trainable=False)

            crf_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                              scope='crf_scope')

            restore_var = [
                v for v in tf.all_variables() if v.name not in crf_variables
            ]

            global_step = tf.Variable(0, name='global_step', trainable=False)

            network_train_op = tf.train.AdamOptimizer(
                learning_rate=learning_rate_placeholder).minimize(
                    loss,
                    var_list=restore_var,
                    colocate_gradients_with_ops=True,
                    global_step=global_step)

            crf_train_op = tf.train.AdamOptimizer(
                learning_rate=crf_learning_rate_placeholder).minimize(
                    loss,
                    var_list=crf_variables,
                    colocate_gradients_with_ops=True,
                    global_step=global_step)

            eval_val_loss = model.evaluation(
                logits,
                labels_placeholder,
                images_placeholder,
                nlabels=exp_config.nlabels,
                loss_type=exp_config.loss_type,
                weak_supervision=True,
                cnn_threshold=exp_config.cnn_threshold,
                include_bg=False)

            # Build the summary Tensor based on the TF collection of Summaries.
            summary = tf.summary.merge_all()

            # Add the variable initializer Op.
            init = tf.global_variables_initializer()

            # Create a saver for writing training checkpoints.
            # Only keep two checkpoints, as checkpoints are kept for every recursion
            # and they can be 300MB +
            saver = tf.train.Saver(max_to_keep=2)
            saver_best_dice = tf.train.Saver(max_to_keep=2)
            saver_best_xent = tf.train.Saver(max_to_keep=2)

            # Create a session for running Ops on the Graph.
            sess = tf.Session()

            # Instantiate a SummaryWriter to output summaries and the Graph.
            summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

            # with tf.name_scope('monitoring'):

            val_error_ = tf.placeholder(tf.float32, shape=[], name='val_error')
            val_error_summary = tf.summary.scalar('validation_loss',
                                                  val_error_)

            val_dice_ = tf.placeholder(tf.float32, shape=[], name='val_dice')
            val_dice_summary = tf.summary.scalar('validation_dice', val_dice_)

            val_summary = tf.summary.merge(
                [val_error_summary, val_dice_summary])

            train_error_ = tf.placeholder(tf.float32,
                                          shape=[],
                                          name='train_error')
            train_error_summary = tf.summary.scalar('training_loss',
                                                    train_error_)

            train_dice_ = tf.placeholder(tf.float32,
                                         shape=[],
                                         name='train_dice')
            train_dice_summary = tf.summary.scalar('training_dice',
                                                   train_dice_)

            train_summary = tf.summary.merge(
                [train_error_summary, train_dice_summary])

            # Run the Op to initialize the variables.
            sess.run(init)

            #            if continue_run:
            #                # Restore session
            #                saver.restore(sess, init_checkpoint_path)
            #            saver.restore(sess,"/scratch_net/biwirender02/cany/scribble/logdir/heart_dropout_rnn_exp/recursion_1_model_best_dice.ckpt-12699")

            init_step = 0
            recursion = 0
            start_epoch = 0
            #

            step = init_step
            curr_lr = exp_config.learning_rate / 10
            crf_curr_lr = 1e-07 / 10
            no_improvement_counter = 0
            best_val = np.inf
            last_train = np.inf
            loss_history = []
            loss_gradient = np.inf
            best_dice = 0
            logging.info('RECURSION {0}'.format(recursion))

            # random walk - if it already has been random walked it won't redo
            if recursion == 0:
                recursion_data = acdc_data.random_walk_epoch(
                    recursion_data, exp_config.rw_beta,
                    exp_config.rw_threshold, exp_config.random_walk)
                print("Random walku geçti")
                #get ground truths
                labels_train = np.array(recursion_data['random_walked'])
            else:
                labels_train = np.array(recursion_data['predicted'])
            print("Start epoch : " + str(start_epoch) + " : max epochs : " +
                  str(exp_config.epochs_per_recursion))
            for epoch in range(start_epoch, exp_config.max_epochs):
                if (epoch % exp_config.epochs_per_recursion == 0
                        and epoch != 0):

                    #Have reached end of recursion
                    recursion_data = predict_next_gt(
                        data=recursion_data,
                        images_train=images_train,
                        images_placeholder=images_placeholder,
                        training_time_placeholder=training_time_placeholder,
                        keep_prob=keep_prob,
                        logits=logits,
                        sess=sess)

                    #                        recursion_data = postprocess_gt(data=recursion_data,
                    #                                                        images_train=images_train,
                    #                                                        scribbles_train=scribbles_train)
                    recursion += 1
                    # random walk - if it already has been random walked it won't redo
                    #                        recursion_data = acdc_data.random_walk_epoch(recursion_data,
                    #                                                                     exp_config.rw_beta,
                    #                                                                     exp_config.rw_threshold,
                    #                                                                     exp_config.random_walk)
                    #get ground truths
                    labels_train = np.array(recursion_data['predicted'])

                    #reinitialise savers - otherwise, no checkpoints will be saved for each recursion
                    saver = tf.train.Saver(max_to_keep=2)
                    saver_best_dice = tf.train.Saver(max_to_keep=2)
                    saver_best_xent = tf.train.Saver(max_to_keep=2)
                logging.info(
                    'Epoch {0} ({1} of {2} epochs for recursion {3})'.format(
                        epoch, 1 + epoch % exp_config.epochs_per_recursion,
                        exp_config.epochs_per_recursion, recursion))
                # for batch in iterate_minibatches(images_train,
                #                                  labels_train,
                #                                  batch_size=exp_config.batch_size,
                #                                  augment_batch=exp_config.augment_batch):

                # You can run this loop with the BACKGROUND GENERATOR, which will lead to some improvements in the
                # training speed. However, be aware that currently an exception inside this loop may not be caught.
                # The batch generator may just continue running silently without warning even though the code has
                # crashed.

                for batch in BackgroundGenerator(
                        iterate_minibatches(
                            images_train,
                            labels_train,
                            batch_size=exp_config.batch_size,
                            augment_batch=exp_config.augment_batch)):

                    if exp_config.warmup_training:
                        if step < 50:
                            curr_lr = exp_config.learning_rate / 10.0
                        elif step == 50:
                            curr_lr = exp_config.learning_rate
                    if ((step % 3000 == 0) & (step > 0)):
                        curr_lr = curr_lr * 0.9
                        crf_curr_lr = crf_curr_lr * 0.9

                    start_time = time.time()

                    # batch = bgn_train.retrieve()
                    x, y = batch

                    # TEMPORARY HACK (to avoid incomplete batches
                    if y.shape[0] < exp_config.batch_size:
                        step += 1
                        continue

                    network_feed_dict = {
                        images_placeholder: x,
                        labels_placeholder: y,
                        learning_rate_placeholder: curr_lr,
                        keep_prob: 0.5,
                        training_time_placeholder: True
                    }

                    crf_feed_dict = {
                        images_placeholder: x,
                        labels_placeholder: y,
                        crf_learning_rate_placeholder: crf_curr_lr,
                        keep_prob: 1,
                        training_time_placeholder: True
                    }

                    if (step % 10 == 0):
                        _, loss_value = sess.run([crf_train_op, loss],
                                                 feed_dict=crf_feed_dict)
                    _, loss_value = sess.run([network_train_op, loss],
                                             feed_dict=network_feed_dict)
                    duration = time.time() - start_time

                    # Write the summaries and print an overview fairly often.
                    if step % 10 == 0:
                        # Print status to stdout.
                        logging.info('Step %d: loss = %.6f (%.3f sec)' %
                                     (step, loss_value, duration))
                        # Update the events file.

                    # Save a checkpoint and evaluate the model periodically.
                    if (step + 1) % exp_config.val_eval_frequency == 0:

                        checkpoint_file = os.path.join(
                            log_dir,
                            'recursion_{}_model.ckpt'.format(recursion))
                        saver.save(sess, checkpoint_file, global_step=step)
                        # Evaluate against the training set.

                        # Evaluate against the validation set.
                        logging.info('Validation Data Eval:')
                        [val_loss, val_dice
                         ] = do_eval(sess, eval_val_loss, images_placeholder,
                                     labels_placeholder,
                                     training_time_placeholder, keep_prob,
                                     images_val, labels_val,
                                     exp_config.batch_size)

                        val_summary_msg = sess.run(val_summary,
                                                   feed_dict={
                                                       val_error_: val_loss,
                                                       val_dice_: val_dice
                                                   })
                        summary_writer.add_summary(val_summary_msg, step)

                        if val_dice > best_dice:
                            best_dice = val_dice
                            best_file = os.path.join(
                                log_dir,
                                'recursion_{}_model_best_dice.ckpt'.format(
                                    recursion))
                            saver_best_dice.save(sess,
                                                 best_file,
                                                 global_step=step)
                            logging.info(
                                'Found new best dice on validation set! - {} - '
                                'Saving recursion_{}_model_best_dice.ckpt'.
                                format(val_dice, recursion))
                            text_file = open('val_results.txt', "a")
                            text_file.write("\nVal dice " + str(step) + " : " +
                                            str(val_dice))
                            text_file.close()
                        if val_loss < best_val:
                            best_val = val_loss
                            best_file = os.path.join(
                                log_dir,
                                'recursion_{}_model_best_xent.ckpt'.format(
                                    recursion))
                            saver_best_xent.save(sess,
                                                 best_file,
                                                 global_step=step)
                            logging.info(
                                'Found new best crossentropy on validation set! - {} - '
                                'Saving recursion_{}_model_best_xent.ckpt'.
                                format(val_loss, recursion))

                    step += 1

    except Exception:
        raise
def run_training(continue_run):

    logging.info('EXPERIMENT NAME: %s' % exp_config.experiment_name)

    init_step = 0
    # Load data
    base_data, recursion_data, recursion = acdc_data.load_and_maybe_process_scribbles(
        scribble_file=sys_config.project_root + exp_config.scribble_data,
        target_folder=log_dir,
        percent_full_sup=exp_config.percent_full_sup,
        scr_ratio=exp_config.length_ratio)
    #wrap everything from this point onwards in a try-except to catch keyboard interrupt so
    #can control h5py closing data
    try:
        loaded_previous_recursion = False
        start_epoch = 0
        if continue_run:
            logging.info(
                '!!!!!!!!!!!!!!!!!!!!!!!!!!!! Continuing previous run !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'
            )
            try:
                try:
                    init_checkpoint_path = utils.get_latest_model_checkpoint_path(
                        log_dir, 'recursion_{}_model.ckpt'.format(recursion))
                except:
                    init_checkpoint_path = utils.get_latest_model_checkpoint_path(
                        log_dir,
                        'recursion_{}_model.ckpt'.format(recursion - 1))
                    loaded_previous_recursion = True
                logging.info('Checkpoint path: %s' % init_checkpoint_path)
                init_step = int(
                    init_checkpoint_path.split('/')[-1].split('-')
                    [-1]) + 1  # plus 1 b/c otherwise starts with eval
                start_epoch = int(
                    init_step /
                    (len(base_data['images_train']) / exp_config.batch_size))
                logging.info('Latest step was: %d' % init_step)
                logging.info('Continuing with epoch: %d' % start_epoch)
            except:
                logging.warning(
                    '!!! Did not find init checkpoint. Maybe first run failed. Disabling continue mode...'
                )
                continue_run = False
                init_step = 0
                start_epoch = 0

            logging.info(
                '!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'
            )

        if loaded_previous_recursion:
            logging.info(
                "Data file exists for recursion {} "
                "but checkpoints only present up to recursion {}".format(
                    recursion, recursion - 1))
            logging.info("Likely means postprocessing was terminated")
            recursion_data = acdc_data.load_different_recursion(
                recursion_data, -1)
            recursion -= 1

        # load images and validation data
        images_train = np.array(base_data['images_train'])
        scribbles_train = np.array(base_data['scribbles_train'])
        images_val = np.array(base_data['images_test'])
        labels_val = np.array(base_data['masks_test'])

        # if exp_config.use_data_fraction:
        #     num_images = images_train.shape[0]
        #     new_last_index = int(float(num_images)*exp_config.use_data_fraction)
        #
        #     logging.warning('USING ONLY FRACTION OF DATA!')
        #     logging.warning(' - Number of imgs orig: %d, Number of imgs new: %d' % (num_images, new_last_index))
        #     images_train = images_train[0:new_last_index,...]
        #     labels_train = labels_train[0:new_last_index,...]

        logging.info('Data summary:')
        logging.info(' - Images:')
        logging.info(images_train.shape)
        logging.info(images_train.dtype)
        #logging.info(' - Labels:')
        #logging.info(labels_train.shape)
        #logging.info(labels_train.dtype)

        # Tell TensorFlow that the model will be built into the default Graph.

        with tf.Graph().as_default():

            # Generate placeholders for the images and labels.

            image_tensor_shape = [exp_config.batch_size] + list(
                exp_config.image_size) + [1]
            mask_tensor_shape = [exp_config.batch_size] + list(
                exp_config.image_size)

            images_placeholder = tf.placeholder(tf.float32,
                                                shape=image_tensor_shape,
                                                name='images')
            labels_placeholder = tf.placeholder(tf.uint8,
                                                shape=mask_tensor_shape,
                                                name='labels')

            learning_rate_placeholder = tf.placeholder(tf.float32, shape=[])
            training_time_placeholder = tf.placeholder(tf.bool, shape=[])

            tf.summary.scalar('learning_rate', learning_rate_placeholder)

            # Build a Graph that computes predictions from the inference model.
            logits = model.inference(images_placeholder,
                                     exp_config.model_handle,
                                     training=training_time_placeholder,
                                     nlabels=exp_config.nlabels)

            # Add to the Graph the Ops for loss calculation.
            [loss, _, weights_norm
             ] = model.loss(logits,
                            labels_placeholder,
                            nlabels=exp_config.nlabels,
                            loss_type=exp_config.loss_type,
                            weight_decay=exp_config.weight_decay
                            )  # second output is unregularised loss

            tf.summary.scalar('loss', loss)
            tf.summary.scalar('weights_norm_term', weights_norm)

            # Add to the Graph the Ops that calculate and apply gradients.
            if exp_config.momentum is not None:
                train_op = model.training_step(loss,
                                               exp_config.optimizer_handle,
                                               learning_rate_placeholder,
                                               momentum=exp_config.momentum)
            else:
                train_op = model.training_step(loss,
                                               exp_config.optimizer_handle,
                                               learning_rate_placeholder)

            # Add the Op to compare the logits to the labels during evaluation.
#            eval_loss = model.evaluation(logits,
#                                         labels_placeholder,
#                                         images_placeholder,
#                                         nlabels=exp_config.nlabels,
#                                         loss_type=exp_config.loss_type,
#                                         weak_supervision=True,
#                                         cnn_threshold=exp_config.cnn_threshold,
#                                         include_bg=True)
            eval_val_loss = model.evaluation(
                logits,
                labels_placeholder,
                images_placeholder,
                nlabels=exp_config.nlabels,
                loss_type=exp_config.loss_type,
                weak_supervision=True,
                cnn_threshold=exp_config.cnn_threshold,
                include_bg=False)

            # Build the summary Tensor based on the TF collection of Summaries.
            summary = tf.summary.merge_all()

            # Add the variable initializer Op.
            init = tf.global_variables_initializer()

            # Create a saver for writing training checkpoints.
            # Only keep two checkpoints, as checkpoints are kept for every recursion
            # and they can be 300MB +
            saver = tf.train.Saver(max_to_keep=2)
            saver_best_dice = tf.train.Saver(max_to_keep=2)
            saver_best_xent = tf.train.Saver(max_to_keep=2)

            # Create a session for running Ops on the Graph.
            sess = tf.Session()

            # Instantiate a SummaryWriter to output summaries and the Graph.
            summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

            # with tf.name_scope('monitoring'):

            val_error_ = tf.placeholder(tf.float32, shape=[], name='val_error')
            val_error_summary = tf.summary.scalar('validation_loss',
                                                  val_error_)

            val_dice_ = tf.placeholder(tf.float32, shape=[], name='val_dice')
            val_dice_summary = tf.summary.scalar('validation_dice', val_dice_)

            val_summary = tf.summary.merge(
                [val_error_summary, val_dice_summary])

            train_error_ = tf.placeholder(tf.float32,
                                          shape=[],
                                          name='train_error')
            train_error_summary = tf.summary.scalar('training_loss',
                                                    train_error_)

            train_dice_ = tf.placeholder(tf.float32,
                                         shape=[],
                                         name='train_dice')
            train_dice_summary = tf.summary.scalar('training_dice',
                                                   train_dice_)

            train_summary = tf.summary.merge(
                [train_error_summary, train_dice_summary])

            # Run the Op to initialize the variables.
            sess.run(init)

            # Restore session
            #            crf_weights = []
            #            for v in tf.all_variables():
            #
            #                if v.name[0:4]=='bila':
            #                    print(str(v))
            #                    crf_weights.append(v.name)
            #                elif v.name[0:4] =='spat':
            #                    print(str(v))
            #                    crf_weights.append(v.name)
            #                elif v.name[0:4] =='comp':
            #                    print(str(v))
            #                    crf_weights.append(v.name)
            #            restore_var = [v for v in tf.all_variables() if v.name not in crf_weights]
            #
            #            load_saver = tf.train.Saver(var_list=restore_var)
            #            load_saver.restore(sess, '/scratch_net/biwirender02/cany/basil/logdir/unet2D_ws_spot_blur/recursion_0_model.ckpt-5699')

            if continue_run:
                # Restore session
                saver.restore(sess, init_checkpoint_path)

            step = init_step
            curr_lr = exp_config.learning_rate

            no_improvement_counter = 0
            best_val = np.inf
            last_train = np.inf
            loss_history = []
            loss_gradient = np.inf
            best_dice = 0
            logging.info('RECURSION {0}'.format(recursion))

            # random walk - if it already has been random walked it won't redo
            recursion_data = acdc_data.random_walk_epoch(
                recursion_data, exp_config.rw_beta, exp_config.rw_threshold,
                exp_config.random_walk)

            #get ground truths
            labels_train = np.array(recursion_data['random_walked'])

            for epoch in range(start_epoch, exp_config.max_epochs):
                if (epoch % exp_config.epochs_per_recursion == 0 and epoch != 0) \
                        or loaded_previous_recursion:
                    loaded_previous_recursion = False
                    #Have reached end of recursion
                    recursion_data = predict_next_gt(
                        data=recursion_data,
                        images_train=images_train,
                        images_placeholder=images_placeholder,
                        training_time_placeholder=training_time_placeholder,
                        logits=logits,
                        sess=sess)

                    recursion_data = postprocess_gt(
                        data=recursion_data,
                        images_train=images_train,
                        scribbles_train=scribbles_train)
                    recursion += 1
                    # random walk - if it already has been random walked it won't redo
                    recursion_data = acdc_data.random_walk_epoch(
                        recursion_data, exp_config.rw_beta,
                        exp_config.rw_threshold, exp_config.random_walk)
                    #get ground truths
                    labels_train = np.array(recursion_data['random_walked'])

                    #reinitialise savers - otherwise, no checkpoints will be saved for each recursion
                    saver = tf.train.Saver(max_to_keep=2)
                    saver_best_dice = tf.train.Saver(max_to_keep=2)
                    saver_best_xent = tf.train.Saver(max_to_keep=2)
                logging.info(
                    'Epoch {0} ({1} of {2} epochs for recursion {3})'.format(
                        epoch, 1 + epoch % exp_config.epochs_per_recursion,
                        exp_config.epochs_per_recursion, recursion))
                # for batch in iterate_minibatches(images_train,
                #                                  labels_train,
                #                                  batch_size=exp_config.batch_size,
                #                                  augment_batch=exp_config.augment_batch):

                # You can run this loop with the BACKGROUND GENERATOR, which will lead to some improvements in the
                # training speed. However, be aware that currently an exception inside this loop may not be caught.
                # The batch generator may just continue running silently without warning even though the code has
                # crashed.

                for batch in BackgroundGenerator(
                        iterate_minibatches(
                            images_train,
                            labels_train,
                            batch_size=exp_config.batch_size,
                            augment_batch=exp_config.augment_batch)):

                    if exp_config.warmup_training:
                        if step < 50:
                            curr_lr = exp_config.learning_rate / 10.0
                        elif step == 50:
                            curr_lr = exp_config.learning_rate

                    start_time = time.time()

                    # batch = bgn_train.retrieve()
                    x, y = batch

                    # TEMPORARY HACK (to avoid incomplete batches
                    if y.shape[0] < exp_config.batch_size:
                        step += 1
                        continue

                    feed_dict = {
                        images_placeholder: x,
                        labels_placeholder: y,
                        learning_rate_placeholder: curr_lr,
                        training_time_placeholder: True
                    }

                    _, loss_value = sess.run([train_op, loss],
                                             feed_dict=feed_dict)

                    duration = time.time() - start_time

                    # Write the summaries and print an overview fairly often.
                    if step % 10 == 0:
                        # Print status to stdout.
                        logging.info('Step %d: loss = %.6f (%.3f sec)' %
                                     (step, loss_value, duration))
                        # Update the events file.

                        summary_str = sess.run(summary, feed_dict=feed_dict)
                        summary_writer.add_summary(summary_str, step)
                        summary_writer.flush()


#                    if (step + 1) % exp_config.train_eval_frequency == 0:
#
#                        logging.info('Training Data Eval:')
#                        [train_loss, train_dice] = do_eval(sess,
#                                                           eval_loss,
#                                                           images_placeholder,
#                                                           labels_placeholder,
#                                                           training_time_placeholder,
#                                                           images_train,
#                                                           labels_train,
#                                                           exp_config.batch_size)
#
#                        train_summary_msg = sess.run(train_summary, feed_dict={train_error_: train_loss,
#                                                                               train_dice_: train_dice}
#                                                     )
#                        summary_writer.add_summary(train_summary_msg, step)
#
#                        loss_history.append(train_loss)
#                        if len(loss_history) > 5:
#                            loss_history.pop(0)
#                            loss_gradient = (loss_history[-5] - loss_history[-1]) / 2
#
#                        logging.info('loss gradient is currently %f' % loss_gradient)
#
#                        if exp_config.schedule_lr and loss_gradient < exp_config.schedule_gradient_threshold:
#                            logging.warning('Reducing learning rate!')
#                            curr_lr /= 10.0
#                            logging.info('Learning rate changed to: %f' % curr_lr)
#
#                            # reset loss history to give the optimisation some time to start decreasing again
#                            loss_gradient = np.inf
#                            loss_history = []
#
#                        if train_loss <= last_train:  # best_train:
#                            logging.info('Decrease in training error!')
#                        else:
#                            logging.info('No improvement in training error for %d steps' % no_improvement_counter)
#
#                        last_train = train_loss

# Save a checkpoint and evaluate the model periodically.
                    if (step + 1) % exp_config.val_eval_frequency == 0:

                        checkpoint_file = os.path.join(
                            log_dir,
                            'recursion_{}_model.ckpt'.format(recursion))
                        saver.save(sess, checkpoint_file, global_step=step)
                        # Evaluate against the training set.

                        # Evaluate against the validation set.
                        logging.info('Validation Data Eval:')
                        [val_loss, val_dice
                         ] = do_eval(sess, eval_val_loss, images_placeholder,
                                     labels_placeholder,
                                     training_time_placeholder, images_val,
                                     labels_val, exp_config.batch_size)

                        val_summary_msg = sess.run(val_summary,
                                                   feed_dict={
                                                       val_error_: val_loss,
                                                       val_dice_: val_dice
                                                   })
                        summary_writer.add_summary(val_summary_msg, step)

                        if val_dice > best_dice:
                            best_dice = val_dice
                            best_file = os.path.join(
                                log_dir,
                                'recursion_{}_model_best_dice.ckpt'.format(
                                    recursion))
                            saver_best_dice.save(sess,
                                                 best_file,
                                                 global_step=step)
                            logging.info(
                                'Found new best dice on validation set! - {} - '
                                'Saving recursion_{}_model_best_dice.ckpt'.
                                format(val_dice, recursion))

                        if val_loss < best_val:
                            best_val = val_loss
                            best_file = os.path.join(
                                log_dir,
                                'recursion_{}_model_best_xent.ckpt'.format(
                                    recursion))
                            saver_best_xent.save(sess,
                                                 best_file,
                                                 global_step=step)
                            logging.info(
                                'Found new best crossentropy on validation set! - {} - '
                                'Saving recursion_{}_model_best_xent.ckpt'.
                                format(val_loss, recursion))

                    step += 1

    except Exception:
        raise
def generate_adversarial_examples(input_folder,
                                  output_path,
                                  model_path,
                                  attack,
                                  attack_args,
                                  exp_config,
                                  add_gaussian=False):
    nx, ny = exp_config.image_size[:2]
    batch_size = 1
    num_channels = exp_config.nlabels

    image_tensor_shape = [batch_size] + list(exp_config.image_size) + [1]
    mask_tensor_shape = [batch_size] + list(exp_config.image_size)
    images_pl = tf.placeholder(tf.float32,
                               shape=image_tensor_shape,
                               name='images')
    labels_pl = tf.placeholder(tf.uint8,
                               shape=mask_tensor_shape,
                               name='labels')
    logits_pl = model.inference(images_pl,
                                exp_config=exp_config,
                                training=tf.constant(False, dtype=tf.bool))
    eval_loss = model.evaluation(logits_pl,
                                 labels_pl,
                                 images_pl,
                                 nlabels=exp_config.nlabels,
                                 loss_type=exp_config.loss_type)

    data = acdc_data.load_and_maybe_process_data(
        input_folder=sys_config.data_root,
        preprocessing_folder=sys_config.preproc_folder,
        mode=exp_config.data_mode,
        size=exp_config.image_size,
        target_resolution=exp_config.target_resolution,
        force_overwrite=False,
        split_test_train=True)

    images = data['images_test'][:20]
    labels = data['masks_test'][:20]

    print("Num images train {} test {}".format(len(data['images_train']),
                                               len(images)))

    saver = tf.train.Saver()
    init = tf.global_variables_initializer()

    baseline_closs = 0.0
    baseline_cdice = 0.0
    attack_closs = 0.0
    attack_cdice = 0.0
    l2_diff_sum = 0.0
    ln_diff_sum = 0.0
    ln_diff = 0.0
    l2_diff = 0.0
    batches = 0
    result_dict = []

    with tf.Session() as sess:
        results = []
        sess.run(init)
        checkpoint_path = utils.get_latest_model_checkpoint_path(
            model_path, 'model_best_dice.ckpt')
        saver.restore(sess, checkpoint_path)

        for batch in BackgroundGenerator(
                train.iterate_minibatches(images, labels, batch_size)):
            x, y = batch
            batches += 1

            if batches != 9:
                continue

            non_adv_mask_out = sess.run(
                [tf.arg_max(tf.nn.softmax(logits_pl), dimension=-1)],
                feed_dict={images_pl: x})

            if attack == 'fgsm':
                adv_x = adv_attack.fgsm_run(x, y, images_pl, labels_pl,
                                            logits_pl, exp_config, sess,
                                            attack_args)
            elif attack == 'pgd':
                adv_x = adv_attack.pgd(x, y, images_pl, labels_pl, logits_pl,
                                       exp_config, sess, attack_args)
            elif attack == 'spgd':
                adv_x = adv_attack.pgd_conv(x, y, images_pl, labels_pl,
                                            logits_pl, exp_config, sess,
                                            **attack_args)
            else:
                raise NotImplementedError
            adv_x = [adv_x]

            if add_gaussian:
                print('adding gaussian noise')
                adv_x = adv_attack.add_gaussian_noise(
                    x,
                    adv_x[0],
                    sess,
                    eps=attack_args['eps'],
                    sizes=attack_args['sizes'],
                    weights=attack_args['weights'])

            for i in range(len(adv_x)):
                l2_diff = np.average(
                    np.squeeze(np.linalg.norm(adv_x[i] - x, axis=(1, 2))))
                ln_diff = np.average(
                    np.squeeze(
                        np.linalg.norm(adv_x[i] - x, axis=(1, 2), ord=np.inf)))

                l2_diff_sum += l2_diff
                ln_diff_sum += ln_diff

                print(l2_diff, l2_diff)

                adv_mask_out = sess.run(
                    [tf.arg_max(tf.nn.softmax(logits_pl), dimension=-1)],
                    feed_dict={images_pl: adv_x[i]})

                closs, cdice = sess.run(eval_loss,
                                        feed_dict={
                                            images_pl: x,
                                            labels_pl: y
                                        })
                baseline_closs = closs + baseline_closs
                baseline_cdice = cdice + baseline_cdice

                adv_closs, adv_cdice = sess.run(eval_loss,
                                                feed_dict={
                                                    images_pl: adv_x[i],
                                                    labels_pl: y
                                                })
                attack_closs = adv_closs + attack_closs
                attack_cdice = adv_cdice + attack_cdice

                partial_result = dict({
                    'attack': attack,
                    'attack_args': {
                        k: attack_args[k]
                        for k in ['eps', 'step_alpha', 'epochs']
                    },  #
                    'baseline_closs': closs,
                    'baseline_cdice': cdice,
                    'attack_closs': adv_closs,
                    'attack_cdice': adv_cdice,
                    'attack_l2_diff': l2_diff,
                    'attack_ln_diff': ln_diff
                })

                jsonString = json.dumps(str(partial_result))

                #results.append(copy.deepcopy(result_dict))

                with open(
                        "eval_results/{}-{}-{}-{}-metrics.json".format(
                            attack, add_gaussian, batches, i),
                        "w") as jsonFile:
                    jsonFile.write(jsonString)

                image_gt = "eval_results/ground-truth-{}-{}-{}-{}.pdf".format(
                    attack, add_gaussian, batches, i)
                plt.imshow(np.squeeze(x), cmap='gray')
                plt.imshow(np.squeeze(y), cmap='viridis', alpha=0.7)
                plt.axis('off')
                plt.tight_layout()
                plt.savefig(image_gt, format='pdf')
                plt.clf()

                image_benign = "eval_results/benign-{}-{}-{}-{}.pdf".format(
                    attack, add_gaussian, batches, i)
                plt.imshow(np.squeeze(x), cmap='gray')
                plt.imshow(np.squeeze(non_adv_mask_out),
                           cmap='viridis',
                           alpha=0.7)
                plt.axis('off')
                plt.tight_layout()
                plt.savefig(image_benign, format='pdf')
                plt.clf()

                image_adv = "eval_results/adversarial-{}-{}-{}-{}.pdf".format(
                    attack, add_gaussian, batches, i)
                plt.imshow(np.squeeze(adv_x[i]), cmap='gray')
                plt.imshow(np.squeeze(adv_mask_out), cmap='viridis', alpha=0.7)
                plt.axis('off')
                plt.tight_layout()
                plt.savefig(image_adv, format='pdf')
                plt.clf()

                plt.imshow(np.squeeze(adv_x[i]), cmap='gray')
                image_adv_input = "eval_results/adv-input-{}-{}-{}-{}.pdf".format(
                    attack, add_gaussian, batches, i)
                plt.tight_layout()
                plt.axis('off')
                plt.savefig(image_adv_input, format='pdf')
                plt.clf()

                plt.imshow(np.squeeze(x), cmap='gray')
                image_adv_input = "eval_results/benign-input-{}-{}-{}-{}.pdf".format(
                    attack, add_gaussian, batches, i)
                plt.axis('off')
                plt.tight_layout()
                plt.savefig(image_adv_input, format='pdf')
                plt.clf()

                print(attack_closs, attack_cdice, l2_diff, ln_diff)

        print("Evaluation results")
        print("{} Attack Params {}".format(attack, attack_args))
        print("Baseline metrics: Avg loss {}, Avg DICE Score {} ".format(
            baseline_closs / (batches * len(adv_x)),
            baseline_cdice / (batches * len(adv_x))))
        print(
            "{} Attack effectiveness: Avg loss {}, Avg DICE Score {} ".format(
                attack, attack_closs / (batches * len(adv_x)),
                attack_cdice / (batches * len(adv_x))))
        print(
            "{} Attack visibility: Avg l2-norm diff {} Avg l-inf-norm diff {}".
            format(attack, l2_diff_sum / (batches * len(adv_x)),
                   ln_diff_sum / (batches * len(adv_x))))
        result_dict = dict({
            'attack': attack,
            'attack_args':
            {k: attack_args[k]
             for k in ['eps', 'step_alpha', 'epochs']},  #
            'baseline_closs_avg': baseline_closs / batches,
            'baseline_cdice_avg': baseline_cdice / batches,
            'attack_closs_avg': attack_closs / batches,
            'attack_cdice_avg': attack_cdice / batches,
            'attack_l2_diff': l2_diff_sum / batches,
            'attack_ln_diff': ln_diff_sum / batches
        })

        results.append(copy.deepcopy(result_dict))
        print(results)

        jsonString = json.dumps(results)
        with open("eval_results/{}-results.json".format(attack),
                  "w") as jsonFile:
            jsonFile.write(jsonString)