def get_images_and_labels(dataset_split_name, preprocessing_name, image_preprocessing_fn, train_image_size,
                           tfRecords_dir, is_training=False, allow_smaller_final_batch=False, batch_size=None):
    pattern = paths_namings.file_pattern_tfrecords(FLAGS, tfRecords_dir, dataset_split_name)
    dataset_Ob = dataset_factory_tf.get_dataset_y(FLAGS.dataset_name, dataset_split_name, file_pattern=pattern)
    dataset = dataset_Ob.get_split_slim(tfRecords_dir=tfRecords_dir, n_channels=FLAGS.channels)

    if not batch_size:
        batch_size = FLAGS.batch_size
    images, raw_images, labels, filenames = load_batch_slim(dataset, batch_size, train_image_size, train_image_size,
                                                            preprocessing_name, image_preprocessing_fn,
                                                            num_readers=FLAGS.num_readers, num_preprocessing_threads=FLAGS.num_preprocessing_threads,
                                                            per_image_standardization=FLAGS.per_image_standardization,
                                                            vgg_sub_mean_pixel=FLAGS.vgg_sub_mean_pixel,
                                                            vgg_resize_side_in=FLAGS.vgg_resize_side,
                                                            vgg_use_aspect_preserving_resize=FLAGS.vgg_use_aspect_preserving_resize,
                                                            labels_offset=FLAGS.labels_offset,
                                                            is_training=is_training, allow_smaller_final_batch= allow_smaller_final_batch)
    return images, labels, filenames, dataset.num_samples
def main(_):
    use_placeholders = FLAGS.use_placeholders
    eval_batch_size = FLAGS.eval_batch_size

    tfRecords_dir = FLAGS.tfRecords_dir
    if (FLAGS.tfRecords_dir is None) or (use_placeholders):
        _, tfRecords_dir, imgs_sub_dirs, csv_files, categories_file = paths_namings.get_dataset_paths_and_settings(
            FLAGS.dataset_name)
        tf.logging.info('Data Dir: %s\nCategoriesFile: %s', tfRecords_dir,
                        categories_file)

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        tf_global_step = slim.get_or_create_global_step()

        ######################
        # Select the dataset #
        ######################
        pattern = paths_namings.file_pattern_tfrecords(
            FLAGS, tfRecords_dir, FLAGS.dataset_split_name_y)
        dataset_Ob = dataset_factory_tf.get_dataset_y(
            FLAGS.dataset_name,
            FLAGS.dataset_split_name_y,
            file_pattern=pattern)
        dataset = dataset_Ob.get_split_slim(tfRecords_dir=tfRecords_dir,
                                            n_channels=FLAGS.channels)

        ####################
        # Select the model #
        ####################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            is_training=False)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=False)

        eval_image_size = FLAGS.image_size or network_fn.default_image_size

        if use_placeholders:
            images_pl = tf.placeholder(
                tf.float32,
                [None, eval_image_size, eval_image_size, FLAGS.channels])
            if (FLAGS.model_name == 'cifarnet') or (FLAGS.model_name
                                                    == 'spoofnet_y'):
                logits_pl, end_points_pl = network_fn(images_pl,
                                                      dropout_keep_prob=1)
            else:
                logits_pl, end_points_pl = network_fn(images_pl)

            if FLAGS.input_csv_file:
                filenames, _, labels = helpers_dataset.read_image_files_from_csv(
                    FLAGS.input_csv_file, delimiter=',')
            else:
                filenames, _, labels = helpers_dataset.read_image_files_from_csv(
                    csv_files[FLAGS.dataset_split_name_y],
                    imgs_sub_dirs[FLAGS.dataset_split_name_y], categories_file)
            labels -= FLAGS.labels_offset
            probabilities_op = end_points_pl['Predictions']

            classifier_pl = evaluation_y.Classifier_PL(
                eval_image_size,
                FLAGS.channels,
                preprocessing_name,
                encode_type=FLAGS.encode_type,
                oversample=FLAGS.oversample_at_eval,
                per_image_standardization=FLAGS.per_image_standardization,
                vgg_sub_mean_pixel=FLAGS.vgg_sub_mean_pixel,
                vgg_resize_side_in=FLAGS.vgg_resize_side,
                vgg_use_aspect_preserving_resize=FLAGS.
                vgg_use_aspect_preserving_resize)

        else:
            images, raw_images, labels_t, filenames_op = load_batch_slim(
                dataset,
                eval_batch_size,
                eval_image_size,
                eval_image_size,
                preprocessing_name,
                image_preprocessing_fn,
                num_preprocessing_threads=FLAGS.num_preprocessing_threads,
                per_image_standardization=FLAGS.per_image_standardization,
                vgg_sub_mean_pixel=FLAGS.vgg_sub_mean_pixel,
                vgg_resize_side_in=FLAGS.vgg_resize_side,
                vgg_use_aspect_preserving_resize=FLAGS.
                vgg_use_aspect_preserving_resize,
                labels_offset=FLAGS.labels_offset,
                is_training=False,
                allow_smaller_final_batch=True)
            _, end_points, names_to_values, names_to_updates, eval_ops_y = get_logits_and_valid_ops(
                images,
                labels_t,
                network_fn,
                scope='Evaluation/' + FLAGS.dataset_split_name_y)
            eval_ops_slim = list(names_to_updates.values())
            accuracy_value = names_to_values['Accuracy']
            probabilities_op = end_points['Predictions']

            num_batches = int(
                math.ceil(dataset.num_samples / float(eval_batch_size)))
            if FLAGS.max_num_batches:
                num_batches = FLAGS.max_num_batches

        ################### YY: testing another set:
        # ==========
        if FLAGS.dataset_split_name_y2 is not None:
            if use_placeholders:
                filenames_2, _, labels_2 = helpers_dataset.read_image_files_from_csv(
                    csv_files[FLAGS.dataset_split_name_y2],
                    imgs_sub_dirs[FLAGS.dataset_split_name_y2],
                    categories_file)
                labels_2 -= FLAGS.labels_offset
            else:
                images_2, labels_t2, filenames_op_2, num_samples_2 = get_images_and_labels(
                    FLAGS.dataset_split_name_y2,
                    preprocessing_name,
                    image_preprocessing_fn,
                    eval_image_size,
                    tfRecords_dir,
                    is_training=False,
                    allow_smaller_final_batch=True,
                    batch_size=eval_batch_size)
                num_batches_2 = int(
                    math.ceil(num_samples_2 / float(eval_batch_size)))
                if FLAGS.max_num_batches:
                    num_batches_2 = FLAGS.max_num_batches
                _, end_points_2, names_to_values_2, names_to_updates_2, eval_ops_y_2 = get_logits_and_valid_ops(
                    images_2,
                    labels_t2,
                    network_fn,
                    scope='Evaluation/' + FLAGS.dataset_split_name_y2)
                eval_ops_slim_2 = list(names_to_updates_2.values())
                accuracy_value_2 = names_to_values_2['Accuracy']
                probabilities_op_2 = end_points_2['Predictions']

        ## ==========
        if FLAGS.moving_average_decay:
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, tf_global_step)
            variables_to_restore = variable_averages.variables_to_restore(
                slim.get_model_variables())
            variables_to_restore[tf_global_step.op.name] = tf_global_step
        else:
            variables_to_restore = slim.get_variables_to_restore()

        ##############################################################
        checkpoints_dir_y = paths_namings.generate_checkpoints_dir(
            FLAGS, eval_image_size)
        checkpoints_dir_y = checkpoints_dir_y + '_slim/pr_' + preprocessing_name
        if tf.gfile.IsDirectory(checkpoints_dir_y):
            print(' ---- checking latest in -- ', checkpoints_dir_y)
            checkpoint_path = tf.train.latest_checkpoint(checkpoints_dir_y)
            print(' ---- latest -- ', checkpoints_dir_y)
        else:
            checkpoint_path = checkpoints_dir_y
            print(' -- not latest----- ', checkpoints_dir_y)

        eval_dir = checkpoints_dir_y + '/eval'
        if FLAGS.use_slim_stream_eval:
            eval_dir += '_slim'

        tf.logging.info('Evaluating %s' % checkpoint_path)

        my_summary_op = tf.summary.merge_all()

        saver = tf.train.Saver(variables_to_restore)

        def restore_fn(sess):
            return saver.restore(sess, checkpoint_path)

        ##========================
        tf.logging.info('*********** Starting evaluation at ' +
                        time.strftime('%Y-%m-%d-%H:%M:%S', time.gmtime()))
        if use_placeholders:
            with tf.Session() as sess:
                writer = tf.summary.FileWriter(eval_dir, sess.graph)

                def eval_loop_pl(sess,
                                 name_,
                                 filenames_=None,
                                 labels_=None,
                                 threshold_=None):

                    images_pl_ = images_pl
                    probabilities_op_pl_ = probabilities_op

                    results_file_name1 = eval_dir + '/' + name_ + '_pl'
                    tag_name = 'Evaluation/%s/whole_set_accuracy_pl' % (name_)
                    if FLAGS.oversample_at_eval:
                        results_file_name1 = results_file_name1 + '_oversample'
                        tag_name += '_oversample'
                    results_file_name1 = results_file_name1 + '.txt'
                    accuracyy, acc_at_in_thr, eer_thr, eer_thr_max = classifier_pl.evaluate_loop_placeholder(
                        sess,
                        probabilities_op_pl_,
                        images_pl_,
                        filenames_,
                        categories_file,
                        labels_,
                        results_file=results_file_name1,
                        threshold=threshold_,
                        batch_size=eval_batch_size,
                        summary_op=my_summary_op,
                        summary_writer=writer)

                    summary_str = tf.Summary()
                    summary_str.value.add(tag=tag_name, simple_value=accuracyy)
                    if threshold_ is not None:
                        if isinstance(threshold_, list):
                            summary_str.value.add(
                                tag=tag_name + '/atGivenThr',
                                simple_value=max(acc_at_in_thr))
                        else:
                            summary_str.value.add(tag=tag_name + '/atGivenThr',
                                                  simple_value=acc_at_in_thr)
                    writer.add_summary(summary_str, sess.run(tf_global_step))
                    if threshold_ is None:
                        tf.logging.info('-----------  %s: [%.4f]', tag_name,
                                        accuracyy)
                    else:
                        if isinstance(threshold_, list):
                            tf.logging.info(
                                '-----------  %s: [%.4f]. At given threshold (%s): [%s]',
                                tag_name, accuracyy,
                                ','.join(str(thr) for thr in threshold_),
                                ','.join(str(acc) for acc in acc_at_in_thr))
                        else:
                            tf.logging.info(
                                '-----------  %s: [%.4f]. At given threshold (%.4f): [%.4f]',
                                tag_name, accuracyy, threshold_, acc_at_in_thr)

                    if eer_thr_max > 0:
                        eer_thr = [eer_thr, eer_thr_max]
                    return eer_thr

                # =============
                restore_fn(sess)
                eer_thr = eval_loop_pl(sess,
                                       name_=FLAGS.dataset_split_name_y,
                                       filenames_=filenames,
                                       labels_=labels)
                if FLAGS.dataset_split_name_y2 is not None:
                    eval_loop_pl(sess,
                                 name_=FLAGS.dataset_split_name_y2,
                                 filenames_=filenames_2,
                                 labels_=labels_2,
                                 threshold_=eer_thr)

        else:
            ## ===========================
            def eval_loop(sess,
                          eval_ops_,
                          num_batches_,
                          name_,
                          accuracy_value_=None,
                          probabilities_op_=None,
                          eval_ops_slim_=None,
                          threshold_=None,
                          filenames_op_=None):

                if FLAGS.use_slim_stream_eval:
                    evaluation_y.evaluate_loop_slim_streaming_metrics(
                        sess, num_batches_, eval_ops_)
                    tf.logging.info(
                        '-----------  %s: Final Streaming Accuracy[%s]: %.4f',
                        time.strftime('%Y-%m-%d-%H:%M:%S', time.gmtime()),
                        name_,
                        sess.run(accuracy_value_) * 100)
                    eer_thr = None
                else:
                    if probabilities_op_ is not None:
                        eval_ops_.append(probabilities_op_)
                    results_file_name = eval_dir + '/incorrect_filenames_' + name_ + '.txt'
                    accuracyy, _, _, _, _, acc_at_in_thr, eer_thr, eer_thr_max = evaluation_y.evaluate_loop_y(
                        sess,
                        num_batches_,
                        eval_batch_size,
                        eval_ops_,
                        eval_ops_slim=eval_ops_slim_,
                        threshold=threshold_,
                        filenames=filenames_op_,
                        results_file=results_file_name)
                    if accuracy_value_ is not None:
                        tf.logging.info(
                            '-----------  %s: Final Streaming Accuracy[%s]: %.4f',
                            time.strftime('%Y-%m-%d-%H:%M:%S', time.gmtime()),
                            name_,
                            sess.run(accuracy_value_) * 100)
                    tag_name = 'Evaluation/%s/whole_set_accuracy' % (name_)
                    summary_str = tf.Summary()
                    summary_str.value.add(tag=tag_name, simple_value=accuracyy)
                    if threshold_ is not None:
                        if isinstance(threshold_, list):
                            summary_str.value.add(
                                tag=tag_name + '/atGivenThr',
                                simple_value=max(acc_at_in_thr))
                        else:
                            summary_str.value.add(tag=tag_name + '/atGivenThr',
                                                  simple_value=acc_at_in_thr)
                    sv.summary_writer.add_summary(summary_str,
                                                  sess.run(tf_global_step))
                    if threshold_ is None:
                        tf.logging.info('-----------  %s: [%.4f]', tag_name,
                                        accuracyy)
                    else:
                        if isinstance(threshold_, list):
                            tf.logging.info(
                                '-----------  %s: [%.4f]. At given threshold (%s): [%s]',
                                tag_name, accuracyy,
                                ','.join(str(thr) for thr in threshold_),
                                ','.join(str(acc) for acc in acc_at_in_thr))
                        else:
                            tf.logging.info(
                                '-----------  %s: [%.4f]. At given threshold (%.4f): [%.4f]',
                                tag_name, accuracyy, threshold_, acc_at_in_thr)

                if eer_thr_max > 0:
                    eer_thr = [eer_thr, eer_thr_max]
                return eer_thr

            sv = tf.train.Supervisor(logdir=eval_dir,
                                     summary_op=None,
                                     saver=None,
                                     init_fn=restore_fn)
            with sv.managed_session() as sess:
                #######################################
                if FLAGS.use_slim_stream_eval:
                    eval_ops = eval_ops_slim
                    eval_ops_2 = eval_ops_slim_2
                else:
                    eval_ops = eval_ops_y
                    eval_ops_2 = eval_ops_y_2
                eer_threshold = eval_loop(sess,
                                          eval_ops,
                                          num_batches,
                                          name_=FLAGS.dataset_split_name_y,
                                          eval_ops_slim_=eval_ops_slim,
                                          accuracy_value_=accuracy_value,
                                          probabilities_op_=probabilities_op,
                                          filenames_op_=filenames_op)
                if FLAGS.dataset_split_name_y2 is not None:
                    eval_loop(sess,
                              eval_ops_2,
                              num_batches_2,
                              name_=FLAGS.dataset_split_name_y2,
                              eval_ops_slim_=eval_ops_slim_2,
                              accuracy_value_=accuracy_value_2,
                              probabilities_op_=probabilities_op_2,
                              filenames_op_=filenames_op_2,
                              threshold_=eer_threshold)

                summaries = sess.run(my_summary_op)
                sv.summary_computed(sess, summaries)

        tf.logging.info('********* Finished evaluation at ' +
                        time.strftime('%Y-%m-%d-%H:%M:%S', time.gmtime()))
def main(_):
    tfRecords_dir = FLAGS.tfRecords_dir
    if (FLAGS.tfRecords_dir is None) or (FLAGS.use_placeholders):
        _, tfRecords_dir, imgs_sub_dirs, csv_files, categories_file = paths_namings.get_dataset_paths_and_settings(
            FLAGS.dataset_name)

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        # Create global_step
        global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        pattern = paths_namings.file_pattern_tfrecords(
            FLAGS, tfRecords_dir, FLAGS.dataset_split_name)
        dataset_Ob = dataset_factory_tf.get_dataset_y(FLAGS.dataset_name,
                                                      FLAGS.dataset_split_name,
                                                      file_pattern=pattern)
        dataset = dataset_Ob.get_split_slim(tfRecords_dir=tfRecords_dir,
                                            n_channels=FLAGS.channels)

        ######################
        # Select the network #
        ######################
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay,
            is_training=True)

        #####################################
        # Select the preprocessing function #
        #####################################
        preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            preprocessing_name, is_training=True)

        ##############################################################
        # Create a dataset provider that loads data from the dataset #
        ##############################################################
        train_image_size = FLAGS.image_size or network_fn.default_image_size
        images, _, labels, _ = load_batch_slim(
            dataset,
            FLAGS.batch_size,
            train_image_size,
            train_image_size,
            preprocessing_name,
            image_preprocessing_fn,
            num_readers=FLAGS.num_readers,
            num_preprocessing_threads=FLAGS.num_preprocessing_threads,
            per_image_standardization=FLAGS.per_image_standardization,
            vgg_sub_mean_pixel=FLAGS.vgg_sub_mean_pixel,
            vgg_resize_side_in=FLAGS.vgg_resize_side,
            vgg_use_aspect_preserving_resize=FLAGS.
            vgg_use_aspect_preserving_resize,
            labels_offset=FLAGS.labels_offset,
            is_training=True)

        #############
        logits, end_points, _, _, eval_ops = get_logits_and_valid_ops(
            images,
            labels,
            network_fn,
            one_hot=False,
            batch_summ=True,
            scope='Evals_Train/Batch')

        #############################
        # Specify the loss function #
        #############################
        labels_one_hot = slim.one_hot_encoding(
            labels, dataset.num_classes - FLAGS.labels_offset)
        if 'AuxLogits' in end_points:
            tf.losses.softmax_cross_entropy(logits=end_points['AuxLogits'],
                                            onehot_labels=labels_one_hot,
                                            weights=0.4,
                                            scope='aux_loss')
        tf.losses.softmax_cross_entropy(logits=logits,
                                        onehot_labels=labels_one_hot,
                                        weights=1.0)

        #############
        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        # Add summaries for end_points.
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(
                tf.summary.scalar('sparsity/' + end_point,
                                  tf.nn.zero_fraction(x)))

        # Add summaries for losses.
        for loss in tf.get_collection(tf.GraphKeys.LOSSES, ''):
            summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

        # Add summaries for variables.
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        #########################################
        # Configure the optimization procedure. #
        #########################################
        learning_rate = _configure_learning_rate(dataset.num_samples,
                                                 global_step)
        optimizer = _configure_optimizer(learning_rate)
        summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #############
        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by network_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       '')  # YY: first_clone_scope)

        if FLAGS.moving_average_decay:  # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = _get_variables_to_train()

        ############# Loss and Optimizing
        total_loss = tf.losses.get_total_loss(
        )  # obtain the regularization losses as well
        clones_gradients = optimizer.compute_gradients(
            total_loss, var_list=variables_to_train)

        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        train_tensor = control_flow_ops.with_dependencies([update_op],
                                                          total_loss,
                                                          name='train_op')
        train_op_y = train_tensor

        ###########################
        ###########################
        checkpoints_dir_y = paths_namings.generate_checkpoints_dir(
            FLAGS, train_image_size)
        checkpoints_dir_y = checkpoints_dir_y + '_slim/pr_' + preprocessing_name
        print(checkpoints_dir_y)

        #################################################################################
        network_fn_eval = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(dataset.num_classes - FLAGS.labels_offset),
            is_training=False)

        ## ==========
        num_batches = math.ceil(dataset.num_samples / float(FLAGS.batch_size))
        _, _, names_to_values, names_to_updates, eval_ops_tr = get_logits_and_valid_ops(
            images, labels, network_fn_eval, scope='Evals_Train')
        ## ==========
        images_val, labels_val, _, num_samples_val = get_images_and_labels(
            'validation',
            preprocessing_name,
            image_preprocessing_fn,
            train_image_size,
            tfRecords_dir,
            is_training=False)
        num_batches_valid = math.ceil(num_samples_val /
                                      float(FLAGS.batch_size))
        _, _, names_to_values_valid, names_to_updates_valid, eval_ops_v = get_logits_and_valid_ops(
            images_val, labels_val, network_fn_eval, scope='Evals_Val')

        ########################################################################################
        # Add the summaries that contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES, ''))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        ########################################################################################
        with tf.name_scope('train_step_eval'):
            train_step_kwargs_extra = {}
            train_step_kwargs_extra['should_eval_train'] = math_ops.equal(
                math_ops.mod(global_step, FLAGS.validation_every_n_steps), 0)
            train_step_kwargs_extra['should_eval_val'] = math_ops.equal(
                math_ops.mod(global_step, FLAGS.test_every_n_steps), 0)
            train_step_kwargs_extra['num_batches_train'] = num_batches
            train_step_kwargs_extra['num_batches_val'] = num_batches_valid
            train_step_kwargs_extra['eval_ops_slim_train'] = list(
                names_to_updates.values())
            train_step_kwargs_extra['eval_ops_slim_val'] = list(
                names_to_updates_valid.values())
            train_step_kwargs_extra['stream_acc_slim_train'] = names_to_values[
                'Accuracy']
            train_step_kwargs_extra[
                'stream_acc_slim_val'] = names_to_values_valid['Accuracy']
            train_step_kwargs_extra['eval_ops_train'] = eval_ops_tr
            train_step_kwargs_extra['eval_ops_val'] = eval_ops_v

        ########################################################################################
        tf.logging.info(' ********** Starting Training %s' % checkpoints_dir_y)
        session_config = tf.ConfigProto()
        session_config.gpu_options.allow_growth = True

        slim_learning_y.train_y(  # slim.learning.train(
            train_op_y,
            logdir=checkpoints_dir_y,
            is_chief=True,
            init_fn=_get_init_fn(checkpoints_dir_y),
            summary_op=summary_op,
            number_of_steps=FLAGS.max_number_of_steps,
            log_every_n_steps=FLAGS.log_every_n_steps,
            save_summaries_secs=FLAGS.save_summaries_secs,
            save_interval_secs=FLAGS.save_interval_secs,
            eval_ops=list(names_to_updates.values()),  # YY
            num_evals=num_batches,
            eval_ops_valid=list(names_to_updates_valid.values()),  # YY
            num_evals_valid=num_batches_valid,
            session_config=session_config,
            train_step_fn=train_step_fn_y,
            train_step_kwargs_extra=train_step_kwargs_extra)
        tf.logging.info(' ********** Finished Training %s' % checkpoints_dir_y)