Beispiel #1
0
def build_model():
    """Builds graph for model to train with rewrites for quantization.

  Returns:
    g: Graph with fake quantization ops and batch norm folding suitable for
    training quantized weights.
    train_tensor: Train op for execution during training.
  """
    g = tf.Graph()
    with g.as_default(), tf.device(
            tf.train.replica_device_setter(FLAGS.ps_tasks)):
        inputs, labels, _ = get_batch(FLAGS.dataset_dir, FLAGS.batch_size)
        with slim.arg_scope(mobilenet_v2.training_scope()):
            logits, _ = mobilenet_v2.mobilenet_v2_050(
                inputs, num_classes=FLAGS.num_classes)
        labels = slim.one_hot_encoding(labels, FLAGS.num_classes)
        tf.losses.softmax_cross_entropy(labels, logits)

        # Call rewriter to produce graph with fake quant ops and folded batch norms
        # quant_delay delays start of quantization till quant_delay steps, allowing
        # for better model accuracy.
        if FLAGS.quantize:
            tf.contrib.quantize.create_training_graph(
                quant_delay=get_quant_delay())

        total_loss = tf.losses.get_total_loss(name='total_loss')
        # Configure the learning rate using an exponential decay.
        num_epochs_per_decay = 1
        imagenet_size = 51200
        decay_steps = int(imagenet_size / FLAGS.batch_size *
                          num_epochs_per_decay)

        learning_rate = tf.train.exponential_decay(
            get_learning_rate(),
            tf.train.get_or_create_global_step(),
            decay_steps,
            _LEARNING_RATE_DECAY_FACTOR,
            staircase=False)
        #opt = tf.train.GradientDescentOptimizer(learning_rate)
        opt = tf.train.RMSPropOptimizer(learning_rate, decay=0.9, momentum=0.9)
        train_tensor = slim.learning.create_train_op(total_loss, optimizer=opt)

    slim.summaries.add_scalar_summary(total_loss, 'total_loss', 'losses')
    slim.summaries.add_scalar_summary(learning_rate, 'learning_rate',
                                      'training')
    return g, train_tensor
Beispiel #2
0
def net(inputs,
        data_format='channels_last',
        depth_mult=1.0,
        VGG_PARAMS_FILE=None,
        is_train=False):

    if data_format != "channels_last":
        print('only works for channels last now')
        return None

    with tf.contrib.slim.arg_scope(
            mobilenet_v2.training_scope(is_training=is_train)):
        logits, endpoints = mobilenet_v2.mobilenet(inputs,
                                                   base_only=True,
                                                   reuse=tf.AUTO_REUSE,
                                                   final_endpoint="layer_19",
                                                   depth_multiplier=depth_mult)

    l15e = endpoints['layer_15/expansion_output']
    l19 = endpoints['layer_19']

    return [l15e, l19], endpoints
def build_model():
    """Build the mobilenet_v1 model for evaluation.

  Returns:
    g: graph with rewrites after insertion of quantization ops and batch norm
    folding.
    eval_ops: eval ops for inference.
    variables_to_restore: List of variables to restore from checkpoint.
  """
    g = tf.Graph()
    with g.as_default():
        inputs, labels, _ = get_batch(FLAGS.dataset_dir, FLAGS.batch_size)
        scope = mobilenet_v2.training_scope(is_training=False,
                                            weight_decay=0.0)
        with slim.arg_scope(scope):
            _, end_points = mobilenet_v2.mobilenet_v2_050(
                inputs, is_training=False, num_classes=FLAGS.num_classes)

        if FLAGS.quantize:
            tf.contrib.quantize.create_eval_graph()

        eval_ops = metrics(end_points['Predictions'], labels)

    return g, eval_ops
def test_model():
    errfile = '../_error/other_errors.txt'

    sess = tf.Session()

    inputs, labels, name_batch = get_batch(FLAGS.dataset_dir,
                                           FLAGS.batch_size,
                                           shuffle=False)
    scope = mobilenet_v2.training_scope(is_training=False, weight_decay=0.0)
    with slim.arg_scope(scope):
        logits, end_points = mobilenet_v2.mobilenet_v2_050(
            inputs, is_training=False, num_classes=FLAGS.num_classes)

    # evaluate model, for classification
    correct_pred = tf.equal(tf.argmax(end_points['Predictions'], 1), labels)
    acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    sess.run(init_op)
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    # saver for restore model
    saver = tf.train.Saver()
    print('[*] Try to load trained model...')
    ckpt_name = load(sess, saver, FLAGS.checkpoint_dir)

    step = 0
    accs = 0
    me_acc = 0
    errors_name = []
    max_steps = int(FLAGS.num_examples / FLAGS.batch_size)
    print('START TESTING...')
    try:
        while not coord.should_stop():
            for _step in range(step + 1, step + max_steps + 1):
                # test
                #_label, _logits, _points = sess.run([labels, logits, end_points])
                _name, _logits, _corr, _acc = sess.run(
                    [name_batch, logits, correct_pred, acc])
                if (~_corr).any():
                    errors_name.extend(list(_name[~_corr]))
                accs += _acc
                me_acc = accs / _step
                if _step % 20 == 0:
                    print(
                        time.strftime("%X"),
                        'global_step:{0}, current_acc:{1:.6f}'.format(
                            _step, me_acc))
    except tf.errors.OutOfRangeError:
        accuracy = 1 - len(errors_name) / FLAGS.num_examples
        print(time.strftime("%X"),
              'RESULT >>> current_acc:{0:.6f}'.format(accuracy))
        # print(errors_name)
        errorsfile = open(errfile, 'a')
        errorsfile.writelines('\n' + ckpt_name + '--' + str(accuracy))
        for err in errors_name:
            errorsfile.writelines('\n' + err.decode('utf-8'))
        errorsfile.close()
    finally:
        coord.request_stop()
        coord.join(threads)
        sess.close()
    print('FINISHED TESTING.')
Beispiel #5
0
    channels = tf.split(axis=2, num_or_size_splits=num_channels, value=image)
    for i in range(num_channels):
        channels[i] -= img_mean[i]
    return tf.concat(axis=2, values=channels)


image_raw_data = tf.placeholder(tf.string, None)
img_data = tf.image.decode_jpeg(image_raw_data)
image = tf.image.convert_image_dtype(img_data, dtype=tf.uint8)
img_show = image
image.set_shape([img_size, img_size, 3])
image = tf.to_float(image)
image = _mean_image_subtraction(image)
image = tf.expand_dims(image, [0])

with tf.contrib.slim.arg_scope(mnv2.training_scope(is_training=False)):
    logits, endpoint = mnv2.mobilenet(image, num_classes=4)
#logits=tf.Print(logits,[logits],'logits: ',summarize=32)
#logits=tf.sigmoid(logits)
#logits=tf.nn.softmax(logits)
#logits = tf.Print(logits, [logits], 'softmax: ', summarize=32)
saver = tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess, './output/chamo_3000.000000_0.002084/chamo.ckpt')
    while True:
        img, cur_id = get_sample(0)
        image_raw_data_jpg = tf.gfile.FastGFile('./re/chamo.jpg', 'rb').read()
        re = sess.run([logits, img_show],
                      feed_dict={image_raw_data: image_raw_data_jpg})
        print(re[0][0])
        show_img = re[1]
Beispiel #6
0
def model_fn(features, labels, mode):
    feature_dict = features
    labels = tf.reshape(feature_dict["labels"], [-1])

    if mode == tf.estimator.ModeKeys.TRAIN:
        is_training = True
    else:
        is_training = False

    if FLAGS.model == 'mobilenet':
        scope = mobilenet_v2.training_scope(is_training=is_training)
        with tf.contrib.slim.arg_scope(scope):
            net, end_points = mobilenet_v2.mobilenet(
                feature_dict["features"],
                is_training=is_training,
                num_classes=FLAGS.num_classes)
            end_points['embedding'] = end_points['global_pool']
    elif FLAGS.model == 'mobilefacenet':
        scope = mobilenet_v2.training_scope(is_training=is_training)
        with tf.contrib.slim.arg_scope(scope):
            net, end_points = mobilefacenet.mobilefacenet(
                feature_dict["features"],
                is_training=is_training,
                num_classes=FLAGS.num_classes)
    elif FLAGS.model == 'metric_learning':
        with slim.arg_scope(
                inception_v1.inception_v1_arg_scope(weight_decay=0.0)):
            net, end_points = proxy_metric_learning.metric_learning(
                feature_dict["features"],
                is_training=is_training,
                num_classes=FLAGS.num_classes)
    elif FLAGS.model == 'faceresnet':
        net, end_points = mobilefacenet.faceresnet(
            feature_dict["features"],
            is_training=is_training,
            num_classes=FLAGS.num_classes)
    elif FLAGS.model == 'cifar100':
        net, end_points = cifar.nin(feature_dict["features"],
                                    labels,
                                    is_training=is_training,
                                    num_classes=FLAGS.num_classes)
    else:
        raise ValueError("Unknown model %s" % FLAGS.model)

    small_embeddings = tf.squeeze(end_points['embedding'])
    logits = end_points['Logits']
    predictions = tf.cast(tf.argmax(logits, 1), dtype=tf.int32)

    if FLAGS.final_activation in ['soft_thresh', 'none']:
        abs_embeddings = tf.abs(small_embeddings)
    else:
        abs_embeddings = small_embeddings
    nnz_small_embeddings = tf.cast(tf.less(FLAGS.zero_threshold,
                                           abs_embeddings),
                                   dtype=tf.float32)
    small_sparsity = tf.reduce_sum(nnz_small_embeddings, axis=1)
    small_sparsity = tf.reduce_mean(small_sparsity)

    mean_nnz_col = tf.reduce_mean(nnz_small_embeddings, axis=0)
    sum_nnz_col = tf.reduce_sum(nnz_small_embeddings, axis=0)
    mean_flops_ub = tf.reduce_sum(sum_nnz_col *
                                  (sum_nnz_col - 1)) / (FLAGS.batch_size *
                                                        (FLAGS.batch_size - 1))
    mean_flops = tf.reduce_sum(mean_nnz_col * mean_nnz_col)

    l1_norm_row = tf.reduce_sum(abs_embeddings, axis=1)
    l1_norm_col = tf.reduce_mean(abs_embeddings, axis=0)

    mean_l1_norm = tf.reduce_mean(l1_norm_row)
    mean_flops_sur = tf.reduce_sum(l1_norm_col * l1_norm_col)

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions_dict = {
            'predictions': predictions,
            'true_labels': feature_dict["labels"],
            'true_label_texts': feature_dict["label_texts"],
            'small_embeddings': small_embeddings,
            'sparsity/small': small_sparsity,
            'sparsity/flops': mean_flops_ub,
            'filename': features["filename"]
        }
        return tf.estimator.EstimatorSpec(mode, predictions=predictions_dict)

    if FLAGS.model == 'mobilenet':
        cr_ent_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=logits, labels=labels)
    elif 'face' in FLAGS.model:
        cr_ent_loss = mobilefacenet.arcface_loss(logits=logits,
                                                 labels=labels,
                                                 out_num=FLAGS.num_classes)
    elif FLAGS.model == 'metric_learning':
        cr_ent_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=logits, labels=labels)
        # cr_ent_loss = mobilefacenet.arcface_loss(logits=logits, labels=labels, out_num=FLAGS.num_classes)
    elif FLAGS.model == 'cifar100':
        # cr_ent_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels)
        cr_ent_loss = tf.contrib.losses.metric_learning.triplet_semihard_loss(
            labels, small_embeddings, margin=0.1)
        # cr_ent_loss = mobilefacenet.arcface_loss(logits=logits,
        #     labels=labels, out_num=FLAGS.num_classes, m=0.0)
        # cr_ent_loss = triplet_semihard_loss(labels=labels, embeddings=small_embeddings,
        #     pairwise_distance=lambda embed: pairwise_distance_euclid(embed, squared=True), margin=0.3)
    else:
        raise ValueError('Unknown model %s' % FLAGS.model)
    cr_ent_loss = tf.reduce_mean(
        cr_ent_loss) + tf.losses.get_regularization_loss()

    ema = tf.train.ExponentialMovingAverage(decay=0.99)
    ema_op = ema.apply([mean_l1_norm])
    moving_l1_norm = ema.average(mean_l1_norm)

    global_step = tf.train.get_or_create_global_step()
    all_ops = [ema_op]

    l1_weight = tf.Variable(0.0, name='l1_weight', trainable=False)

    if FLAGS.l1_weighing_scheme == 'constant':
        l1_weight = FLAGS.l1_parameter
    elif FLAGS.l1_weighing_scheme == 'dynamic_1':
        l1_weight = FLAGS.l1_parameter / moving_l1_norm
        l1_weight = tf.stop_gradient(l1_weight)
        l1_weight = tf.train.piecewise_constant(x=global_step,
                                                boundaries=[5],
                                                values=[0.0, l1_weight])
    elif FLAGS.l1_weighing_scheme == 'dynamic_2':
        if FLAGS.sparsity_type == "flops_sur":
            update_lr = 1e-5
        else:
            update_lr = 1e-4
        update = update_lr * (FLAGS.l1_parameter - cr_ent_loss)
        assign_op = tf.assign(l1_weight, tf.nn.relu(l1_weight + update))
        all_ops.append(assign_op)
    elif FLAGS.l1_weighing_scheme == 'dynamic_3':
        update_lr = 1e-4
        global_step = tf.train.get_or_create_global_step()
        upper_bound = FLAGS.l1_parameter - (
            FLAGS.l1_parameter - 12.0) * tf.cast(global_step, tf.float32) / 5e5
        upper_bound = tf.cast(upper_bound, tf.float32)
        update = update_lr * tf.sign(upper_bound - cr_ent_loss)
        assign_op = tf.assign(l1_weight, l1_weight + tf.nn.relu(update))
        all_ops.append(assign_op)
    elif FLAGS.l1_weighing_scheme == 'dynamic_4':
        l1_weight = FLAGS.l1_parameter * tf.minimum(
            1.0, (tf.cast(global_step, tf.float32) / FLAGS.l1_p_steps)**2)
    elif FLAGS.l1_weighing_scheme is None:
        l1_weight = 0.0
    else:
        raise ValueError('Unknown l1_weighing_scheme %s' %
                         FLAGS.l1_weighing_scheme)

    if FLAGS.sparsity_type == 'l1_norm':
        sparsity_loss = l1_weight * mean_l1_norm
    elif FLAGS.sparsity_type == 'flops_sur':
        sparsity_loss = l1_weight * mean_flops_sur
    elif FLAGS.sparsity_type is None:
        sparsity_loss = 0.0
    else:
        raise ValueError("Unknown sparsity_type %d" % FLAGS.sparsity_type)

    total_loss = cr_ent_loss + sparsity_loss
    accuracy = tf.reduce_mean(
        tf.cast(tf.equal(predictions, labels), dtype=tf.float32))

    if mode == tf.estimator.ModeKeys.EVAL:
        global_accuracy = tf.metrics.mean(accuracy)
        global_small_sparsity = tf.metrics.mean(small_sparsity)
        global_flops = tf.metrics.mean(mean_flops_ub)
        global_l1_norm = tf.metrics.mean(mean_l1_norm)
        global_mean_flops_sur = tf.metrics.mean(mean_flops_sur)
        global_cr_ent_loss = tf.metrics.mean(cr_ent_loss)

        metrics = {
            'accuracy': global_accuracy,
            'sparsity/small': global_small_sparsity,
            'sparsity/flops': global_flops,
            'l1_norm': global_l1_norm,
            'l1_norm/flops_sur': global_mean_flops_sur,
            'loss/cr_ent_loss': global_cr_ent_loss,
        }
        return tf.estimator.EstimatorSpec(mode,
                                          loss=total_loss,
                                          eval_metric_ops=metrics)

    if mode == tf.estimator.ModeKeys.TRAIN:
        base_lrate = FLAGS.learning_rate
        learning_rate = tf.train.piecewise_constant(
            x=global_step,
            boundaries=[
                FLAGS.decay_step if FLAGS.decay_step is not None else int(1e7)
            ],
            values=[base_lrate, base_lrate / 10.0])

        tf.summary.image("input", feature_dict["features"], max_outputs=1)
        tf.summary.scalar('sparsity/small', small_sparsity)
        tf.summary.scalar('sparsity/flops', mean_flops_ub)
        tf.summary.scalar('accuracy', accuracy)
        tf.summary.scalar('l1_norm', mean_l1_norm)
        tf.summary.scalar('l1_norm/ema',
                          moving_l1_norm)  # Comment this for mom
        tf.summary.scalar('l1_norm/l1_weight', l1_weight)
        tf.summary.scalar('l1_norm/flops_sur', mean_flops_sur)
        tf.summary.scalar('learning_rate', learning_rate)
        tf.summary.scalar('loss/cr_ent_loss', cr_ent_loss)
        tf.summary.scalar(
            'sparsity/ratio', mean_flops * FLAGS.embedding_size /
            (small_sparsity * small_sparsity))
        try:
            tf.summary.scalar('loss/upper_bound', upper_bound)
        except NameError:
            print("Skipping 'upper_bound' summary")
        for variable in tf.trainable_variables():
            if 'soft_thresh' in variable.name:
                print('Adding summary for lambda')
                tf.summary.scalar('lambda', variable)

        # Histogram summaries
        # gamma = tf.get_default_graph().get_tensor_by_name('MobilenetV2/Conv_2/BatchNorm/gamma:0')
        # pre_relu = tf.get_default_graph().get_tensor_by_name(
        #       'MobilenetV2/Conv_2/BatchNorm/FusedBatchNorm:0')
        # pre_relu = tf.squeeze(pre_relu)
        # tf.summary.histogram('gamma', gamma)
        # tf.summary.histogram('pre_relu', pre_relu[:, 237])
        # tf.summary.histogram('small_activations', nnz_small_embeddings)
        # tf.summary.histogram('small_activations/log', tf.log(nnz_small_embeddings + 1e-10))
        # fl_sur_ratio = (mean_nnz_col * mean_nnz_col) / (l1_norm_col * l1_norm_col)
        # tf.summary.histogram('fl_sur_ratio', fl_sur_ratio)
        # tf.summary.histogram('fl_sur_ratio/log', tf.log(fl_sur_ratio + 1e-10))
        # l1_sur_ratio = (mean_nnz_col * mean_nnz_col) / l1_norm_col
        # tf.summary.histogram('l1_sur_ratio', l1_sur_ratio)
        # tf.summary.histogram('l1_sur_ratio/log', tf.log(l1_sur_ratio + 1e-10))

        if FLAGS.optimizer == 'mom':
            optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                                   momentum=FLAGS.momentum)
        elif FLAGS.optimizer == 'sgd':
            optimizer = tf.train.GradientDescentOptimizer(
                learning_rate=learning_rate)
        elif FLAGS.optimizer == 'adam':
            optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate,
                                               epsilon=0.0001)
        elif FLAGS.optimizer == 'rmsprop':
            optimizer = tf.train.RMSPropOptimizer(learning_rate=learning_rate,
                                                  epsilon=0.01,
                                                  momentum=FLAGS.momentum)
        else:
            raise ValueError("Unknown optimizer %s" % FLAGS.optimizer)

        # Make the optimizer distributed
        optimizer = hvd.DistributedOptimizer(optimizer)

        train_op = tf.contrib.slim.learning.create_train_op(
            total_loss,
            optimizer,
            global_step=tf.train.get_or_create_global_step())
        all_ops.append(train_op)
        merged = tf.group(*all_ops)

        return tf.estimator.EstimatorSpec(mode,
                                          loss=total_loss,
                                          train_op=merged)
Beispiel #7
0
def build_model(apply_or_model=False, apply_and_model=False):
    """Build test model and write model as pb file. 
    
    Args:
        apply_or_model, apply_and_model: whether to apply or/and model.
    """
    g = tf.Graph()
    with g.as_default(), tf.device(
            tf.train.replica_device_setter(FLAGS.ps_tasks)):
        anchors = anchor_generator.generate_anchors(**_anchors_figure)
        box_pred = box_predictor.SSDBoxPredictor(FLAGS.is_training,
                                                 FLAGS.num_classes,
                                                 box_code_size=4)
        batchnorm_updates_collections = (None if FLAGS.inplace_batchnorm_update
                                         else tf.GraphKeys.UPDATE_OPS)
        anchors = tf.convert_to_tensor(anchors,
                                       dtype=tf.float32,
                                       name='anchors')
        convert_ratio = tf.convert_to_tensor(_convert_ratio,
                                             tf.float32,
                                             name='convert_ratio')
        value_to_ratio = tf.convert_to_tensor(_value_to_ratio,
                                              tf.float32,
                                              name='convert_ratio')

        img_tensor = tf.placeholder(
            tf.float32,
            [1, FLAGS.original_image_height, FLAGS.original_image_width, 3],
            name='input_img')
        grid_size_tensor = tf.placeholder(tf.float32, [2], 'input_grid_size')
        preimg_batch, grid_points_tl = preprocess(img_tensor, grid_size_tensor,
                                                  FLAGS.image_size,
                                                  value_to_ratio,
                                                  apply_or_model)

        with slim.arg_scope([slim.batch_norm], is_training=(
            FLAGS.is_training and not FLAGS.freeze_batchnorm),
            updates_collections=batchnorm_updates_collections),\
            slim.arg_scope(
                mobilenet_v2.training_scope(is_training=None, bn_decay=0.997)):
            _, image_features = mobilenet_v2.mobilenet_base(
                preimg_batch,
                final_endpoint='layer_18',
                depth_multiplier=FLAGS.depth_multiplier,
                finegrain_classification_mode=True)
            feature_maps = feature_map_generator.pooling_pyramid_feature_maps(
                base_feature_map_depth=0,
                num_layers=2,
                image_features={'image_features': image_features['layer_18']})
            pred_dict = box_pred.predict(feature_maps.values(), [1, 1])
            box_encodings = tf.concat(pred_dict['box_encodings'], axis=1)
            if box_encodings.shape.ndims == 4 and box_encodings.shape[2] == 1:
                box_encodings = tf.squeeze(box_encodings, axis=2)
            class_predictions_with_background = tf.concat(
                pred_dict['class_predictions_with_background'], axis=1)
        detection_boxes, detection_scores = postprocess(
            anchors,
            box_encodings,
            class_predictions_with_background,
            convert_ratio,
            grid_points_tl,
            num_classes=FLAGS.num_classes,
            score_threshold=FLAGS.score_threshold,
            apply_and_model=apply_and_model)
        input_boxes = tf.placeholder_with_default(detection_boxes[:1],
                                                  [None, 4],
                                                  name='input_boxes')
        if apply_or_model or apply_and_model:
            return g, img_tensor, input_boxes, detection_boxes, detection_scores
        num_batch = shape_utils.combined_static_and_dynamic_shape(input_boxes)
        input_scores = tf.tile([0.7], [num_batch[0]])
        total_boxes = tf.concat([detection_boxes, input_boxes], 0)
        total_scores = tf.concat([detection_scores, input_scores], 0)
        result_dict = non_max_suppression(
            total_boxes,
            total_scores,
            max_output_size=FLAGS.max_output_size,
            iou_threshold=FLAGS.iou_threshold)

        output_node_names = [
            'Non_max_suppression/result_boxes',
            'Non_max_suppression/result_scores',
            'Non_max_suppression/abnormal_indices',
            'Non_max_suppression/abnormal_inter_idx',
            'Non_max_suppression/abnormal_inter'
        ]
        init_op = tf.global_variables_initializer()
        with tf.Session() as sess:
            sess.run(init_op)
            # saver for restore model
            saver = tf.train.Saver()
            print('[*] Try to load trained model...')
            ckpt_name = load(sess, saver, FLAGS.checkpoint_dir)
            write_pb_model(FLAGS.checkpoint_dir + ckpt_name + '.pb', sess,
                           g.as_graph_def(), output_node_names)
Beispiel #8
0
def build_model():
    g = tf.Graph()
    with g.as_default(), tf.device(
            tf.train.replica_device_setter(FLAGS.ps_tasks)):
        matcher = argmax_matcher.ArgMaxMatcher(matched_threshold=0.5,
                                               unmatched_threshold=0.4,
                                               force_match_for_each_row=True)
        anchors = anchor_generator.generate_anchors(**_anchors_figure)
        box_pred = box_predictor.SSDBoxPredictor(FLAGS.is_training,
                                                 FLAGS.num_classes,
                                                 box_code_size=4)
        batchnorm_updates_collections = (None if FLAGS.inplace_batchnorm_update
                                         else tf.GraphKeys.UPDATE_OPS)
        with tf.variable_scope('inputs'):
            img_batch, bbox_batch, bbox_num_batch, _ = get_batch(
                FLAGS.dataset_dir, FLAGS.batch_size)
            img_batch = tf.cast(img_batch, tf.float32) / 127.5 - 1
            img_batch = tf.identity(img_batch, name='gt_imgs')
            bbox_list = []
            for i in range(FLAGS.batch_size):
                gt_boxes = tf.identity(bbox_batch[i][:bbox_num_batch[i]],
                                       name='gt_boxes')
                bbox_list.append(gt_boxes)
            anchors = tf.convert_to_tensor(anchors,
                                           dtype=tf.float32,
                                           name='anchors')
        with slim.arg_scope([slim.batch_norm],
                is_training=(FLAGS.is_training and not FLAGS.freeze_batchnorm),
                updates_collections=batchnorm_updates_collections),\
            slim.arg_scope(
                mobilenet_v2.training_scope(is_training=None, bn_decay=0.997)):
            _, image_features = mobilenet_v2.mobilenet_base(
                img_batch,
                final_endpoint='layer_18',
                depth_multiplier=FLAGS.depth_multiplier,
                finegrain_classification_mode=True)

            feature_maps = feature_map_generator.pooling_pyramid_feature_maps(
                base_feature_map_depth=0,
                num_layers=2,
                image_features={'image_features': image_features['layer_18']})

            pred_dict = box_pred.predict(feature_maps.values(), [1, 1])
            box_encodings = tf.concat(pred_dict['box_encodings'], axis=1)
            if box_encodings.shape.ndims == 4 and box_encodings.shape[2] == 1:
                box_encodings = tf.squeeze(box_encodings, axis=2)
            class_predictions_with_background = tf.concat(
                pred_dict['class_predictions_with_background'], axis=1)

        losses_dict = loss_op.loss(box_encodings,
                                   class_predictions_with_background,
                                   bbox_list,
                                   anchors,
                                   matcher,
                                   random_example=False)
        for loss_tensor in losses_dict.values():
            tf.losses.add_loss(loss_tensor)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

        # Configure the learning rate using an exponential decay.
        num_epochs_per_decay = 3
        imagenet_size = 10240
        decay_steps = int(imagenet_size / FLAGS.batch_size *
                          num_epochs_per_decay)

        learning_rate = tf.train.exponential_decay(
            FLAGS.learning_rate,
            tf.train.get_or_create_global_step(),
            decay_steps,
            _LEARNING_RATE_DECAY_FACTOR,
            staircase=True)

        opt = tf.train.AdamOptimizer(learning_rate)

        total_losses = []
        cls_loc_losses = tf.get_collection(tf.GraphKeys.LOSSES)
        cls_loc_loss = tf.add_n(cls_loc_losses, name='cls_loc_loss')
        total_losses.append(cls_loc_loss)
        regularization_losses = tf.get_collection(
            tf.GraphKeys.REGULARIZATION_LOSSES)
        regularization_loss = tf.add_n(regularization_losses,
                                       name='regularization_loss')
        total_losses.append(regularization_loss)
        total_loss = tf.add_n(total_losses, name='total_loss')

        grads_and_vars = opt.compute_gradients(total_loss)

        total_loss = tf.check_numerics(total_loss, 'LossTensor is inf or nan.')
        grad_updates = opt.apply_gradients(
            grads_and_vars, global_step=tf.train.get_or_create_global_step())
        update_ops.append(grad_updates)
        update_op = tf.group(*update_ops, name='update_barrier')
        with tf.control_dependencies([update_op]):
            train_tensor = tf.identity(total_loss, name='train_op')

    slim.summaries.add_scalar_summary(cls_loc_loss, 'cls_loc_loss', 'losses')
    slim.summaries.add_scalar_summary(regularization_loss,
                                      'regularization_loss', 'losses')
    slim.summaries.add_scalar_summary(total_loss, 'total_loss', 'losses')
    slim.summaries.add_scalar_summary(learning_rate, 'learning_rate',
                                      'training')
    return g, train_tensor