コード例 #1
0
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.DEBUG)

    with tf.Graph().as_default():
        ######################
        # Config model_deploy#
        ######################
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.num_ps_tasks)

        # Create global_step
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        network_fn = nets_factory.get_network(FLAGS.model_name)
        params = network_fn.default_params
        params = params._replace(match_threshold=FLAGS.match_threshold)
        # initalize the net
        net = network_fn(params)
        out_shape = net.params.img_shape
        anchors = net.anchors(out_shape)

        # create batch dataset
        with tf.device(deploy_config.inputs_device()):
            b_image, b_glocalisations, b_gscores = \
            load_batch.get_batch(FLAGS.dataset_dir,
                  FLAGS.num_readers,
                  FLAGS.batch_size,
                  out_shape,
                  net,
                  anchors,
                  FLAGS,
                  file_pattern = FLAGS.file_pattern,
                  is_training = True,
                  shuffe = FLAGS.shuffle_data)
            allgscores = []
            allglocalization = []
            for i in range(len(anchors)):
                allgscores.append(tf.reshape(b_gscores[i], [-1]))
                allglocalization.append(
                    tf.reshape(b_glocalisations[i], [-1, 4]))

            b_gscores = tf.concat(allgscores, 0)
            b_glocalisations = tf.concat(allglocalization, 0)

            batch_queue = slim.prefetch_queue.prefetch_queue(
                tf_utils.reshape_list([b_image, b_glocalisations, b_gscores]),
                num_threads=8,
                capacity=16 * deploy_config.num_clones)

        # =================================================================== #
        # Define the model running on every GPU.
        # =================================================================== #
        def clone_fn(batch_queue):

            #Allows data parallelism by creating multiple
            #clones of network_fn.

            # Dequeue batch.
            batch_shape = [1] * 3
            b_image, b_glocalisations, b_gscores = \
             tf_utils.reshape_list(batch_queue.dequeue(), batch_shape)
            # Construct SSD network.
            arg_scope = net.arg_scope(weight_decay=FLAGS.weight_decay,
                                      data_format=FLAGS.data_format)
            with slim.arg_scope(arg_scope):
                localisations, logits, end_points = \
                 net.net(b_image, is_training=True, use_batch=FLAGS.use_batch)
            # Add loss function.
            net.losses(logits,
                       localisations,
                       b_glocalisations,
                       b_gscores,
                       negative_ratio=FLAGS.negative_ratio,
                       use_hard_neg=FLAGS.use_hard_neg,
                       alpha=FLAGS.loss_alpha,
                       label_smoothing=FLAGS.label_smoothing)
            return end_points

        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

        clones = model_deploy.create_clones(deploy_config, clone_fn,
                                            [batch_queue])
        first_clone_scope = deploy_config.clone_scope(0)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        end_points = clones[0].outputs
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))

        for loss in tf.get_collection('EXTRA_LOSSES', first_clone_scope):
            summaries.add(tf.summary.scalar(loss.op.name, loss))

        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(deploy_config.optimizer_device()):
            learning_rate = tf_utils.configure_learning_rate(
                FLAGS, FLAGS.num_samples, global_step)
            optimizer = tf_utils.configure_optimizer(FLAGS, learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        if FLAGS.fine_tune:
            gradient_multipliers = pickle.load(
                open('nets/multiplier_300.pkl', 'rb'))
        else:
            gradient_multipliers = None

        if FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = tf_utils.get_variables_to_train(FLAGS)

        #  and returns a train_tensor and summary_op
        total_loss, clones_gradients = model_deploy.optimize_clones(
            clones, optimizer, var_list=variables_to_train)
        # Add total_loss to summary.
        summaries.add(tf.summary.scalar('total_loss', total_loss))

        # Create gradient updates.
        grad_updates = optimizer.apply_gradients(clones_gradients,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        train_tensor = control_flow_ops.with_dependencies([update_op],
                                                          total_loss,
                                                          name='train_op')

        #train_tensor = slim.learning.create_train_op(total_loss, optimizer, gradient_multipliers=gradient_multipliers)
        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # =================================================================== #
        # Kicks off the training.
        # =================================================================== #
        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction,
            allocator_type="BFC")
        config = tf.ConfigProto(
            gpu_options=gpu_options,
            log_device_placement=False,
            allow_soft_placement=True,
            inter_op_parallelism_threads=0,
            intra_op_parallelism_threads=1,
        )
        saver = tf.train.Saver(max_to_keep=5,
                               keep_checkpoint_every_n_hours=1.0,
                               write_version=2,
                               pad_step_number=False)

        slim.learning.train(train_tensor,
                            logdir=FLAGS.train_dir,
                            master='',
                            is_chief=True,
                            init_fn=tf_utils.get_init_fn(FLAGS),
                            summary_op=summary_op,
                            number_of_steps=FLAGS.max_number_of_steps,
                            log_every_n_steps=FLAGS.log_every_n_steps,
                            save_summaries_secs=FLAGS.save_summaries_secs,
                            saver=saver,
                            save_interval_secs=FLAGS.save_interval_secs,
                            session_config=config,
                            sync_optimizer=None)
コード例 #2
0
ファイル: eval.py プロジェクト: mjh2040/TextBox_pp
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.INFO)
    with tf.Graph().as_default():
        tf_global_step = slim.get_or_create_global_step()

        # initalize the net
        network_fn = nets_factory.get_network(FLAGS.model_name)
        net = network_fn()
        out_shape = net.params.img_shape
        out_shape = (300, 300)
        anchors = net.anchors(out_shape)
        # =================================================================== #
        # Create a dataset provider and batches.
        # =================================================================== #
        with tf.device('/cpu:0'):
            b_image, glabels, b_gbboxes, g_bbox_img, b_glocalisations, b_gscores =\
                load_batch.get_batch(FLAGS.dataset_dir,
                    FLAGS.num_readers,
                    FLAGS.batch_size,
                    out_shape,
                    net,
                    anchors,
                    FLAGS,
                    file_pattern =  '*.tfrecord',
                    is_training = False,
                    shuffe = FLAGS.shuffle_data)
        b_gdifficults = tf.zeros(tf.shape(glabels), dtype=tf.int64)
        dict_metrics = {}
        arg_scope = net.arg_scope(data_format=DATA_FORMAT)
        with slim.arg_scope(arg_scope):
            localisations, logits, end_points  = \
             net.net(b_image, is_training=False, use_batch=FLAGS.use_batch)
        # Add losses functions.
        #total_loss = net.losses(logits, localisations,
        #					  b_glocalisations, b_gscores)
        predictions = []
        for i in range(len(logits)):
            predictions.append(slim.softmax(logits[i]))

        # Performing post-processing on CPU: loop-intensive, usually more efficient.
        with tf.device('/device:CPU:0'):
            # Detected objects from SSD output.
            localisations = net.bboxes_decode(localisations, anchors)
            rscores, rbboxes = \
             net.detected_bboxes(predictions, localisations,
                   select_threshold=FLAGS.select_threshold,
                   nms_threshold=FLAGS.nms_threshold,
                   clipping_bbox=None,
                   top_k=FLAGS.select_top_k,
                   keep_top_k=FLAGS.keep_top_k)
            # Compute TP and FP statistics.
            num_gbboxes, tp, fp, rscores = \
             tfe.bboxes_matching_batch(rscores.keys(), rscores, rbboxes,
                     glabels, b_gbboxes, b_gdifficults,
                     matching_threshold=FLAGS.matching_threshold)

        # Variables to restore: moving avg. or normal weights.
        if FLAGS.moving_average_decay:
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, tf_global_step)
            variables_to_restore = variable_averages.variables_to_restore(
                slim.get_model_variables())
            variables_to_restore[tf_global_step.op.name] = tf_global_step
        else:
            variables_to_restore = slim.get_variables_to_restore()

        # =================================================================== #
        # Evaluation metrics.
        # =================================================================== #
        with tf.device(FLAGS.gpu_eval):
            dict_metrics = {}
            # Extra losses as well.
            for loss in tf.get_collection('EXTRA_LOSSES'):
                dict_metrics[loss.op.name] = slim.metrics.streaming_mean(loss)

            # Add metrics to summaries and Print on screen.
            for name, metric in dict_metrics.items():
                # summary_name = 'eval/%s' % name
                summary_name = name
                op = tf.summary.scalar(summary_name, metric[0], collections=[])
                # op = tf.Print(op, [metric[0]], summary_name)
                tf.add_to_collection(tf.GraphKeys.SUMMARIES, op)

            # FP and TP metrics.
            tp_fp_metric = tfe.streaming_tp_fp_arrays(num_gbboxes, tp, fp,
                                                      rscores)
            for c in tp_fp_metric[0].keys():
                dict_metrics['tp_fp_%s' % c] = (tp_fp_metric[0][c],
                                                tp_fp_metric[1][c])

            # Add to summaries precision/recall values.
            icdar2013 = {}
            for c in tp_fp_metric[0].keys():
                # Precison and recall values.
                prec, rec = tfe.precision_recall(*tp_fp_metric[0][c])

                op = tf.summary.scalar('precision',
                                       tf.reduce_mean(prec),
                                       collections=[])
                # op = tf.Print(op, [v], summary_name)
                tf.add_to_collection(tf.GraphKeys.SUMMARIES, op)

                op = tf.summary.scalar('recall',
                                       tf.reduce_mean(rec),
                                       collections=[])
                # op = tf.Print(op, [v], summary_name)
                tf.add_to_collection(tf.GraphKeys.SUMMARIES, op)

                # Average precision VOC07.
                v = tfe.average_precision_voc12(prec, rec)
                #v = (prec + rec)/2.
                summary_name = 'ICDAR13/%s' % c
                op = tf.summary.scalar(summary_name, v, collections=[])
                # op = tf.Print(op, [v], summary_name)
                tf.add_to_collection(tf.GraphKeys.SUMMARIES, op)
                icdar2013[c] = v

            # Mean average precision VOC07.
            summary_name = 'ICDAR13/mAP'
            mAP = tf.add_n(list(icdar2013.values())) / len(icdar2013)
            op = tf.summary.scalar(summary_name, mAP, collections=[])
            op = tf.Print(op, [mAP], summary_name)
            tf.add_to_collection(tf.GraphKeys.SUMMARIES, op)

        # Split into values and updates ops.
        names_to_values, names_to_updates = slim.metrics.aggregate_metric_map(
            dict_metrics)

        # =================================================================== #
        # Evaluation loop.
        # =================================================================== #
        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction)
        config = tf.ConfigProto(log_device_placement=False,
                                gpu_options=gpu_options)
        # config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1

        # Number of batches...
        if FLAGS.max_num_batches:
            num_batches = FLAGS.max_num_batches
        else:
            num_batches = math.ceil(FLAGS.num_samples /
                                    float(FLAGS.batch_size))

        if not FLAGS.wait_for_checkpoints:
            if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
                checkpoint_path = tf.train.latest_checkpoint(
                    FLAGS.checkpoint_path)
            else:
                checkpoint_path = FLAGS.checkpoint_path
            tf.logging.info('Evaluating %s' % checkpoint_path)

            # Standard evaluation loop.
            start = time.time()
            slim.evaluation.evaluate_once(
                master=FLAGS.master,
                checkpoint_path=checkpoint_path,
                logdir=FLAGS.eval_dir,
                num_evals=num_batches,
                eval_op=list(names_to_updates.values()),
                variables_to_restore=variables_to_restore,
                session_config=config)
            # Log time spent.
            elapsed = time.time()
            elapsed = elapsed - start
            print('Time spent : %.3f seconds.' % elapsed)
            print('Time spent per BATCH: %.3f seconds.' %
                  (elapsed / num_batches))

        else:
            checkpoint_path = FLAGS.checkpoint_path
            tf.logging.info('Evaluating %s' % checkpoint_path)

            # Waiting loop.
            slim.evaluation.evaluation_loop(
                master=FLAGS.master,
                checkpoint_dir=checkpoint_path,
                logdir=FLAGS.eval_dir,
                num_evals=num_batches,
                eval_op=list(names_to_updates.values()),
                variables_to_restore=variables_to_restore,
                eval_interval_secs=60,
                max_number_of_evaluations=np.inf,
                session_config=config,
                timeout=None)
コード例 #3
0
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.DEBUG)

    with tf.Graph().as_default():

        # initalize the net
        net = txtbox_300.TextboxNet()
        out_shape = net.params.img_shape
        anchors = net.anchors(out_shape)

        # Create global_step.
        global_step = slim.create_global_step()
        # create batch dataset

        with tf.device(FLAGS.gpu_data):
            b_image, b_glocalisations, b_gscores = \
            load_batch.get_batch(FLAGS.dataset_dir,
                  FLAGS.num_readers,
                  FLAGS.batch_size,
                  out_shape,
                  net,
                  anchors,
                  FLAGS.num_preprocessing_threads,
                  is_training = True)

        with tf.device(FLAGS.gpu_train):
            arg_scope = net.arg_scope(weight_decay=FLAGS.weight_decay)

            with slim.arg_scope(arg_scope):
                localisations, logits, end_points = \
                  net.net(b_image, is_training=True)

            # Add loss function.
            total_loss = net.losses(logits,
                                    localisations,
                                    b_glocalisations,
                                    b_gscores,
                                    match_threshold=FLAGS.match_threshold,
                                    negative_ratio=FLAGS.negative_ratio,
                                    alpha=FLAGS.loss_alpha,
                                    label_smoothing=FLAGS.label_smoothing)

        # Gather summaries.

        for end_point in end_points:
            x = end_points[end_point]
            tf.summary.histogram('activations/' + end_point, x)
            tf.summary.scalar('sparsity/' + end_point, tf.nn.zero_fraction(x))

        for loss in tf.get_collection(tf.GraphKeys.LOSSES):
            tf.summary.scalar(loss.op.name, loss)

        for loss in tf.get_collection('EXTRA_LOSSES'):
            tf.summary.scalar(loss.op.name, loss)

        for variable in slim.get_model_variables():
            tf.summary.histogram(variable.op.name, variable)

        with tf.device(FLAGS.gpu_train):
            learning_rate = tf_utils.configure_learning_rate(
                FLAGS, FLAGS.num_samples, global_step)
            # Configure the optimization procedure
            optimizer = tf_utils.configure_optimizer(FLAGS, learning_rate)
            tf.summary.scalar('learning_rate', learning_rate)

            ## Training

            train_op = slim.learning.create_train_op(total_loss, optimizer)

        merged = tf.summary.merge_all()

        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction)
        config = tf.ConfigProto(gpu_options=gpu_options,
                                log_device_placement=False,
                                allow_soft_placement=True)

        # Checkpoint directory. Tensorflow assumes this directory already exists so we need to create it
        checkpoint_dir = FLAGS.train_dir
        checkpoint_prefix = os.path.join(checkpoint_dir, "model.ckpt")
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)

        with tf.Session(config=config) as sess:
            sess.run(tf.global_variables_initializer())
            train_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)
            saver = tf.train.Saver(max_to_keep=1,
                                   keep_checkpoint_every_n_hours=1.0,
                                   pad_step_number=False)
            path = tf.train.latest_checkpoint(FLAGS.train_dir)
            if path:
                saver.restore(sess, path)
                print sess.run([global_step])
            with slim.queues.QueueRunners(sess):
                for i in xrange(FLAGS.max_number_of_steps):
                    loss, _ , summary_, global_step_= \
                    sess.run([total_loss,train_op,merged,global_step])
                    current_step = tf.train.global_step(sess, global_step)
                    if i % 10 == 0:
                        print loss
                    if global_step_ % 2 == 0:
                        train_writer.add_summary(summary_, global_step_)
                    if global_step_ % 100 == 0:
                        path = saver.save(sess,
                                          checkpoint_prefix,
                                          global_step=current_step)
                        print("Saved model checkpoint to {}\n".format(path))
コード例 #4
0
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.DEBUG)
    with tf.Graph().as_default():

        global_step = slim.create_global_step()

        # Get the SSD network and its anchors.
        #ssd_params = ssd_class.default_params._replace(num_classes=FLAGS.num_classes)
        ssd_net = ssd_vgg_300.SSDNet()
        ssd_shape = ssd_net.params.img_shape
        ssd_anchors = ssd_net.anchors(ssd_shape)


        b_image, b_gclasses, b_glocalisations, b_gscores = \
         load_batch.get_batch(FLAGS.dataset_dir,
             FLAGS.num_readers,
             FLAGS.batch_size,
             ssd_shape,
             ssd_net,
             ssd_anchors,
             FLAGS.num_preprocessing_threads,
             is_training = True)

        with tf.device(FLAGS.gpu_train):
            arg_scope = ssd_net.arg_scope(weight_decay=FLAGS.weight_decay)
            with slim.arg_scope(arg_scope):
                predictions, localisations, logits, end_points = \
                 ssd_net.net(b_image, is_training=True)
            # Add loss function.
            total_loss = ssd_net.losses(logits,
                                        localisations,
                                        b_gclasses,
                                        b_glocalisations,
                                        b_gscores,
                                        match_threshold=FLAGS.match_threshold,
                                        negative_ratio=FLAGS.negative_ratio,
                                        alpha=FLAGS.loss_alpha,
                                        label_smoothing=FLAGS.label_smoothing)

        # Gather initial summaries.

        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))

        for loss in tf.get_collection('EXTRA_LOSSES'):
            summaries.add(tf.summary.scalar(loss.op.name, loss))

        with tf.device(FLAGS.gpu_train):
            learning_rate = tf_utils.configure_learning_rate(
                FLAGS, FLAGS.num_samples, global_step)
            # Configure the optimization procedure
            optimizer = tf_utils.configure_optimizer(FLAGS, learning_rate)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

            ## Training

            train_op = slim.learning.create_train_op(total_loss, optimizer)

        # =================================================================== #
        # Kicks off the training.
        # =================================================================== #
        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction)
        config = tf.ConfigProto(gpu_options=gpu_options,
                                log_device_placement=False,
                                allow_soft_placement=True)
        saver = tf.train.Saver(max_to_keep=1,
                               keep_checkpoint_every_n_hours=1.0,
                               write_version=2,
                               pad_step_number=False)
        slim.learning.train(train_op,
                            logdir=FLAGS.train_dir,
                            master='',
                            is_chief=True,
                            init_fn=tf_utils.get_init_fn(FLAGS),
                            number_of_steps=FLAGS.max_number_of_steps,
                            log_every_n_steps=FLAGS.log_every_n_steps,
                            save_summaries_secs=FLAGS.save_summaries_secs,
                            saver=saver,
                            save_interval_secs=FLAGS.save_interval_secs,
                            session_config=config,
                            sync_optimizer=None)
コード例 #5
0
def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError(
            'You must supply the dataset directory with --dataset_dir')

    tf.logging.set_verbosity(tf.logging.DEBUG)

    with tf.Graph().as_default():
        network_fn = nets_factory.get_network(FLAGS.model_name)
        params = network_fn.default_params
        params = params._replace(match_threshold=FLAGS.match_threshold)
        # initalize the net
        net = network_fn(params)
        out_shape = net.params.img_shape
        anchors = net.anchors(out_shape)

        # Create global_step.

        global_step = slim.create_global_step()
        # create batch dataset

        b_image, b_glocalisations, b_gscores = \
        load_batch.get_batch(FLAGS.dataset_dir,
                             FLAGS.num_readers,
                             FLAGS.batch_size,
                             out_shape,
                             net,
                             anchors,
                             FLAGS,
                             file_pattern = FLAGS.file_pattern,
                             is_training = True,
                             shuffe = FLAGS.shuffle_data)

        with tf.device(FLAGS.gpu_train):
            #with tf.device(FLAGS.gpu_train):

            arg_scope = net.arg_scope(weight_decay=FLAGS.weight_decay)

            with slim.arg_scope(arg_scope):
                localisations, logits, end_points = \
                        net.net(b_image, is_training=True, use_batch=FLAGS.use_batch)
            # Add loss function.
            total_loss = net.losses(logits,
                                    localisations,
                                    b_glocalisations,
                                    b_gscores,
                                    negative_ratio=FLAGS.negative_ratio,
                                    use_hard_neg=FLAGS.use_hard_neg,
                                    alpha=FLAGS.loss_alpha,
                                    label_smoothing=FLAGS.label_smoothing)

        # Gather summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
        '''
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
            summaries.add(tf.summary.scalar('sparsity/' + end_point,
                                            tf.nn.zero_fraction(x)))

        #for loss in tf.get_collection(tf.GraphKeys.LOSSES):
        #   summaries.add(tf.summary.scalar(loss.op.name, loss))
        '''
        for loss in tf.get_collection('EXTRA_LOSSES'):
            summaries.add(tf.summary.scalar(loss.op.name, loss))
        '''
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))
        '''
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        with tf.device(FLAGS.gpu_train):
            learning_rate = tf_utils.configure_learning_rate(
                FLAGS, FLAGS.num_samples, global_step)
            # Configure the optimization procedure
            optimizer = tf_utils.configure_optimizer(FLAGS, learning_rate)
            #summaries.add(tf.summary.scalar('learning_rate', learning_rate))

            ## Training
            #loss = tf.get_collection(tf.GraphKeys.LOSSES)
            #total_loss = tf.add_n(loss)
            '''
            if FLAGS.fine_tune:
                gradient_multipliers = pickle.load(open('nets/multiplier_300.pkl','rb'))
            else:
                gradient_multipliers = None
            '''
        if FLAGS.moving_average_decay:
            # Update ops executed locally by trainer.
            update_ops.append(
                variable_averages.apply(moving_average_variables))

        # Variables to train.
        variables_to_train = tf_utils.get_variables_to_train(FLAGS)
        vars_grad = optimizer.compute_gradients(total_loss, variables_to_train)
        grad_updates = optimizer.apply_gradients(vars_grad,
                                                 global_step=global_step)
        update_ops.append(grad_updates)

        update_op = tf.group(*update_ops)
        train_op = control_flow_ops.with_dependencies([update_op],
                                                      total_loss,
                                                      name='train_op')
        #train_op = slim.learning.create_train_op(total_loss, optimizer, gradient_multipliers=gradient_multipliers)

        # =================================================================== #
        # Kicks off the training.
        # =================================================================== #
        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction)
        config = tf.ConfigProto(gpu_options=gpu_options,
                                log_device_placement=False,
                                allow_soft_placement=True)
        saver = tf.train.Saver(max_to_keep=5,
                               keep_checkpoint_every_n_hours=1.0,
                               write_version=2,
                               pad_step_number=False)

        slim.learning.train(train_op,
                            logdir=FLAGS.train_dir,
                            master='',
                            is_chief=True,
                            init_fn=tf_utils.get_init_fn(FLAGS),
                            number_of_steps=FLAGS.max_number_of_steps,
                            log_every_n_steps=FLAGS.log_every_n_steps,
                            save_summaries_secs=FLAGS.save_summaries_secs,
                            saver=saver,
                            save_interval_secs=FLAGS.save_interval_secs,
                            session_config=config,
                            sync_optimizer=None)