Exemple #1
0
        lambda inputs, output_channals, heatmap_size, istraining, data_format:
        detxt_cpn.cascaded_pyramid_net(inputs,
                                       output_channals,
                                       heatmap_size,
                                       istraining,
                                       data_format,
                                       net_depth=101),
        'logs_sub_dir':
        'logs_large_detxt_cpn'
    },
    'simple_net': {
        'backbone':
        lambda inputs, output_channals, heatmap_size, istraining, data_format:
        simple_xt.simple_net(inputs,
                             output_channals,
                             heatmap_size,
                             istraining,
                             data_format,
                             net_depth=101),
        'logs_sub_dir':
        'logs_simple_net'
    },
    'head_seresnext50_cpn': {
        'backbone': seresnet_cpn.head_xt_cascaded_pyramid_net,
        'logs_sub_dir': 'logs_head_sext_cpn'
    },
}


def input_pipeline(model_scope=FLAGS.model_scope):
    preprocessing_fn = lambda org_image, file_name, shape: preprocessing.preprocess_for_test_raw_output(
        org_image,
Exemple #2
0
def keypoint_model_fn(features, labels, mode, params):
    targets = labels['targets']
    shape = labels['shape']
    classid = labels['classid']
    key_v = labels['key_v']
    isvalid = labels['isvalid']
    norm_value = labels['norm_value']

    cur_batch_size = tf.shape(features)[0]
    #features= tf.ones_like(features)

    with tf.variable_scope(params['model_scope'],
                           default_name=None,
                           values=[features],
                           reuse=tf.AUTO_REUSE):
        pred_outputs = simple_xt.simple_net(
            features,
            config.class_num_joints[(params['model_scope'] if 'all'
                                     not in params['model_scope'] else '*')],
            params['heatmap_size'], (mode == tf.estimator.ModeKeys.TRAIN),
            params['data_format'])[0]

    if params['data_format'] == 'channels_last':
        pred_outputs = tf.transpose(pred_outputs, [0, 3, 1, 2],
                                    name='outputs_trans')

    score_map = pred_outputs

    pred_x, pred_y = get_keypoint(
        features,
        targets,
        score_map,
        params['heatmap_size'],
        params['train_image_size'],
        params['train_image_size'],
        (params['model_scope'] if 'all' not in params['model_scope'] else '*'),
        clip_at_zero=True,
        data_format=params['data_format'])

    # this is important!!!
    targets = 255. * targets

    #with tf.control_dependencies([pred_x, pred_y]):
    ne_mertric = mertric.normalized_error(
        targets, score_map, norm_value, key_v, isvalid, cur_batch_size,
        config.class_num_joints[(params['model_scope'] if 'all'
                                 not in params['model_scope'] else '*')],
        params['heatmap_size'], params['train_image_size'])

    all_visible = tf.expand_dims(tf.expand_dims(tf.cast(
        tf.logical_and(key_v > 0, isvalid > 0), tf.float32),
                                                axis=-1),
                                 axis=-1)
    targets = targets * all_visible
    pred_outputs = pred_outputs * all_visible

    sq_diff = tf.reduce_sum(tf.squared_difference(targets, pred_outputs),
                            axis=-1)
    last_pred_mse = tf.metrics.mean_absolute_error(sq_diff,
                                                   tf.zeros_like(sq_diff),
                                                   name='last_pred_mse')

    metrics = {'normalized_error': ne_mertric, 'last_pred_mse': last_pred_mse}
    predictions = {'normalized_error': ne_mertric[1]}
    ne_mertric = tf.identity(ne_mertric[1], name='ne_mertric')

    base_learning_rate = params['learning_rate']
    mse_loss_list = []
    if params['use_ohkm']:
        base_learning_rate = 1. * base_learning_rate
        temp_loss = tf.reduce_mean(tf.reshape(
            tf.losses.mean_squared_error(targets,
                                         pred_outputs,
                                         weights=1.0,
                                         loss_collection=None,
                                         reduction=tf.losses.Reduction.NONE),
            [
                cur_batch_size, config.class_num_joints[(
                    params['model_scope']
                    if 'all' not in params['model_scope'] else '*')], -1
            ]),
                                   axis=-1)

        num_topk = config.class_num_joints[
            (params['model_scope']
             if 'all' not in params['model_scope'] else '*')] // 2
        gather_col = tf.nn.top_k(temp_loss, k=num_topk, sorted=True)[1]
        gather_row = tf.reshape(
            tf.tile(tf.reshape(tf.range(cur_batch_size), [-1, 1]),
                    [1, num_topk]), [-1, 1])
        gather_indcies = tf.stop_gradient(
            tf.stack([gather_row, tf.reshape(gather_col, [-1, 1])], axis=-1))

        select_targets = tf.gather_nd(targets, gather_indcies)
        select_heatmap = tf.gather_nd(pred_outputs, gather_indcies)

        mse_loss_list.append(
            tf.losses.mean_squared_error(
                select_targets,
                select_heatmap,
                weights=1.0 / tf.cast(cur_batch_size, tf.float32),
                scope='loss',
                loss_collection=None,  #tf.GraphKeys.LOSSES,
                # mean all elements of all pixels in all batch
                reduction=tf.losses.Reduction.MEAN))
    else:
        mse_loss_list.append(
            tf.losses.mean_squared_error(
                targets,
                pred_outputs,
                weights=1.0 / tf.cast(cur_batch_size, tf.float32),
                scope='loss',
                loss_collection=None,  #tf.GraphKeys.LOSSES,
                # mean all elements of all pixels in all batch
                reduction=tf.losses.Reduction.MEAN)
        )  # SUM, SUM_OVER_BATCH_SIZE, default mean by all elements

    mse_loss = tf.multiply(params['mse_weight'],
                           tf.add_n(mse_loss_list),
                           name='mse_loss')
    tf.summary.scalar('mse', mse_loss)
    tf.losses.add_loss(mse_loss)

    # bce_loss_list = []
    # for pred_ind in list(range(len(pred_outputs))):
    #     bce_loss_list.append(tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=pred_outputs[pred_ind], labels=targets_list[pred_ind]/255., name='loss_{}'.format(pred_ind)), name='loss_mean_{}'.format(pred_ind)))

    # mse_loss = tf.multiply(params['mse_weight'] / params['num_stacks'], tf.add_n(bce_loss_list), name='mse_loss')
    # tf.summary.scalar('mse', mse_loss)
    # tf.losses.add_loss(mse_loss)

    # Add weight decay to the loss. We exclude the batch norm variables because
    # doing so leads to a small improvement in accuracy.
    loss = mse_loss + params['weight_decay'] * tf.add_n([
        tf.nn.l2_loss(v) for v in tf.trainable_variables()
        if 'batch_normalization' not in v.name
    ])
    total_loss = tf.identity(loss, name='total_loss')
    tf.summary.scalar('loss', total_loss)

    if mode == tf.estimator.ModeKeys.EVAL:
        return tf.estimator.EstimatorSpec(mode=mode,
                                          loss=loss,
                                          predictions=predictions,
                                          eval_metric_ops=metrics)

    if mode == tf.estimator.ModeKeys.TRAIN:
        global_step = tf.train.get_or_create_global_step()

        lr_values = [params['warmup_learning_rate']] + [
            base_learning_rate * decay for decay in params['lr_decay_factors']
        ]
        learning_rate = tf.train.piecewise_constant(
            tf.cast(global_step, tf.int32), [params['warmup_steps']] + [
                int(float(ep) * params['steps_per_epoch'])
                for ep in params['decay_boundaries']
            ], lr_values)
        truncated_learning_rate = tf.maximum(learning_rate,
                                             tf.constant(
                                                 params['end_learning_rate'],
                                                 dtype=learning_rate.dtype),
                                             name='learning_rate')
        tf.summary.scalar('lr', truncated_learning_rate)

        optimizer = tf.train.MomentumOptimizer(
            learning_rate=truncated_learning_rate, momentum=params['momentum'])

        optimizer = tf_replicate_model_fn.TowerOptimizer(optimizer)

        # Batch norm requires update_ops to be added as a train_op dependency.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = optimizer.minimize(loss, global_step)
    else:
        train_op = None

    return tf.estimator.EstimatorSpec(
        mode=mode,
        predictions=predictions,
        loss=loss,
        train_op=train_op,
        eval_metric_ops=metrics,
        scaffold=tf.train.
        Scaffold(init_fn=train_helper.get_init_fn_for_scaffold_(
            params['checkpoint_path'], params['model_dir'],
            params['checkpoint_exclude_scopes'], params['model_scope'],
            params['checkpoint_model_scope'], params['ignore_missing_vars'])))