Esempio n. 1
0
    def restore_fn(self, checkpoint_path, from_detection_checkpoint=True):
        """Return callable for loading a checkpoint into the tensorflow graph.

        Args:
          checkpoint_path: path to checkpoint to restore.
          from_detection_checkpoint: whether to restore from a full detection
            checkpoint (with compatible variable names) or to restore from a
            classification checkpoint for initialization prior to training.

        Returns:
          a callable which takes a tf.Session as input and loads a checkpoint when
            run.
        """
        variables_to_restore = {}
        for variable in tf.all_variables():
            if variable.op.name.startswith(self._extract_features_scope):
                var_name = variable.op.name
                if not from_detection_checkpoint:
                    var_name = (re.split(
                        '^' + self._extract_features_scope + '/',
                        var_name)[-1])
                variables_to_restore[var_name] = variable
        # TODO: Load variables selectively using scopes.
        variables_to_restore = (
            variables_helper.get_variables_available_in_checkpoint(
                variables_to_restore, checkpoint_path))
        saver = tf.train.Saver(variables_to_restore)

        def restore(sess):
            saver.restore(sess, checkpoint_path)

        return restore
 def test_return_all_variables_from_checkpoint(self):
     variables = [
         tf.Variable(1.0, name='weights'),
         tf.Variable(1.0, name='biases')
     ]
     checkpoint_path = os.path.join(self.get_temp_dir(), 'graph.pb')
     init_op = tf.global_variables_initializer()
     saver = tf.train.Saver(variables)
     with self.test_session() as sess:
         sess.run(init_op)
         saver.save(sess, checkpoint_path)
     out_variables = variables_helper.get_variables_available_in_checkpoint(
         variables, checkpoint_path)
     self.assertItemsEqual(out_variables, variables)
    def test_return_variables_available_in_checkpoint(self):
        checkpoint_path = os.path.join(self.get_temp_dir(), 'graph.pb')
        weight_variable = tf.Variable(1.0, name='weights')
        global_step = tf.train.get_or_create_global_step()
        graph1_variables = [weight_variable, global_step]
        init_op = tf.global_variables_initializer()
        saver = tf.train.Saver(graph1_variables)
        with self.test_session() as sess:
            sess.run(init_op)
            saver.save(sess, checkpoint_path)

        graph2_variables = graph1_variables + [tf.Variable(1.0, name='biases')]
        out_variables = variables_helper.get_variables_available_in_checkpoint(
            graph2_variables, checkpoint_path, include_global_step=False)
        self.assertItemsEqual(out_variables, [weight_variable])
Esempio n. 4
0
    def restore_from_classification_checkpoint_fn(
            self, checkpoint_path, first_stage_feature_extractor_scope,
            second_stage_feature_extractor_scope):
        """Returns callable for loading a checkpoint into the tensorflow graph.

        Note that this overrides the default implementation in
        faster_rcnn_meta_arch.FasterRCNNFeatureExtractor which does not work for
        InceptionResnetV2 checkpoints.

        TODO: revisit whether it's possible to force the `Repeat` namescope as
        created in `_extract_box_classifier_features` to start counting at 2 (e.g.
        `Repeat_2`) so that the default restore_fn can be used.

        Args:
          checkpoint_path: Path to checkpoint to restore.
          first_stage_feature_extractor_scope: A scope name for the first stage
            feature extractor.
          second_stage_feature_extractor_scope: A scope name for the second stage
            feature extractor.

        Returns:
          a callable which takes a tf.Session as input and loads a checkpoint when
            run.
        """
        variables_to_restore = {}
        for variable in tf.global_variables():
            if variable.op.name.startswith(
                    first_stage_feature_extractor_scope):
                var_name = variable.op.name.replace(
                    first_stage_feature_extractor_scope + '/', '')
                variables_to_restore[var_name] = variable
            if variable.op.name.startswith(
                    second_stage_feature_extractor_scope):
                var_name = variable.op.name.replace(
                    second_stage_feature_extractor_scope +
                    '/InceptionResnetV2/Repeat', 'InceptionResnetV2/Repeat_2')
                var_name = var_name.replace(
                    second_stage_feature_extractor_scope + '/', '')
                variables_to_restore[var_name] = variable
        variables_to_restore = (
            variables_helper.get_variables_available_in_checkpoint(
                variables_to_restore, checkpoint_path))
        saver = tf.train.Saver(variables_to_restore)

        def restore(sess):
            saver.restore(sess, checkpoint_path)

        return restore
    def test_return_variables_available_an_checkpoint_with_dict_inputs(self):
        checkpoint_path = os.path.join(self.get_temp_dir(), 'graph.pb')
        graph1_variables = [
            tf.Variable(1.0, name='ckpt_weights'),
        ]
        init_op = tf.global_variables_initializer()
        saver = tf.train.Saver(graph1_variables)
        with self.test_session() as sess:
            sess.run(init_op)
            saver.save(sess, checkpoint_path)

        graph2_variables_dict = {
            'ckpt_weights': tf.Variable(1.0, name='weights'),
            'ckpt_biases': tf.Variable(1.0, name='biases')
        }
        out_variables = variables_helper.get_variables_available_in_checkpoint(
            graph2_variables_dict, checkpoint_path)
        self.assertTrue(isinstance(out_variables, dict))
        self.assertItemsEqual(out_variables.keys(), ['ckpt_weights'])
        self.assertTrue(out_variables['ckpt_weights'].op.name == 'weights')
    def test_return_variables_with_correct_sizes(self):
        checkpoint_path = os.path.join(self.get_temp_dir(), 'graph.pb')
        bias_variable = tf.Variable(3.0, name='biases')
        global_step = tf.train.get_or_create_global_step()
        graph1_variables = [
            tf.Variable([[1.0, 2.0], [3.0, 4.0]], name='weights'),
            bias_variable, global_step
        ]
        init_op = tf.global_variables_initializer()
        saver = tf.train.Saver(graph1_variables)
        with self.test_session() as sess:
            sess.run(init_op)
            saver.save(sess, checkpoint_path)

        graph2_variables = [
            tf.Variable([1.0, 2.0],
                        name='weights'),  # Note the new variable shape.
            bias_variable,
            global_step
        ]

        out_variables = variables_helper.get_variables_available_in_checkpoint(
            graph2_variables, checkpoint_path, include_global_step=True)
        self.assertItemsEqual(out_variables, [bias_variable, global_step])
Esempio n. 7
0
def main(_):

    with tf.Graph().as_default() as graph:
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
        global_summaries = set([])

        num_batches_epoch = num_samples // (FLAGS.batch_size *
                                            FLAGS.num_clones)
        print(num_batches_epoch)

        #######################
        # Config model_deploy #
        #######################
        config = model_deploy.DeploymentConfig(
            num_clones=FLAGS.num_clones,
            clone_on_cpu=FLAGS.clone_on_cpu,
            replica_id=FLAGS.task,
            num_replicas=FLAGS.worker_replicas,
            num_ps_tasks=FLAGS.ps_tasks)

        # Create global_step
        with tf.device(config.variables_device()):
            global_step = slim.create_global_step()

        ######################
        # Select the dataset #
        ######################
        with tf.device(config.inputs_device()):
            # Train Process
            dataset = get_split('train', FLAGS.dataset_dir)
            provider = slim.dataset_data_provider.DatasetDataProvider(
                dataset,
                num_readers=FLAGS.num_readers,
                common_queue_capacity=FLAGS.batch_size * 20,
                common_queue_min=FLAGS.batch_size * 10)
            [image_a, image_b,
             label] = provider.get(['image_a', 'image_b', 'label'])
            probe = image_a

            galleries = tf.unstack(image_b)
            galleries_process = []
            probe = process_image(probe)
            probe.set_shape([FLAGS.target_height, FLAGS.target_width, 3])

            gallery_target = tf.slice(image_b, [label, 0, 0, 0],
                                      [1, -1, -1, -1])
            gallery_target = tf.squeeze(gallery_target, axis=[0])
            gallery = process_image(gallery_target)
            gallery.set_shape([FLAGS.target_height, FLAGS.target_width, 3])
            galleries_process.append(gallery)

            for Idx in range(FLAGS.top_k - 1):

                imgIdx = tf.cond(Idx >= label, lambda: Idx + 1, lambda: Idx)
                gallery_other = tf.slice(image_b, [imgIdx, 0, 0, 0],
                                         [1, -1, -1, -1])
                gallery_other = tf.squeeze(gallery_other, axis=[0])
                gallery = process_image(gallery_other)
                gallery.set_shape([FLAGS.target_height, FLAGS.target_width, 3])
                galleries_process.append(gallery)

            label_new = 0
            galleries_process = tf.stack(galleries_process)

            probe_batch, galleries_batch, labels = tf.train.batch(
                [probe, galleries_process, label_new],
                batch_size=FLAGS.batch_size,
                num_threads=8,
                capacity=FLAGS.batch_size * 10)

            inputs_queue = prefetch_queue(
                [probe_batch, galleries_batch, labels])

        ######################
        # Select the network #
        ######################
        def model_fn(inputs_queue):
            probe_batch, galleries_batch, labels = inputs_queue.dequeue()
            probe_batch_tile = tf.tile(tf.expand_dims(probe_batch, axis=1),
                                       [1, FLAGS.top_k, 1, 1, 1])
            shape = probe_batch_tile.get_shape().as_list()
            probe_batch_reshape = tf.reshape(
                probe_batch_tile, [-1, shape[2], shape[3], shape[4]])
            galleries_batch_reshape = tf.reshape(
                galleries_batch, [-1, shape[2], shape[3], shape[4]])
            images_a = probe_batch_reshape
            images_b = galleries_batch_reshape

            model = find_class_by_name(FLAGS.model, [models])()

            logits = model.create_model(images_a,
                                        images_b,
                                        reuse=False,
                                        is_training=True)
            logits = tf.reshape(logits, [FLAGS.batch_size, -1])
            label_onehot = tf.one_hot(labels, FLAGS.top_k)
            crossentropy_loss = tf.losses.softmax_cross_entropy(
                onehot_labels=label_onehot, logits=logits)

            tf.summary.histogram('images_a', images_a)

        clones = model_deploy.create_clones(config, model_fn, [inputs_queue])
        first_clone_scope = clones[0].scope

        #################################
        # Configure the moving averages #
        #################################
        if FLAGS.moving_average_decay:
            moving_average_variables = slim.get_model_variables()
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, global_step)
        else:
            moving_average_variables, variable_averages = None, None

        #########################################
        # Configure the optimization procedure. #
        #########################################
        with tf.device(config.optimizer_device()):

            learning_rate_step_boundaries = [
                int(num_batches_epoch * num_epoches * 0.60),
                int(num_batches_epoch * num_epoches * 0.75),
                int(num_batches_epoch * num_epoches * 0.90)
            ]
            learning_rate_sequence = [FLAGS.learning_rate]
            learning_rate_sequence += [
                FLAGS.learning_rate * 0.1, FLAGS.learning_rate * 0.01,
                FLAGS.learning_rate * 0.001
            ]
            learning_rate = learning_schedules.manual_stepping(
                global_step, learning_rate_step_boundaries,
                learning_rate_sequence)
            #             learning_rate = learning_schedules.exponential_decay_with_burnin(global_step,
            #                                   FLAGS.learning_rate,num_batches_epoch*num_epoches,0.001/FLAGS.learning_rate,
            #                                   burnin_learning_rate=0.01,
            #                                   burnin_steps=5000)
            if FLAGS.optimizer == 'adam':
                opt = tf.train.AdamOptimizer(learning_rate)
            if FLAGS.optimizer == 'momentum':
                opt = tf.train.MomentumOptimizer(learning_rate, momentum=0.9)
            summaries.add(tf.summary.scalar('learning_rate', learning_rate))

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)
        with tf.device(config.optimizer_device()):
            training_optimizer = opt

        # Create ops required to initialize the model from a given checkpoint. TODO!!
        init_fn = None
        if FLAGS.model == 'DCSL':
            if FLAGS.weights is None:
                # if not FLAGS.moving_average_decay:
                variables = slim.get_model_variables('InceptionResnetV2')
                init_fn = slim.assign_from_checkpoint_fn(
                    os.path.join(FLAGS.checkpoints_dir,
                                 'inception_resnet_v2.ckpt'),
                    slim.get_model_variables('InceptionResnetV2'))
        if FLAGS.model == 'DCSL_inception_v1':
            if FLAGS.weights is None:
                # if not FLAGS.moving_average_decay:
                variables = slim.get_model_variables('InceptionV1')
                init_fn = slim.assign_from_checkpoint_fn(
                    os.path.join(FLAGS.checkpoints_dir, 'inception_v1.ckpt'),
                    slim.get_model_variables('InceptionV1'))
        if FLAGS.model == 'DCSL_NAS':
            #             if FLAGS.weights is None:
            #                 # if not FLAGS.moving_average_decay:
            #                 variables = slim.get_model_variables('NAS')
            #                 init_fn = slim.assign_from_checkpoint_fn(
            #                     os.path.join(FLAGS.checkpoints_dir, 'nasnet-a_large_04_10_2017/model.ckpt'),
            #                     slim.get_model_variables('NAS'))
            def restore_map():
                variables_to_restore = {}
                for variable in tf.global_variables():
                    for scope_name in ['NAS']:
                        if variable.op.name.startswith(scope_name):
                            var_name = variable.op.name.replace(
                                scope_name + '/', '')
                            #                             var_name = variable.op.name
                            variables_to_restore[
                                var_name +
                                '/ExponentialMovingAverage'] = variable


#                             variables_to_restore[var_name] = variable
                return variables_to_restore

            var_map = restore_map()
            # restore_var = [v for v in tf.global_variables() if 'global_step' not in v.name]
            available_var_map = (
                variables_helper.get_variables_available_in_checkpoint(
                    var_map, FLAGS.weights))
            init_saver = tf.train.Saver(available_var_map)

            def initializer_fn(sess):
                init_saver.restore(sess, FLAGS.weights)

            init_fn = initializer_fn

        if FLAGS.model == 'MultiHeadAttentionBaseModel_set':
            if FLAGS.weights is None:
                # if not FLAGS.moving_average_decay:
                variables = slim.get_model_variables('InceptionV1')
                init_fn = slim.assign_from_checkpoint_fn(
                    os.path.join(FLAGS.checkpoints_dir, 'inception_v1.ckpt'),
                    slim.get_model_variables('InceptionV1'))
            else:
                restore_var = [
                    v for v in slim.get_model_variables()
                    if 'Score' not in v.name
                ]
                init_fn = slim.assign_from_checkpoint_fn(
                    FLAGS.weights, restore_var)
        if FLAGS.model == 'MultiHeadAttentionBaseModel_set_share':
            if FLAGS.weights is None:
                # if not FLAGS.moving_average_decay:
                variables = slim.get_model_variables('InceptionV1')
                init_fn = slim.assign_from_checkpoint_fn(
                    os.path.join(FLAGS.checkpoints_dir, 'inception_v1.ckpt'),
                    slim.get_model_variables('InceptionV1'))
            else:
                restore_var = [
                    v for v in slim.get_model_variables()
                    if 'Score' not in v.name
                ]
                init_fn = slim.assign_from_checkpoint_fn(
                    FLAGS.weights, restore_var)

        if FLAGS.model == 'MultiHeadAttentionBaseModel_set_share_softmatch':
            if FLAGS.weights is None:
                # if not FLAGS.moving_average_decay:
                variables = slim.get_model_variables('InceptionV1')
                init_fn = slim.assign_from_checkpoint_fn(
                    os.path.join(FLAGS.checkpoints_dir, 'inception_v1.ckpt'),
                    slim.get_model_variables('InceptionV1'))
            else:
                restore_var = [
                    v for v in slim.get_model_variables()
                    if 'Score' not in v.name
                ]
                init_fn = slim.assign_from_checkpoint_fn(
                    FLAGS.weights, restore_var)
        if FLAGS.model == 'MultiHeadAttentionBaseModel_set_share_softmatch_v2':
            if FLAGS.weights is None:
                # if not FLAGS.moving_average_decay:
                variables = slim.get_model_variables('InceptionV1')
                init_fn = slim.assign_from_checkpoint_fn(
                    os.path.join(FLAGS.checkpoints_dir, 'inception_v1.ckpt'),
                    slim.get_model_variables('InceptionV1'))
            else:
                restore_var = [
                    v for v in slim.get_model_variables()
                    if 'Score' not in v.name
                ]
                init_fn = slim.assign_from_checkpoint_fn(
                    FLAGS.weights, restore_var)

        if FLAGS.model == 'MultiHeadAttentionBaseModel_set_share_res50':
            if FLAGS.weights is None:
                # if not FLAGS.moving_average_decay:
                variables = slim.get_model_variables('resnet_v2_50')
                init_fn = slim.assign_from_checkpoint_fn(
                    os.path.join(FLAGS.checkpoints_dir, 'resnet_v2_50.ckpt'),
                    slim.get_model_variables('resnet_v2_50'))
        if FLAGS.model == 'MultiHeadAttentionBaseModel_set_inv3':
            # if not FLAGS.moving_average_decay:
            variables = slim.get_model_variables('InceptionV3')
            init_fn = slim.assign_from_checkpoint_fn(
                os.path.join(FLAGS.checkpoints_dir, 'inception_v3.ckpt'),
                slim.get_model_variables('InceptionV3'))

        # compute and update gradients
        with tf.device(config.optimizer_device()):
            if FLAGS.moving_average_decay:
                update_ops.append(
                    variable_averages.apply(moving_average_variables))

            # Variables to train.
            all_trainable = tf.trainable_variables()

            #  and returns a train_tensor and summary_op
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones,
                training_optimizer,
                regularization_losses=None,
                var_list=all_trainable)

            grad_mult = utils.get_model_gradient_multipliers(
                FLAGS.last_layer_gradient_multiplier)
            grads_and_vars = slim.learning.multiply_gradients(
                grads_and_vars, grad_mult)
            # Optionally clip gradients
            # with tf.name_scope('clip_grads'):
            #     grads_and_vars = slim.learning.clip_gradient_norms(grads_and_vars, 10)

            total_loss = tf.check_numerics(total_loss,
                                           'LossTensor is inf or nan.')

            # Create gradient updates.
            grad_updates = training_optimizer.apply_gradients(
                grads_and_vars, global_step=global_step)
            update_ops.append(grad_updates)

            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        # Add summaries.
        for loss_tensor in tf.losses.get_losses():
            global_summaries.add(
                tf.summary.scalar(loss_tensor.op.name, loss_tensor))
        global_summaries.add(
            tf.summary.scalar('TotalLoss', tf.losses.get_total_loss()))

        # Add the summaries from the first clone. These contain the summaries
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))
        summaries |= global_summaries
        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # GPU settings
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)
        session_config.gpu_options.allow_growth = False
        # Save checkpoints regularly.
        keep_checkpoint_every_n_hours = 2.0

        saver = tf.train.Saver(
            keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)

        ###########################
        # Kicks off the training. #
        ###########################
        slim.learning.train(train_tensor,
                            logdir=logdir,
                            master=FLAGS.master,
                            is_chief=(FLAGS.task == 0),
                            session_config=session_config,
                            startup_delay_steps=10,
                            summary_op=summary_op,
                            init_fn=init_fn,
                            number_of_steps=num_batches_epoch *
                            FLAGS.num_epoches,
                            save_summaries_secs=240,
                            sync_optimizer=None,
                            saver=saver)
def train(train_config,
          train_dir,
          master,
          task=0,
          num_clones=1,
          worker_replicas=1,
          clone_on_cpu=False,
          ps_tasks=0,
          worker_job_name='lonely_worker',
          is_chief=True):
    """Training function for detection models.

    Args:
    train_config: configuration of parameters for model training.
    train_dir: Directory to write checkpoints and training summaries to.
    master: BNS name of the TensorFlow master to use.
    task: The task id of this training instance.
    num_clones: The number of clones to run per machine.
    worker_replicas: The number of work replicas to train with.
    clone_on_cpu: True if clones should be forced to run on CPU.
    ps_tasks: Number of parameter server tasks.
    worker_job_name: Name of the worker job.
    is_chief: Whether this replica is the chief replica.
    """

    with tf.Graph().as_default():
        # Build a configuration specifying multi-GPU and multi-replicas.
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=num_clones,
            clone_on_cpu=clone_on_cpu,
            replica_id=task,
            num_replicas=worker_replicas,
            num_ps_tasks=ps_tasks,
            worker_job_name=worker_job_name)

        # Place the global step on the device storing the variables.
        with tf.device(deploy_config.variables_device()):
            global_step = tf.train.create_global_step()

        with tf.device(deploy_config.inputs_device()):
            train_config.batch_size = train_config.batch_size // num_clones
            train_config['input_path'] = train_config.train_file_path
            input_queue = input_queue_generator(train_config)
        # Gather initial summaries.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
        global_summaries = set([])

        # get num of classes
        num_classes = len(get_label_map_dict(train_config.label_map_file))
        model_fn = partial(_create_losses,
                           num_classes=num_classes,
                           train_config=train_config)
        clones = model_deploy.create_clones(deploy_config, model_fn,
                                            [input_queue])
        first_clone_scope = clones[0].scope

        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by model_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        with tf.device(deploy_config.optimizer_device()):
            training_optimizer = _create_optimizer(train_config.optimizer,
                                                   train_config.lr,
                                                   train_config.decay_steps,
                                                   global_summaries)
        # Create ops required to initialize the model from a given checkpoint.
        init_fn = None
        if train_config.fine_tune_checkpoint:
            var_map = {}
            for variable in tf.all_variables():
                if variable.op.name.startswith("Retina_FPN"):
                    var_name = variable.op.name
                    var_map[var_name] = variable
            available_var_map = (
                variables_helper.get_variables_available_in_checkpoint(
                    var_map, train_config.fine_tune_checkpoint))
            init_saver = tf.train.Saver(available_var_map)

            def initializer_fn(sess):
                init_saver.restore(sess, train_config.fine_tune_checkpoint)

            init_fn = initializer_fn

        with tf.device(deploy_config.optimizer_device()):
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones, training_optimizer, regularization_losses=None)
            total_loss = tf.check_numerics(total_loss,
                                           'LossTensor is inf or nan.')

            # Optionally multiply bias gradients by train_config.bias_grad_multiplier.
            if train_config.bias_grad_multiplier:
                biases_regex_list = ['.*/biases']
                grads_and_vars = variables_helper.multiply_gradients_matching_regex(
                    grads_and_vars,
                    biases_regex_list,
                    multiplier=train_config.bias_grad_multiplier)

            # Optionally freeze some layers by setting their gradients to be zero.
            if train_config.freeze_variables:
                grads_and_vars = variables_helper.freeze_gradients_matching_regex(
                    grads_and_vars, train_config.freeze_variables)

            # Optionally clip gradients
            if train_config.gradient_clipping_by_norm > 0:
                with tf.name_scope('clip_grads'):
                    grads_and_vars = slim.learning.clip_gradient_norms(
                        grads_and_vars, train_config.gradient_clipping_by_norm)

            # Create gradient updates.
            grad_updates = training_optimizer.apply_gradients(
                grads_and_vars, global_step=global_step)
            update_ops.append(grad_updates)

            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        # Add summaries.
        for model_var in slim.get_model_variables():
            global_summaries.add(
                tf.summary.histogram(model_var.op.name, model_var))
        for loss_tensor in tf.losses.get_losses():
            global_summaries.add(
                tf.summary.scalar(loss_tensor.op.name, loss_tensor))
            global_summaries.add(
                tf.summary.scalar('TotalLoss', tf.losses.get_total_loss()))

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))
        summaries |= global_summaries

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # Soft placement allows placing on CPU ops without GPU implementation.
        gpu_memory_fraction = 0.8
        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=gpu_memory_fraction,
            allow_growth=True)
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False,
                                        gpu_options=gpu_options)

        # Save checkpoints regularly.
        keep_checkpoint_every_n_hours = train_config.keep_checkpoint_every_n_hours
        saver = tf.train.Saver(
            keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)

        slim.learning.train(train_tensor,
                            logdir=train_dir,
                            master=master,
                            is_chief=is_chief,
                            session_config=session_config,
                            startup_delay_steps=15,
                            init_fn=init_fn,
                            summary_op=summary_op,
                            number_of_steps=None,
                            save_summaries_secs=120,
                            sync_optimizer=None,
                            saver=saver)
Esempio n. 9
0
def train(create_tensor_dict_fn, create_model_fn, train_config, master, task,
          num_clones, worker_replicas, clone_on_cpu, ps_tasks, worker_job_name,
          is_chief, train_dir):
    """Training function for detection models.

  Args:
    create_tensor_dict_fn: a function to create a tensor input dictionary.
    create_model_fn: a function that creates a DetectionModel and generates
                     losses.
    train_config: a train_pb2.TrainConfig protobuf.
    master: BNS name of the TensorFlow master to use.
    task: The task id of this training instance.
    num_clones: The number of clones to run per machine.
    worker_replicas: The number of work replicas to train with.
    clone_on_cpu: True if clones should be forced to run on CPU.
    ps_tasks: Number of parameter server tasks.
    worker_job_name: Name of the worker job.
    is_chief: Whether this replica is the chief replica.
    train_dir: Directory to write checkpoints and training summaries to.
  """

    detection_model = create_model_fn()
    data_augmentation_options = [
        preprocessor_builder.build(step)
        for step in train_config.data_augmentation_options
    ]

    with tf.Graph().as_default():
        # Build a configuration specifying multi-GPU and multi-replicas.
        deploy_config = model_deploy.DeploymentConfig(
            num_clones=num_clones,
            clone_on_cpu=clone_on_cpu,
            replica_id=task,
            num_replicas=worker_replicas,
            num_ps_tasks=ps_tasks,
            worker_job_name=worker_job_name)

        # Place the global step on the device storing the variables.
        with tf.device(deploy_config.variables_device()):
            global_step = slim.create_global_step()

        with tf.device(deploy_config.inputs_device()):
            input_queue = create_input_queue(
                train_config.batch_size // num_clones, create_tensor_dict_fn,
                train_config.batch_queue_capacity,
                train_config.num_batch_queue_threads,
                train_config.prefetch_queue_capacity,
                data_augmentation_options)

        # Gather initial summaries.
        # TODO(rathodv): See if summaries can be added/extracted from global tf
        # collections so that they don't have to be passed around.
        summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
        global_summaries = set([])

        model_fn = functools.partial(_create_losses,
                                     create_model_fn=create_model_fn,
                                     train_config=train_config)
        clones = model_deploy.create_clones(deploy_config, model_fn,
                                            [input_queue])
        first_clone_scope = clones[0].scope

        # Gather update_ops from the first clone. These contain, for example,
        # the updates for the batch_norm variables created by model_fn.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                       first_clone_scope)

        with tf.device(deploy_config.optimizer_device()):
            training_optimizer = optimizer_builder.build(
                train_config.optimizer, global_summaries)

        sync_optimizer = None
        if train_config.sync_replicas:
            training_optimizer = tf.SyncReplicasOptimizer(
                training_optimizer,
                replicas_to_aggregate=train_config.replicas_to_aggregate,
                total_num_replicas=train_config.worker_replicas)
            sync_optimizer = training_optimizer

        # Create ops required to initialize the model from a given checkpoint.
        init_fn = None
        if train_config.fine_tune_checkpoint:
            var_map = detection_model.restore_map(
                from_detection_checkpoint=train_config.
                from_detection_checkpoint)
            available_var_map = (
                variables_helper.get_variables_available_in_checkpoint(
                    var_map, train_config.fine_tune_checkpoint))
            init_saver = tf.train.Saver(available_var_map)

            def initializer_fn(sess):
                init_saver.restore(sess, train_config.fine_tune_checkpoint)

            init_fn = initializer_fn

        with tf.device(deploy_config.optimizer_device()):
            total_loss, grads_and_vars = model_deploy.optimize_clones(
                clones, training_optimizer, regularization_losses=None)
            total_loss = tf.check_numerics(total_loss,
                                           'LossTensor is inf or nan.')

            # Optionally multiply bias gradients by train_config.bias_grad_multiplier.
            if train_config.bias_grad_multiplier:
                biases_regex_list = ['.*/biases']
                grads_and_vars = variables_helper.multiply_gradients_matching_regex(
                    grads_and_vars,
                    biases_regex_list,
                    multiplier=train_config.bias_grad_multiplier)

            # Optionally freeze some layers by setting their gradients to be zero.
            if train_config.freeze_variables:
                grads_and_vars = variables_helper.freeze_gradients_matching_regex(
                    grads_and_vars, train_config.freeze_variables)

            # Optionally clip gradients
            if train_config.gradient_clipping_by_norm > 0:
                with tf.name_scope('clip_grads'):
                    grads_and_vars = slim.learning.clip_gradient_norms(
                        grads_and_vars, train_config.gradient_clipping_by_norm)

            # Create gradient updates.
            grad_updates = training_optimizer.apply_gradients(
                grads_and_vars, global_step=global_step)
            update_ops.append(grad_updates)

            update_op = tf.group(*update_ops)
            with tf.control_dependencies([update_op]):
                train_tensor = tf.identity(total_loss, name='train_op')

        # Add summaries.
        for model_var in slim.get_model_variables():
            global_summaries.add(
                tf.summary.histogram(model_var.op.name, model_var))
        for loss_tensor in tf.losses.get_losses():
            global_summaries.add(
                tf.summary.scalar(loss_tensor.op.name, loss_tensor))
        global_summaries.add(
            tf.summary.scalar('TotalLoss', tf.losses.get_total_loss()))

        # Add the summaries from the first clone. These contain the summaries
        # created by model_fn and either optimize_clones() or _gather_clone_loss().
        summaries |= set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))
        summaries |= global_summaries

        # Merge all summaries together.
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        # Soft placement allows placing on CPU ops without GPU implementation.
        session_config = tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=False)

        # Save checkpoints regularly.
        keep_checkpoint_every_n_hours = train_config.keep_checkpoint_every_n_hours
        saver = tf.train.Saver(
            keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)

        slim.learning.train(
            train_tensor,
            logdir=train_dir,
            master=master,
            is_chief=is_chief,
            session_config=session_config,
            startup_delay_steps=train_config.startup_delay_steps,
            init_fn=init_fn,
            summary_op=summary_op,
            number_of_steps=(train_config.num_steps
                             if train_config.num_steps else None),
            save_summaries_secs=120,
            sync_optimizer=sync_optimizer,
            saver=saver)
  def model_fn(features, labels, mode, params=None):
    """Constructs the object detection model.

    Args:
      features: Dictionary of feature tensors, returned from `input_fn`.
      labels: Dictionary of groundtruth tensors if mode is TRAIN or EVAL,
        otherwise None.
      mode: Mode key from tf.estimator.ModeKeys.
      params: Parameter dictionary passed from the estimator.

    Returns:
      An `EstimatorSpec` that encapsulates the model and its serving
        configurations.
    """
    params = params or {}
    total_loss, train_op, detections, export_outputs = None, None, None, None
    is_training = mode == tf.estimator.ModeKeys.TRAIN
    detection_model = detection_model_fn(is_training=is_training,
                                         add_summaries=(not use_tpu))
    scaffold_fn = None

    if mode == tf.estimator.ModeKeys.TRAIN:
      labels = unstack_batch(
          labels,
          unpad_groundtruth_tensors=train_config.unpad_groundtruth_tensors)
    elif mode == tf.estimator.ModeKeys.EVAL:
      labels = unstack_batch(labels, unpad_groundtruth_tensors=False)

    if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
      gt_boxes_list = labels[fields.InputDataFields.groundtruth_boxes]
      gt_classes_list = labels[fields.InputDataFields.groundtruth_classes]
      gt_masks_list = None
      if fields.InputDataFields.groundtruth_instance_masks in labels:
        gt_masks_list = labels[
            fields.InputDataFields.groundtruth_instance_masks]
      gt_keypoints_list = None
      if fields.InputDataFields.groundtruth_keypoints in labels:
        gt_keypoints_list = labels[fields.InputDataFields.groundtruth_keypoints]
      detection_model.provide_groundtruth(
          groundtruth_boxes_list=gt_boxes_list,
          groundtruth_classes_list=gt_classes_list,
          groundtruth_masks_list=gt_masks_list,
          groundtruth_keypoints_list=gt_keypoints_list)

    preprocessed_images = features[fields.InputDataFields.image]
    prediction_dict = detection_model.predict(
        preprocessed_images, features[fields.InputDataFields.true_image_shape])
    detections = detection_model.postprocess(
        prediction_dict, features[fields.InputDataFields.true_image_shape])

    if mode == tf.estimator.ModeKeys.TRAIN:
      if train_config.fine_tune_checkpoint and hparams.load_pretrained:
        asg_map = detection_model.restore_map(
            from_detection_checkpoint=train_config.from_detection_checkpoint,
            load_all_detection_checkpoint_vars=(
                train_config.load_all_detection_checkpoint_vars))
        available_var_map = (
            variables_helper.get_variables_available_in_checkpoint(
                asg_map, train_config.fine_tune_checkpoint,
                include_global_step=False))
        if use_tpu:
          def tpu_scaffold():
            tf.train.init_from_checkpoint(train_config.fine_tune_checkpoint,
                                          available_var_map)
            return tf.train.Scaffold()
          scaffold_fn = tpu_scaffold
        else:
          tf.train.init_from_checkpoint(train_config.fine_tune_checkpoint,
                                        available_var_map)

    if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
      losses_dict = detection_model.loss(
          prediction_dict, features[fields.InputDataFields.true_image_shape])
      losses = [loss_tensor for loss_tensor in losses_dict.itervalues()]
      total_loss = tf.add_n(losses, name='total_loss')

    if mode == tf.estimator.ModeKeys.TRAIN:
      global_step = tf.train.get_or_create_global_step()
      training_optimizer, optimizer_summary_vars = optimizer_builder.build(
          train_config.optimizer)

      if use_tpu:
        training_optimizer = tpu_optimizer.CrossShardOptimizer(
            training_optimizer)

      # Optionally freeze some layers by setting their gradients to be zero.
      trainable_variables = None
      if train_config.freeze_variables:
        trainable_variables = tf.contrib.framework.filter_variables(
            tf.trainable_variables(),
            exclude_patterns=train_config.freeze_variables)

      clip_gradients_value = None
      if train_config.gradient_clipping_by_norm > 0:
        clip_gradients_value = train_config.gradient_clipping_by_norm

      if not use_tpu:
        for var in optimizer_summary_vars:
          tf.summary.scalar(var.op.name, var)
      summaries = [] if use_tpu else None
      train_op = tf.contrib.layers.optimize_loss(
          loss=total_loss,
          global_step=global_step,
          learning_rate=None,
          clip_gradients=clip_gradients_value,
          optimizer=training_optimizer,
          variables=trainable_variables,
          summaries=summaries,
          name='')  # Preventing scope prefix on all variables.

    if mode == tf.estimator.ModeKeys.PREDICT:
      export_outputs = {
          tf.saved_model.signature_constants.PREDICT_METHOD_NAME:
              tf.estimator.export.PredictOutput(detections)
      }

    eval_metric_ops = None
    if mode == tf.estimator.ModeKeys.EVAL:
      # Detection summaries during eval.
      class_agnostic = (fields.DetectionResultFields.detection_classes
                        not in detections)
      groundtruth = _get_groundtruth_data(detection_model, class_agnostic)
      eval_dict = eval_util.result_dict_for_single_example(
          tf.expand_dims(features[fields.InputDataFields.original_image][0], 0),
          features[inputs.HASH_KEY][0],
          detections,
          groundtruth,
          class_agnostic=class_agnostic,
          scale_to_absolute=False)

      if class_agnostic:
        category_index = label_map_util.create_class_agnostic_category_index()
      else:
        category_index = label_map_util.create_category_index_from_labelmap(
            eval_input_config.label_map_path)
      detection_and_groundtruth = vis_utils.draw_side_by_side_evaluation_image(
          eval_dict, category_index, max_boxes_to_draw=20, min_score_thresh=0.2)
      if not use_tpu:
        tf.summary.image('Detections_Left_Groundtruth_Right',
                         detection_and_groundtruth)

      # Eval metrics on a single image.
      detection_fields = fields.DetectionResultFields()
      input_data_fields = fields.InputDataFields()
      coco_evaluator = coco_evaluation.CocoDetectionEvaluator(
          category_index.values())
      eval_metric_ops = coco_evaluator.get_estimator_eval_metric_ops(
          image_id=eval_dict[input_data_fields.key],
          groundtruth_boxes=eval_dict[input_data_fields.groundtruth_boxes],
          groundtruth_classes=eval_dict[input_data_fields.groundtruth_classes],
          detection_boxes=eval_dict[detection_fields.detection_boxes],
          detection_scores=eval_dict[detection_fields.detection_scores],
          detection_classes=eval_dict[detection_fields.detection_classes])

    if use_tpu:
      return tf.contrib.tpu.TPUEstimatorSpec(
          mode=mode,
          scaffold_fn=scaffold_fn,
          predictions=detections,
          loss=total_loss,
          train_op=train_op,
          eval_metrics=eval_metric_ops,
          export_outputs=export_outputs)
    else:
      return tf.estimator.EstimatorSpec(
          mode=mode,
          predictions=detections,
          loss=total_loss,
          train_op=train_op,
          eval_metric_ops=eval_metric_ops,
          export_outputs=export_outputs)
Esempio n. 11
0
    def model_fn(features, labels, mode, params=None):
        """Constructs the object detection model.

    Args:
      features: Dictionary of feature tensors, returned from `input_fn`.
      labels: Dictionary of groundtruth tensors if mode is TRAIN or EVAL,
        otherwise None.
      mode: Mode key from tf.estimator.ModeKeys.
      params: Parameter dictionary passed from the estimator.

    Returns:
      An `EstimatorSpec` that encapsulates the model and its serving
        configurations.
    """
        params = params or {}
        total_loss, train_op, detections, export_outputs = None, None, None, None
        is_training = mode == tf.estimator.ModeKeys.TRAIN

        # Make sure to set the Keras learning phase. True during training,
        # False for inference.
        tf.keras.backend.set_learning_phase(is_training)
        detection_model = detection_model_fn(is_training=is_training,
                                             add_summaries=(not use_tpu))
        scaffold_fn = None

        if mode == tf.estimator.ModeKeys.TRAIN:
            labels = unstack_batch(labels,
                                   unpad_groundtruth_tensors=train_config.
                                   unpad_groundtruth_tensors)
        elif mode == tf.estimator.ModeKeys.EVAL:
            # For evaling on train data, it is necessary to check whether groundtruth
            # must be unpadded.
            boxes_shape = (labels[fields.InputDataFields.groundtruth_boxes].
                           get_shape().as_list())
            unpad_groundtruth_tensors = boxes_shape[
                1] is not None and not use_tpu
            labels = unstack_batch(
                labels, unpad_groundtruth_tensors=unpad_groundtruth_tensors)

        if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
            gt_boxes_list = labels[fields.InputDataFields.groundtruth_boxes]
            gt_classes_list = labels[
                fields.InputDataFields.groundtruth_classes]
            gt_masks_list = None
            if fields.InputDataFields.groundtruth_instance_masks in labels:
                gt_masks_list = labels[
                    fields.InputDataFields.groundtruth_instance_masks]
            gt_keypoints_list = None
            if fields.InputDataFields.groundtruth_keypoints in labels:
                gt_keypoints_list = labels[
                    fields.InputDataFields.groundtruth_keypoints]
            gt_weights_list = None
            if fields.InputDataFields.groundtruth_weights in labels:
                gt_weights_list = labels[
                    fields.InputDataFields.groundtruth_weights]
            gt_confidences_list = None
            if fields.InputDataFields.groundtruth_confidences in labels:
                gt_confidences_list = labels[
                    fields.InputDataFields.groundtruth_confidences]
            gt_is_crowd_list = None
            if fields.InputDataFields.groundtruth_is_crowd in labels:
                gt_is_crowd_list = labels[
                    fields.InputDataFields.groundtruth_is_crowd]
            detection_model.provide_groundtruth(
                groundtruth_boxes_list=gt_boxes_list,
                groundtruth_classes_list=gt_classes_list,
                groundtruth_confidences_list=gt_confidences_list,
                groundtruth_masks_list=gt_masks_list,
                groundtruth_keypoints_list=gt_keypoints_list,
                groundtruth_weights_list=gt_weights_list,
                groundtruth_is_crowd_list=gt_is_crowd_list)

        preprocessed_images = features[fields.InputDataFields.image]
        if use_tpu and train_config.use_bfloat16:
            with tf.contrib.tpu.bfloat16_scope():
                prediction_dict = detection_model.predict(
                    preprocessed_images,
                    features[fields.InputDataFields.true_image_shape])
                for k, v in prediction_dict.items():
                    if v.dtype == tf.bfloat16:
                        prediction_dict[k] = tf.cast(v, tf.float32)
        else:
            prediction_dict = detection_model.predict(
                preprocessed_images,
                features[fields.InputDataFields.true_image_shape])
        if mode in (tf.estimator.ModeKeys.EVAL, tf.estimator.ModeKeys.PREDICT):
            detections = detection_model.postprocess(
                prediction_dict,
                features[fields.InputDataFields.true_image_shape])

        if mode == tf.estimator.ModeKeys.TRAIN:
            if train_config.fine_tune_checkpoint and hparams.load_pretrained:
                if not train_config.fine_tune_checkpoint_type:
                    # train_config.from_detection_checkpoint field is deprecated. For
                    # backward compatibility, set train_config.fine_tune_checkpoint_type
                    # based on train_config.from_detection_checkpoint.
                    if train_config.from_detection_checkpoint:
                        train_config.fine_tune_checkpoint_type = 'detection'
                    else:
                        train_config.fine_tune_checkpoint_type = 'classification'
                asg_map = detection_model.restore_map(
                    fine_tune_checkpoint_type=train_config.
                    fine_tune_checkpoint_type,
                    load_all_detection_checkpoint_vars=(
                        train_config.load_all_detection_checkpoint_vars))
                available_var_map = (
                    variables_helper.get_variables_available_in_checkpoint(
                        asg_map,
                        train_config.fine_tune_checkpoint,
                        include_global_step=False))
                if use_tpu:

                    def tpu_scaffold():
                        tf.train.init_from_checkpoint(
                            train_config.fine_tune_checkpoint,
                            available_var_map)
                        return tf.train.Scaffold()

                    scaffold_fn = tpu_scaffold
                else:
                    tf.train.init_from_checkpoint(
                        train_config.fine_tune_checkpoint, available_var_map)

        if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
            losses_dict = detection_model.loss(
                prediction_dict,
                features[fields.InputDataFields.true_image_shape])
            losses = [loss_tensor for loss_tensor in losses_dict.values()]
            if train_config.add_regularization_loss:
                regularization_losses = detection_model.regularization_losses()
                if regularization_losses:
                    regularization_loss = tf.add_n(regularization_losses,
                                                   name='regularization_loss')
                    losses.append(regularization_loss)
                    losses_dict[
                        'Loss/regularization_loss'] = regularization_loss
            total_loss = tf.add_n(losses, name='total_loss')
            losses_dict['Loss/total_loss'] = total_loss

            if 'graph_rewriter_config' in configs:
                graph_rewriter_fn = graph_rewriter_builder.build(
                    configs['graph_rewriter_config'], is_training=is_training)
                graph_rewriter_fn()

            # TODO(rathodv): Stop creating optimizer summary vars in EVAL mode once we
            # can write learning rate summaries on TPU without host calls.
            global_step = tf.train.get_or_create_global_step()
            training_optimizer, optimizer_summary_vars = optimizer_builder.build(
                train_config.optimizer)

        if mode == tf.estimator.ModeKeys.TRAIN:
            if use_tpu:
                training_optimizer = tf.contrib.tpu.CrossShardOptimizer(
                    training_optimizer)

            # Optionally freeze some layers by setting their gradients to be zero.
            trainable_variables = None
            include_variables = (train_config.update_trainable_variables
                                 if train_config.update_trainable_variables
                                 else None)
            exclude_variables = (train_config.freeze_variables
                                 if train_config.freeze_variables else None)
            trainable_variables = tf.contrib.framework.filter_variables(
                tf.trainable_variables(),
                include_patterns=include_variables,
                exclude_patterns=exclude_variables)

            clip_gradients_value = None
            if train_config.gradient_clipping_by_norm > 0:
                clip_gradients_value = train_config.gradient_clipping_by_norm

            if not use_tpu:
                for var in optimizer_summary_vars:
                    tf.summary.scalar(var.op.name, var)
            summaries = [] if use_tpu else None
            if train_config.summarize_gradients:
                summaries = [
                    'gradients', 'gradient_norm', 'global_gradient_norm'
                ]
            train_op = tf.contrib.layers.optimize_loss(
                loss=total_loss,
                global_step=global_step,
                learning_rate=None,
                clip_gradients=clip_gradients_value,
                optimizer=training_optimizer,
                update_ops=detection_model.updates(),
                variables=trainable_variables,
                summaries=summaries,
                name='')  # Preventing scope prefix on all variables.

        if mode == tf.estimator.ModeKeys.PREDICT:
            exported_output = exporter_lib.add_output_tensor_nodes(detections)
            export_outputs = {
                tf.saved_model.signature_constants.PREDICT_METHOD_NAME:
                tf.estimator.export.PredictOutput(exported_output)
            }

        eval_metric_ops = None
        scaffold = None
        if mode == tf.estimator.ModeKeys.EVAL:
            class_agnostic = (fields.DetectionResultFields.detection_classes
                              not in detections)
            groundtruth = _prepare_groundtruth_for_eval(
                detection_model, class_agnostic,
                eval_input_config.max_number_of_boxes)
            use_original_images = fields.InputDataFields.original_image in features
            if use_original_images:
                eval_images = features[fields.InputDataFields.original_image]
                true_image_shapes = tf.slice(
                    features[fields.InputDataFields.true_image_shape], [0, 0],
                    [-1, 3])
                original_image_spatial_shapes = features[
                    fields.InputDataFields.original_image_spatial_shape]
            else:
                eval_images = features[fields.InputDataFields.image]
                true_image_shapes = None
                original_image_spatial_shapes = None

            eval_dict = eval_util.result_dict_for_batched_example(
                eval_images,
                features[inputs.HASH_KEY],
                detections,
                groundtruth,
                class_agnostic=class_agnostic,
                scale_to_absolute=True,
                original_image_spatial_shapes=original_image_spatial_shapes,
                true_image_shapes=true_image_shapes)

            if class_agnostic:
                category_index = label_map_util.create_class_agnostic_category_index(
                )
            else:
                category_index = label_map_util.create_category_index_from_labelmap(
                    eval_input_config.label_map_path)
            vis_metric_ops = None
            if not use_tpu and use_original_images:
                eval_metric_op_vis = vis_utils.VisualizeSingleFrameDetections(
                    category_index,
                    max_examples_to_draw=eval_config.num_visualizations,
                    max_boxes_to_draw=eval_config.max_num_boxes_to_visualize,
                    min_score_thresh=eval_config.min_score_threshold,
                    use_normalized_coordinates=False)
                vis_metric_ops = eval_metric_op_vis.get_estimator_eval_metric_ops(
                    eval_dict)

            # Eval metrics on a single example.
            eval_metric_ops = eval_util.get_eval_metric_ops_for_evaluators(
                eval_config, list(category_index.values()), eval_dict)
            for loss_key, loss_tensor in iter(losses_dict.items()):
                eval_metric_ops[loss_key] = tf.metrics.mean(loss_tensor)
            for var in optimizer_summary_vars:
                eval_metric_ops[var.op.name] = (var, tf.no_op())
            if vis_metric_ops is not None:
                eval_metric_ops.update(vis_metric_ops)
            eval_metric_ops = {str(k): v for k, v in eval_metric_ops.items()}

            if eval_config.use_moving_averages:
                variable_averages = tf.train.ExponentialMovingAverage(0.0)
                variables_to_restore = variable_averages.variables_to_restore()
                keep_checkpoint_every_n_hours = (
                    train_config.keep_checkpoint_every_n_hours)
                saver = tf.train.Saver(
                    variables_to_restore,
                    keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours
                )
                scaffold = tf.train.Scaffold(saver=saver)

        # EVAL executes on CPU, so use regular non-TPU EstimatorSpec.
        if use_tpu and mode != tf.estimator.ModeKeys.EVAL:
            return tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                scaffold_fn=scaffold_fn,
                predictions=detections,
                loss=total_loss,
                train_op=train_op,
                eval_metrics=eval_metric_ops,
                export_outputs=export_outputs)
        else:
            if scaffold is None:
                keep_checkpoint_every_n_hours = (
                    train_config.keep_checkpoint_every_n_hours)
                saver = tf.train.Saver(
                    sharded=True,
                    keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
                    save_relative_paths=True)
                tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
                scaffold = tf.train.Scaffold(saver=saver)
            return tf.estimator.EstimatorSpec(mode=mode,
                                              predictions=detections,
                                              loss=total_loss,
                                              train_op=train_op,
                                              eval_metric_ops=eval_metric_ops,
                                              export_outputs=export_outputs,
                                              scaffold=scaffold)