def lighr_head_model_fn(features, labels, mode, params):
    """Our model_fn for ResNet to be used with our Estimator."""
    num_anchors_list = labels['num_anchors_list']
    num_feature_layers = len(num_anchors_list)

    shape = labels['targets'][-1]
    glabels = labels['targets'][:num_feature_layers][0]
    gtargets = labels['targets'][num_feature_layers:2 * num_feature_layers][0]
    gscores = labels['targets'][2 * num_feature_layers:3 *
                                num_feature_layers][0]

    #features = tf.ones([4,480,480,3]) * 0.5
    with tf.variable_scope(params['model_scope'],
                           default_name=None,
                           values=[features],
                           reuse=tf.AUTO_REUSE):
        rpn_feat_map, backbone_feat = xception_body.XceptionBody(
            features,
            params['num_classes'],
            is_training=(mode == tf.estimator.ModeKeys.TRAIN),
            data_format=params['data_format'])
        #rpn_feat_map = tf.Print(rpn_feat_map,[tf.shape(rpn_feat_map), rpn_feat_map,backbone_feat])
        rpn_cls_score, rpn_bbox_pred = xception_body.get_rpn(
            rpn_feat_map, num_anchors_list[0],
            (mode == tf.estimator.ModeKeys.TRAIN), params['data_format'],
            'rpn_head')

        large_sep_feature = xception_body.large_sep_kernel(
            backbone_feat, 256, 10 * 7 * 7,
            (mode == tf.estimator.ModeKeys.TRAIN), params['data_format'],
            'large_sep_feature')

        if params['data_format'] == 'channels_first':
            rpn_cls_score = tf.transpose(rpn_cls_score, [0, 2, 3, 1])
            rpn_bbox_pred = tf.transpose(rpn_bbox_pred, [0, 2, 3, 1])

        rpn_cls_score = tf.reshape(rpn_cls_score, [-1, 2])
        rpn_object_score = tf.nn.softmax(rpn_cls_score)[:, -1]

        #with tf.device('/cpu:0'):
        rpn_object_score = tf.reshape(rpn_object_score,
                                      [params['batch_size'], -1])
        rpn_location_pred = tf.reshape(rpn_bbox_pred,
                                       [params['batch_size'], -1, 4])

        #rpn_location_pred = tf.Print(rpn_location_pred,[tf.shape(rpn_location_pred), rpn_location_pred])

        rpn_bboxes_pred = labels['rpn_decode_fn'](rpn_location_pred)

        #rpn_bboxes_pred = tf.Print(rpn_bboxes_pred,[tf.shape(rpn_bboxes_pred), rpn_bboxes_pred])
        # rpn loss here
        cls_pred = tf.reshape(rpn_cls_score, [-1, 2])
        location_pred = tf.reshape(rpn_bbox_pred, [-1, 4])
        glabels = tf.reshape(glabels, [-1])
        gscores = tf.reshape(gscores, [-1])
        gtargets = tf.reshape(gtargets, [-1, 4])

        expected_num_fg_rois = tf.cast(
            tf.round(
                tf.cast(params['batch_size'] * params['rpn_anchors_per_image'],
                        tf.float32) * params['rpn_fg_ratio']), tf.int32)

        def select_samples(cls_pred, location_pred, glabels, gscores,
                           gtargets):
            def upsampel_impl(now_count, need_count):
                # sample with replacement
                left_count = need_count - now_count
                select_indices = tf.random_shuffle(
                    tf.range(now_count))[:tf.floormod(left_count, now_count)]
                select_indices = tf.concat([
                    tf.tile(tf.range(now_count),
                            [tf.floor_div(left_count, now_count) + 1]),
                    select_indices
                ],
                                           axis=0)

                return select_indices

            def downsample_impl(now_count, need_count):
                # downsample with replacement
                select_indices = tf.random_shuffle(
                    tf.range(now_count))[:need_count]
                return select_indices

            positive_mask = glabels > 0
            positive_indices = tf.squeeze(tf.where(positive_mask), axis=-1)
            n_positives = tf.shape(positive_indices)[0]
            # either downsample or take all
            fg_select_indices = tf.cond(
                n_positives < expected_num_fg_rois, lambda: positive_indices,
                lambda: tf.gather(
                    positive_indices,
                    downsample_impl(n_positives, expected_num_fg_rois)))
            # now the all rois taken as positive is min(n_positives, expected_num_fg_rois)

            #negtive_mask = tf.logical_and(tf.logical_and(tf.logical_not(tf.logical_or(positive_mask, glabels < 0)), gscores < params['rpn_neg_threshold']), gscores > 0.)
            negtive_mask = tf.equal(
                glabels,
                0)  #tf.logical_and(tf.equal(glabels, 0), gscores > 0.)
            negtive_indices = tf.squeeze(tf.where(negtive_mask), axis=-1)
            n_negtives = tf.shape(negtive_indices)[0]

            expected_num_bg_rois = params[
                'batch_size'] * params['rpn_anchors_per_image'] - tf.minimum(
                    n_positives, expected_num_fg_rois)
            # either downsample or take all
            bg_select_indices = tf.cond(
                n_negtives < expected_num_bg_rois, lambda: negtive_indices,
                lambda: tf.gather(
                    negtive_indices,
                    downsample_impl(n_negtives, expected_num_bg_rois)))
            # now the all rois taken as positive is min(n_negtives, expected_num_bg_rois)

            keep_indices = tf.concat([fg_select_indices, bg_select_indices],
                                     axis=0)
            n_keeps = tf.shape(keep_indices)[0]
            # now n_keeps must be equal or less than rpn_anchors_per_image
            final_keep_indices = tf.cond(
                n_keeps <
                params['batch_size'] * params['rpn_anchors_per_image'],
                lambda: tf.gather(
                    keep_indices,
                    upsampel_impl(
                        n_keeps, params['batch_size'] * params[
                            'rpn_anchors_per_image'])), lambda: keep_indices)

            return tf.gather(cls_pred, final_keep_indices), tf.gather(
                location_pred, final_keep_indices), tf.cast(
                    tf.gather(
                        tf.clip_by_value(glabels, 0, params['num_classes']),
                        final_keep_indices) > 0,
                    tf.int64), tf.gather(gscores,
                                         final_keep_indices), tf.gather(
                                             gtargets, final_keep_indices)

        cls_pred, location_pred, glabels, gscores, gtargets = select_samples(
            cls_pred, location_pred, glabels, gscores, gtargets)

        # Calculate loss, which includes softmax cross entropy and L2 regularization.
        rpn_cross_entropy = tf.losses.sparse_softmax_cross_entropy(
            labels=glabels, logits=cls_pred)

        # Create a tensor named cross_entropy for logging purposes.
        rpn_cross_entropy = tf.identity(rpn_cross_entropy,
                                        name='rpn_cross_entropy_loss')
        tf.summary.scalar('rpn_cross_entropy_loss', rpn_cross_entropy)

        total_positive_mask = (glabels > 0)
        gtargets = tf.boolean_mask(gtargets,
                                   tf.stop_gradient(total_positive_mask))
        location_pred = tf.boolean_mask(location_pred,
                                        tf.stop_gradient(total_positive_mask))
        #gtargets = tf.Print(gtargets, [gtargets], message='gtargets:', summarize=100)

        rpn_l1_distance = modified_smooth_l1(location_pred, gtargets, sigma=1.)
        rpn_loc_loss = tf.reduce_mean(tf.reduce_sum(
            rpn_l1_distance, axis=-1)) / params['rpn_fg_ratio']
        rpn_loc_loss = tf.identity(rpn_loc_loss, name='rpn_location_loss')
        tf.summary.scalar('rpn_location_loss', rpn_loc_loss)
        tf.losses.add_loss(rpn_loc_loss)

        rpn_loss = tf.identity(rpn_loc_loss + rpn_cross_entropy,
                               name='rpn_loss')
        tf.summary.scalar('rpn_loss', rpn_loss)
        #print(rpn_loc_loss)

        proposals_bboxes, proposals_targets, proposals_labels, proposals_scores = xception_body.get_proposals(
            rpn_object_score, rpn_bboxes_pred, labels['rpn_encode_fn'],
            params['rpn_pre_nms_top_n'], params['rpn_post_nms_top_n'],
            params['rpn_nms_thres'], params['rpn_min_size'],
            (mode == tf.estimator.ModeKeys.TRAIN), params['data_format'])

        #proposals_targets = tf.Print(proposals_targets, [proposals_targets], message='proposals_targets0:')
        def head_loss_func(cls_score, bboxes_reg, select_indices,
                           proposals_targets, proposals_labels):
            if select_indices is not None:
                proposals_targets = tf.gather(proposals_targets,
                                              select_indices,
                                              axis=1)
                proposals_labels = tf.gather(proposals_labels,
                                             select_indices,
                                             axis=1)
            # Calculate loss, which includes softmax cross entropy and L2 regularization.
            head_cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=proposals_labels, logits=cls_score)

            total_positive_mask = tf.cast((proposals_labels > 0), tf.float32)
            # proposals_targets = tf.boolean_mask(proposals_targets, tf.stop_gradient(total_positive_mask))
            # bboxes_reg = tf.boolean_mask(bboxes_reg, tf.stop_gradient(total_positive_mask))
            head_loc_loss = modified_smooth_l1(bboxes_reg,
                                               proposals_targets,
                                               sigma=1.)
            head_loc_loss = tf.reduce_sum(head_loc_loss,
                                          axis=-1) * total_positive_mask
            if (params['using_ohem'] and
                (select_indices is not None)) or (not params['using_ohem']):
                head_cross_entropy_loss = tf.reduce_mean(head_cross_entropy)
                head_cross_entropy_loss = tf.identity(
                    head_cross_entropy_loss, name='head_cross_entropy_loss')
                tf.summary.scalar('head_cross_entropy_loss',
                                  head_cross_entropy_loss)

                head_location_loss = tf.reduce_mean(
                    head_loc_loss) / params['fg_ratio']
                head_location_loss = tf.identity(head_location_loss,
                                                 name='head_location_loss')
                tf.summary.scalar('head_location_loss', head_location_loss)

            return head_cross_entropy + head_loc_loss / params['fg_ratio']

        head_loss = xception_body.get_head(
            large_sep_feature,
            lambda input_, bboxes_, grid_width_, grid_height_: ps_roi_align(
                input_, bboxes_, grid_width_, grid_height_, pool_method), 7,
            7, lambda cls, bbox, indices: head_loss_func(
                cls, bbox, indices, proposals_targets, proposals_labels),
            proposals_bboxes, params['num_classes'],
            (mode == tf.estimator.ModeKeys.TRAIN), params['using_ohem'],
            params['ohem_roi_one_image'], params['data_format'], 'final_head')

        # Create a tensor named cross_entropy for logging purposes.
        head_loss = tf.identity(head_loss, name='head_loss')
        tf.summary.scalar('head_loss', head_loss)

        tf.losses.add_loss(head_loss)

    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode=mode, predictions=None)

    # Add weight decay to the loss. We exclude the batch norm variables because
    # doing so leads to a small improvement in accuracy.
    loss = rpn_cross_entropy + rpn_loc_loss + head_loss + params[
        'weight_decay'] * tf.add_n([
            tf.nn.l2_loss(v) for v in tf.trainable_variables() if
            (('batch_normalization' not in v.name) and ('_bn' not in v.name))
        ])  #_bn
    total_loss = tf.identity(loss, name='total_loss')

    if mode == tf.estimator.ModeKeys.TRAIN:
        global_step = tf.train.get_or_create_global_step()

        lr_values = [
            params['learning_rate'] * decay
            for decay in params['lr_decay_factors']
        ]
        learning_rate = tf.train.piecewise_constant(
            tf.cast(global_step, tf.int32),
            [int(_) for _ in params['decay_boundaries']], lr_values)
        truncated_learning_rate = tf.maximum(
            learning_rate,
            tf.constant(params['end_learning_rate'],
                        dtype=learning_rate.dtype))
        # Create a tensor named learning_rate for logging purposes.
        tf.identity(truncated_learning_rate, name='learning_rate')
        tf.summary.scalar('learning_rate', truncated_learning_rate)

        optimizer = tf.train.MomentumOptimizer(
            learning_rate=truncated_learning_rate, momentum=params['momentum'])

        # Batch norm requires update_ops to be added as a train_op dependency.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = optimizer.minimize(loss, global_step)
    else:
        train_op = None

    return tf.estimator.EstimatorSpec(
        mode=mode,
        predictions=None,
        loss=loss,
        train_op=train_op,
        eval_metric_ops=None,
        scaffold=tf.train.Scaffold(
            init_fn=train_helper.get_init_fn_for_scaffold(FLAGS)))
Example #2
0
def keypoint_model_fn(features, labels, mode, params):
    targets = labels['targets']
    shape = labels['shape']
    classid = labels['classid']
    key_v = labels['key_v']
    isvalid = labels['isvalid']
    norm_value = labels['norm_value']

    cur_batch_size = tf.shape(features)[0]

    with tf.variable_scope(params['model_scope'], default_name=None, values=[features], reuse=tf.AUTO_REUSE):
        pred_outputs = hg.create_model(features, params['num_stacks'], params['feats_channals'],
                            config.class_num_joints[(params['model_scope'] if 'all' not in params['model_scope'] else '*')], params['num_modules'],
                            (mode == tf.estimator.ModeKeys.TRAIN), params['data_format'])

    if params['data_format'] == 'channels_last':
        pred_outputs = [tf.transpose(pred_outputs[ind], [0, 3, 1, 2], name='outputs_trans_{}'.format(ind)) for ind in list(range(len(pred_outputs)))]

    score_map = pred_outputs[-1]

    pred_x, pred_y = get_keypoint(features, targets, score_map, params['heatmap_size'], params['train_image_size'], params['train_image_size'], (params['model_scope'] if 'all' not in params['model_scope'] else '*'), clip_at_zero=True, data_format=params['data_format'])

    # this is important!!!
    targets = 255. * targets
    #with tf.control_dependencies([pred_x, pred_y]):
    ne_mertric = mertric.normalized_error(targets, score_map, norm_value, key_v, isvalid,
                             cur_batch_size,
                             config.class_num_joints[(params['model_scope'] if 'all' not in params['model_scope'] else '*')],
                             params['heatmap_size'],
                             params['train_image_size'])

    # last_pred_mse = tf.metrics.mean_squared_error(score_map, targets,
    #                             weights=1.0 / tf.cast(cur_batch_size, tf.float32),
    #                             name='last_pred_mse')

    all_visible = tf.logical_and(key_v>0, isvalid>0)
    targets = tf.boolean_mask(targets, all_visible)
    pred_outputs = [tf.boolean_mask(pred_outputs[ind], all_visible, name='boolean_mask_{}'.format(ind)) for ind in list(range(len(pred_outputs)))]

    sq_diff = tf.reduce_sum(tf.squared_difference(targets, pred_outputs[-1]), axis=-1)
    last_pred_mse = tf.metrics.mean_absolute_error(sq_diff, tf.zeros_like(sq_diff), name='last_pred_mse')

    metrics = {'normalized_error': ne_mertric, 'last_pred_mse':last_pred_mse}
    predictions = {'normalized_error': ne_mertric[1]}
    ne_mertric = tf.identity(ne_mertric[1], name='ne_mertric')


    mse_loss_list = []
    for pred_ind in list(range(len(pred_outputs))):
        mse_loss_list.append(tf.losses.mean_squared_error(targets, pred_outputs[pred_ind],
                            weights=1.0 / tf.cast(cur_batch_size, tf.float32),
                            scope='loss_{}'.format(pred_ind),
                            loss_collection=None,#tf.GraphKeys.LOSSES,
                            reduction=tf.losses.Reduction.MEAN))# SUM, SUM_OVER_BATCH_SIZE, default mean by all elements

    mse_loss = tf.multiply(params['mse_weight'], tf.add_n(mse_loss_list), name='mse_loss')
    tf.summary.scalar('mse', mse_loss)
    tf.losses.add_loss(mse_loss)

    # bce_loss_list = []
    # for pred_ind in list(range(len(pred_outputs))):
    #     bce_loss_list.append(tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=pred_outputs[pred_ind], labels=targets, name='loss_{}'.format(pred_ind)), name='loss_mean_{}'.format(pred_ind)))

    # mse_loss = tf.multiply(params['mse_weight'] / params['num_stacks'], tf.add_n(bce_loss_list), name='mse_loss')
    # tf.summary.scalar('mse', mse_loss)
    # tf.losses.add_loss(mse_loss)

    # Add weight decay to the loss. We exclude the batch norm variables because
    # doing so leads to a small improvement in accuracy.
    loss = mse_loss + params['weight_decay'] * tf.add_n(
                              [tf.nn.l2_loss(v) for v in tf.trainable_variables()
                               if 'batch_normalization' not in v.name])
    total_loss = tf.identity(loss, name='total_loss')
    tf.summary.scalar('loss', total_loss)

    if mode == tf.estimator.ModeKeys.EVAL:
        return tf.estimator.EstimatorSpec(mode=mode, loss=loss, predictions=predictions, eval_metric_ops=metrics)

    if mode == tf.estimator.ModeKeys.TRAIN:
        global_step = tf.train.get_or_create_global_step()

        lr_values = [params['warmup_learning_rate']] + [params['learning_rate'] * decay for decay in params['lr_decay_factors']]
        learning_rate = tf.train.piecewise_constant(tf.cast(global_step, tf.int32),
                                                    [params['warmup_steps']] + [int(float(ep)*params['steps_per_epoch']) for ep in params['decay_boundaries']],
                                                    lr_values)
        truncated_learning_rate = tf.maximum(learning_rate, tf.constant(params['end_learning_rate'], dtype=learning_rate.dtype), name='learning_rate')
        tf.summary.scalar('lr', truncated_learning_rate)

        optimizer = tf.train.MomentumOptimizer(learning_rate=truncated_learning_rate,
                                                momentum=params['momentum'])

        # Batch norm requires update_ops to be added as a train_op dependency.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = optimizer.minimize(loss, global_step)
    else:
        train_op = None

    return tf.estimator.EstimatorSpec(
                          mode=mode,
                          predictions=predictions,
                          loss=loss,
                          train_op=train_op,
                          eval_metric_ops=metrics,
                          scaffold=tf.train.Scaffold(init_fn=train_helper.get_init_fn_for_scaffold(FLAGS)))
Example #3
0
def xdet_model_fn(features, labels, mode, params):
    """Our model_fn for ResNet to be used with our Estimator."""
    num_anchors_list = labels['num_anchors_list']
    num_feature_layers = len(num_anchors_list)

    shape = labels['targets'][-1]
    glabels = labels['targets'][:num_feature_layers][0]
    gtargets = labels['targets'][num_feature_layers:2 * num_feature_layers][0]
    gscores = labels['targets'][2 * num_feature_layers:3 *
                                num_feature_layers][0]

    with tf.variable_scope(params['model_scope'],
                           default_name=None,
                           values=[features],
                           reuse=tf.AUTO_REUSE):
        backbone = xdet_body_v2.xdet_resnet_v2(params['resnet_size'],
                                               params['data_format'])
        body_cls_output, body_regress_output = backbone(
            inputs=features, is_training=(mode == tf.estimator.ModeKeys.TRAIN))

        cls_pred, location_pred = xdet_body_v2.xdet_head(
            body_cls_output,
            body_regress_output,
            params['num_classes'],
            num_anchors_list[0], (mode == tf.estimator.ModeKeys.TRAIN),
            data_format=params['data_format'])

    if params['data_format'] == 'channels_first':
        cls_pred = tf.transpose(cls_pred, [0, 2, 3, 1])
        location_pred = tf.transpose(location_pred, [0, 2, 3, 1])

    bboxes_pred = labels['decode_fn'](
        location_pred
    )  #(tf.reshape(location_pred, tf.shape(location_pred).as_list()[0:-1] + [-1, 4]))

    cls_pred = tf.reshape(cls_pred, [-1, params['num_classes']])
    location_pred = tf.reshape(location_pred, [-1, 4])
    glabels = tf.reshape(glabels, [-1])
    gscores = tf.reshape(gscores, [-1])
    gtargets = tf.reshape(gtargets, [-1, 4])

    # raw mask for positive > 0.5, and for negetive < 0.3
    # each positive examples has one label
    positive_mask = glabels > 0  #tf.logical_and(glabels > 0, gscores > params['match_threshold'])
    fpositive_mask = tf.cast(positive_mask, tf.float32)
    n_positives = tf.reduce_sum(fpositive_mask)
    # negtive examples are those max_overlap is still lower than neg_threshold, note that some positive may also has lower jaccard
    # note those gscores is 0 is either be ignored during anchors encode or anchors have 0 overlap with all ground truth
    #negtive_mask = tf.logical_and(tf.logical_and(tf.logical_not(tf.logical_or(positive_mask, glabels < 0)), gscores < params['neg_threshold']), gscores > 0.)
    #gscores = tf.Print(gscores, [tf.reduce_sum(tf.cast(gscores > 0., tf.float32))])
    #glabels = tf.Print(glabels, [glabels, tf.reduce_sum(tf.cast(tf.equal(glabels, 0), tf.float32))], message='glabels: ', summarize=1000)

    negtive_mask = tf.logical_and(tf.equal(glabels, 0), gscores > 0.)
    #negtive_mask = tf.Print(negtive_mask, [tf.reduce_sum(tf.cast(negtive_mask, tf.float32))])
    #negtive_mask = tf.logical_and(tf.logical_and(tf.logical_not(positive_mask), gscores < params['neg_threshold']), gscores > 0.)
    #negtive_mask = tf.logical_and(gscores < params['neg_threshold'], tf.logical_not(positive_mask))
    fnegtive_mask = tf.cast(negtive_mask, tf.float32)
    n_negtives = tf.reduce_sum(fnegtive_mask)

    n_neg_to_select = tf.cast(params['negative_ratio'] * n_positives, tf.int32)
    n_neg_to_select = tf.minimum(n_neg_to_select,
                                 tf.cast(n_negtives, tf.int32))

    # hard negative mining for classification
    predictions_for_bg = tf.nn.softmax(cls_pred)[:, 0]
    prob_for_negtives = tf.where(
        negtive_mask,
        0. - predictions_for_bg,
        # ignore all the positives
        0. - tf.ones_like(predictions_for_bg))
    topk_prob_for_bg, _ = tf.nn.top_k(prob_for_negtives, k=n_neg_to_select)
    selected_neg_mask = prob_for_negtives > topk_prob_for_bg[-1]

    # # random select negtive examples for classification
    # selected_neg_mask = tf.random_uniform(tf.shape(gscores), minval=0, maxval=1.) < tf.where(
    #                                                                                     tf.greater(n_negtives, 0),
    #                                                                                     tf.divide(tf.cast(n_neg_to_select, tf.float32), n_negtives),
    #                                                                                     tf.zeros_like(tf.cast(n_neg_to_select, tf.float32)),
    #                                                                                     name='rand_select_negtive')

    # include both selected negtive and all positive examples
    final_mask = tf.stop_gradient(
        tf.logical_or(tf.logical_and(negtive_mask, selected_neg_mask),
                      positive_mask))
    total_examples = tf.reduce_sum(tf.cast(final_mask, tf.float32))

    # add mask for glabels and cls_pred here
    glabels = tf.boolean_mask(tf.clip_by_value(glabels, 0, FLAGS.num_classes),
                              tf.stop_gradient(final_mask))
    cls_pred = tf.boolean_mask(cls_pred, tf.stop_gradient(final_mask))
    location_pred = tf.boolean_mask(location_pred,
                                    tf.stop_gradient(positive_mask))
    gtargets = tf.boolean_mask(gtargets, tf.stop_gradient(positive_mask))
    predictions = {
        'classes':
        tf.argmax(cls_pred, axis=-1),
        'probabilities':
        tf.reduce_max(tf.nn.softmax(cls_pred, name='softmax_tensor'), axis=-1),
        'bboxes_predict':
        tf.reshape(bboxes_pred, [-1, 4])
    }

    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

    # Calculate loss, which includes softmax cross entropy and L2 regularization.
    cross_entropy = tf.cond(
        n_positives > 0., lambda: tf.losses.sparse_softmax_cross_entropy(
            labels=glabels, logits=cls_pred), lambda: 0.)
    #cross_entropy = tf.losses.sparse_softmax_cross_entropy(labels=glabels, logits=cls_pred)

    # Create a tensor named cross_entropy for logging purposes.
    tf.identity(cross_entropy, name='cross_entropy_loss')
    tf.summary.scalar('cross_entropy_loss', cross_entropy)

    loc_loss = tf.cond(
        n_positives > 0., lambda: modified_smooth_l1(
            location_pred, tf.stop_gradient(gtargets), sigma=1.),
        lambda: tf.zeros_like(location_pred))
    #loc_loss = modified_smooth_l1(location_pred, tf.stop_gradient(gtargets))
    loc_loss = tf.reduce_mean(tf.reduce_sum(loc_loss, axis=-1))
    loc_loss = tf.identity(loc_loss, name='location_loss')
    tf.summary.scalar('location_loss', loc_loss)
    tf.losses.add_loss(loc_loss)

    # Add weight decay to the loss. We exclude the batch norm variables because
    # doing so leads to a small improvement in accuracy.
    loss = 1.2 * (cross_entropy +
                  loc_loss) + params['weight_decay'] * tf.add_n([
                      tf.nn.l2_loss(v) for v in tf.trainable_variables()
                      if 'batch_normalization' not in v.name
                  ])
    total_loss = tf.identity(loss, name='total_loss')

    if mode == tf.estimator.ModeKeys.TRAIN:
        global_step = tf.train.get_or_create_global_step()

        lr_values = [
            params['learning_rate'] * decay
            for decay in params['lr_decay_factors']
        ]
        learning_rate = tf.train.piecewise_constant(
            tf.cast(global_step, tf.int32),
            [int(_) for _ in params['decay_boundaries']], lr_values)
        truncated_learning_rate = tf.maximum(
            learning_rate,
            tf.constant(params['end_learning_rate'],
                        dtype=learning_rate.dtype))
        # Create a tensor named learning_rate for logging purposes.
        tf.identity(truncated_learning_rate, name='learning_rate')
        tf.summary.scalar('learning_rate', truncated_learning_rate)

        optimizer = tf.train.MomentumOptimizer(
            learning_rate=truncated_learning_rate, momentum=params['momentum'])

        # Batch norm requires update_ops to be added as a train_op dependency.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = optimizer.minimize(loss, global_step)
    else:
        train_op = None

    cls_accuracy = tf.metrics.accuracy(glabels, predictions['classes'])
    metrics = {'cls_accuracy': cls_accuracy}

    # Create a tensor named train_accuracy for logging purposes.
    tf.identity(cls_accuracy[1], name='cls_accuracy')
    tf.summary.scalar('cls_accuracy', cls_accuracy[1])

    return tf.estimator.EstimatorSpec(
        mode=mode,
        predictions=predictions,
        loss=loss,
        train_op=train_op,
        eval_metric_ops=metrics,
        scaffold=tf.train.Scaffold(
            init_fn=train_helper.get_init_fn_for_scaffold(FLAGS)))
def xdet_model_fn(features, labels, mode, params):
    """Our model_fn for ResNet to be used with our Estimator."""

    shape = labels['shape']
    loc_targets = labels['loc_targets']
    cls_targets = labels['cls_targets']
    match_scores = labels['match_scores']

    global global_anchor_info
    decode_fn = global_anchor_info['decode_fn']
    num_anchors_per_layer = global_anchor_info['num_anchors_per_layer']
    all_num_anchors_depth = global_anchor_info['all_num_anchors_depth']

    with tf.variable_scope(params['model_scope'],
                           default_name=None,
                           values=[features],
                           reuse=tf.AUTO_REUSE):
        backbone = xdet_body_v4.xdet_resnet_v4(params['resnet_size'],
                                               params['data_format'])
        backbone_outputs = backbone(
            inputs=features, is_training=(mode == tf.estimator.ModeKeys.TRAIN))

        cls_pred, location_pred = xdet_body_v4.xdet_head(
            backbone_outputs,
            params['num_classes'],
            all_num_anchors_depth, (mode == tf.estimator.ModeKeys.TRAIN),
            data_format=params['data_format'])

        if params['data_format'] == 'channels_first':
            cls_pred = [tf.transpose(pred, [0, 2, 3, 1]) for pred in cls_pred]
            location_pred = [
                tf.transpose(pred, [0, 2, 3, 1]) for pred in location_pred
            ]

        cls_pred = [
            tf.reshape(pred,
                       [tf.shape(features)[0], -1, params['num_classes']])
            for pred in cls_pred
        ]
        location_pred = [
            tf.reshape(pred, [tf.shape(features)[0], -1, 4])
            for pred in location_pred
        ]

        cls_pred = tf.concat(cls_pred, axis=1)
        location_pred = tf.concat(location_pred, axis=1)

        cls_pred = tf.reshape(cls_pred, [-1, params['num_classes']])
        location_pred = tf.reshape(location_pred, [-1, 4])

    with tf.device('/cpu:0'):
        with tf.control_dependencies([cls_pred, location_pred]):
            with tf.name_scope('post_forward'):
                #bboxes_pred = decode_fn(location_pred)
                bboxes_pred = tf.map_fn(
                    lambda _preds: decode_fn(_preds),
                    tf.reshape(location_pred, [tf.shape(features)[0], -1, 4]),
                    dtype=[tf.float32] * len(num_anchors_per_layer),
                    back_prop=False)
                #cls_targets = tf.Print(cls_targets, [tf.shape(bboxes_pred[0]),tf.shape(bboxes_pred[1]),tf.shape(bboxes_pred[2]),tf.shape(bboxes_pred[3])])
                bboxes_pred = [
                    tf.reshape(preds, [-1, 4]) for preds in bboxes_pred
                ]
                bboxes_pred = tf.concat(bboxes_pred, axis=0)

                flaten_cls_targets = tf.reshape(cls_targets, [-1])
                flaten_match_scores = tf.reshape(match_scores, [-1])
                flaten_loc_targets = tf.reshape(loc_targets, [-1, 4])

                # each positive examples has one label
                positive_mask = flaten_cls_targets > 0
                n_positives = tf.count_nonzero(positive_mask)

                batch_n_positives = tf.count_nonzero(cls_targets, -1)

                batch_negtive_mask = tf.equal(
                    cls_targets, 0
                )  #tf.logical_and(tf.equal(cls_targets, 0), match_scores > 0.)
                batch_n_negtives = tf.count_nonzero(batch_negtive_mask, -1)

                batch_n_neg_select = tf.cast(
                    params['negative_ratio'] *
                    tf.cast(batch_n_positives, tf.float32), tf.int32)
                batch_n_neg_select = tf.minimum(
                    batch_n_neg_select, tf.cast(batch_n_negtives, tf.int32))

                # hard negative mining for classification
                predictions_for_bg = tf.nn.softmax(
                    tf.reshape(
                        cls_pred,
                        [tf.shape(features)[0], -1, params['num_classes']
                         ]))[:, :, 0]
                prob_for_negtives = tf.where(
                    batch_negtive_mask,
                    0. - predictions_for_bg,
                    # ignore all the positives
                    0. - tf.ones_like(predictions_for_bg))
                topk_prob_for_bg, _ = tf.nn.top_k(
                    prob_for_negtives, k=tf.shape(prob_for_negtives)[1])
                score_at_k = tf.gather_nd(
                    topk_prob_for_bg,
                    tf.stack([
                        tf.range(tf.shape(features)[0]), batch_n_neg_select - 1
                    ],
                             axis=-1))

                selected_neg_mask = prob_for_negtives >= tf.expand_dims(
                    score_at_k, axis=-1)

                # include both selected negtive and all positive examples
                final_mask = tf.stop_gradient(
                    tf.logical_or(
                        tf.reshape(
                            tf.logical_and(batch_negtive_mask,
                                           selected_neg_mask), [-1]),
                        positive_mask))
                total_examples = tf.count_nonzero(final_mask)

                cls_pred = tf.boolean_mask(cls_pred, final_mask)
                location_pred = tf.boolean_mask(
                    location_pred, tf.stop_gradient(positive_mask))
                flaten_cls_targets = tf.boolean_mask(
                    tf.clip_by_value(flaten_cls_targets, 0,
                                     params['num_classes']), final_mask)
                flaten_loc_targets = tf.stop_gradient(
                    tf.boolean_mask(flaten_loc_targets, positive_mask))

                predictions = {
                    'classes':
                    tf.argmax(cls_pred, axis=-1),
                    'probabilities':
                    tf.reduce_max(tf.nn.softmax(cls_pred,
                                                name='softmax_tensor'),
                                  axis=-1),
                    'loc_predict':
                    bboxes_pred
                }

                cls_accuracy = tf.metrics.accuracy(flaten_cls_targets,
                                                   predictions['classes'])
                metrics = {'cls_accuracy': cls_accuracy}

                # Create a tensor named train_accuracy for logging purposes.
                tf.identity(cls_accuracy[1], name='cls_accuracy')
                tf.summary.scalar('cls_accuracy', cls_accuracy[1])

    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

    # Calculate loss, which includes softmax cross entropy and L2 regularization.
    #cross_entropy = tf.cond(n_positives > 0, lambda: tf.losses.sparse_softmax_cross_entropy(labels=flaten_cls_targets, logits=cls_pred), lambda: 0.)# * (params['negative_ratio'] + 1.)
    #flaten_cls_targets=tf.Print(flaten_cls_targets, [flaten_loc_targets],summarize=50000)
    cross_entropy = tf.losses.sparse_softmax_cross_entropy(
        labels=flaten_cls_targets,
        logits=cls_pred) * (params['negative_ratio'] + 1.)
    # Create a tensor named cross_entropy for logging purposes.
    tf.identity(cross_entropy, name='cross_entropy_loss')
    tf.summary.scalar('cross_entropy_loss', cross_entropy)

    #loc_loss = tf.cond(n_positives > 0, lambda: modified_smooth_l1(location_pred, tf.stop_gradient(flaten_loc_targets), sigma=1.), lambda: tf.zeros_like(location_pred))
    loc_loss = modified_smooth_l1(location_pred, flaten_loc_targets, sigma=1.)
    #loc_loss = modified_smooth_l1(location_pred, tf.stop_gradient(gtargets))
    loc_loss = tf.reduce_mean(tf.reduce_sum(loc_loss, axis=-1),
                              name='location_loss')
    tf.summary.scalar('location_loss', loc_loss)
    tf.losses.add_loss(loc_loss)

    l2_loss_vars = []
    for trainable_var in tf.trainable_variables():
        #if 'batch_normalization' not in trainable_var.name:
        l2_loss_vars.append(tf.nn.l2_loss(trainable_var))

    # Add weight decay to the loss. We exclude the batch norm variables because
    # doing so leads to a small improvement in accuracy.
    total_loss = tf.add(cross_entropy + loc_loss,
                        tf.multiply(params['weight_decay'],
                                    tf.add_n(l2_loss_vars),
                                    name='l2_loss'),
                        name='total_loss')

    if mode == tf.estimator.ModeKeys.TRAIN:
        global_step = tf.train.get_or_create_global_step()

        lr_values = [
            params['learning_rate'] * decay
            for decay in params['lr_decay_factors']
        ]
        learning_rate = tf.train.piecewise_constant(
            tf.cast(global_step, tf.int32),
            [int(_) for _ in params['decay_boundaries']], lr_values)
        truncated_learning_rate = tf.maximum(learning_rate,
                                             tf.constant(
                                                 params['end_learning_rate'],
                                                 dtype=learning_rate.dtype),
                                             name='learning_rate')
        # Create a tensor named learning_rate for logging purposes.
        tf.summary.scalar('learning_rate', truncated_learning_rate)

        optimizer = tf.train.MomentumOptimizer(
            learning_rate=truncated_learning_rate, momentum=params['momentum'])

        # Batch norm requires update_ops to be added as a train_op dependency.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = optimizer.minimize(total_loss, global_step)
    else:
        train_op = None

    return tf.estimator.EstimatorSpec(
        mode=mode,
        predictions=predictions,
        loss=total_loss,
        train_op=train_op,
        eval_metric_ops=metrics,
        scaffold=tf.train.Scaffold(
            init_fn=train_helper.get_init_fn_for_scaffold(FLAGS)))