lambda inputs, output_channals, heatmap_size, istraining, data_format: detxt_cpn.cascaded_pyramid_net(inputs, output_channals, heatmap_size, istraining, data_format, net_depth=101), 'logs_sub_dir': 'logs_large_detxt_cpn' }, 'simple_net': { 'backbone': lambda inputs, output_channals, heatmap_size, istraining, data_format: simple_xt.simple_net(inputs, output_channals, heatmap_size, istraining, data_format, net_depth=101), 'logs_sub_dir': 'logs_simple_net' }, 'head_seresnext50_cpn': { 'backbone': seresnet_cpn.head_xt_cascaded_pyramid_net, 'logs_sub_dir': 'logs_head_sext_cpn' }, } def input_pipeline(model_scope=FLAGS.model_scope): preprocessing_fn = lambda org_image, file_name, shape: preprocessing.preprocess_for_test_raw_output( org_image,
def keypoint_model_fn(features, labels, mode, params): targets = labels['targets'] shape = labels['shape'] classid = labels['classid'] key_v = labels['key_v'] isvalid = labels['isvalid'] norm_value = labels['norm_value'] cur_batch_size = tf.shape(features)[0] #features= tf.ones_like(features) with tf.variable_scope(params['model_scope'], default_name=None, values=[features], reuse=tf.AUTO_REUSE): pred_outputs = simple_xt.simple_net( features, config.class_num_joints[(params['model_scope'] if 'all' not in params['model_scope'] else '*')], params['heatmap_size'], (mode == tf.estimator.ModeKeys.TRAIN), params['data_format'])[0] if params['data_format'] == 'channels_last': pred_outputs = tf.transpose(pred_outputs, [0, 3, 1, 2], name='outputs_trans') score_map = pred_outputs pred_x, pred_y = get_keypoint( features, targets, score_map, params['heatmap_size'], params['train_image_size'], params['train_image_size'], (params['model_scope'] if 'all' not in params['model_scope'] else '*'), clip_at_zero=True, data_format=params['data_format']) # this is important!!! targets = 255. * targets #with tf.control_dependencies([pred_x, pred_y]): ne_mertric = mertric.normalized_error( targets, score_map, norm_value, key_v, isvalid, cur_batch_size, config.class_num_joints[(params['model_scope'] if 'all' not in params['model_scope'] else '*')], params['heatmap_size'], params['train_image_size']) all_visible = tf.expand_dims(tf.expand_dims(tf.cast( tf.logical_and(key_v > 0, isvalid > 0), tf.float32), axis=-1), axis=-1) targets = targets * all_visible pred_outputs = pred_outputs * all_visible sq_diff = tf.reduce_sum(tf.squared_difference(targets, pred_outputs), axis=-1) last_pred_mse = tf.metrics.mean_absolute_error(sq_diff, tf.zeros_like(sq_diff), name='last_pred_mse') metrics = {'normalized_error': ne_mertric, 'last_pred_mse': last_pred_mse} predictions = {'normalized_error': ne_mertric[1]} ne_mertric = tf.identity(ne_mertric[1], name='ne_mertric') base_learning_rate = params['learning_rate'] mse_loss_list = [] if params['use_ohkm']: base_learning_rate = 1. * base_learning_rate temp_loss = tf.reduce_mean(tf.reshape( tf.losses.mean_squared_error(targets, pred_outputs, weights=1.0, loss_collection=None, reduction=tf.losses.Reduction.NONE), [ cur_batch_size, config.class_num_joints[( params['model_scope'] if 'all' not in params['model_scope'] else '*')], -1 ]), axis=-1) num_topk = config.class_num_joints[ (params['model_scope'] if 'all' not in params['model_scope'] else '*')] // 2 gather_col = tf.nn.top_k(temp_loss, k=num_topk, sorted=True)[1] gather_row = tf.reshape( tf.tile(tf.reshape(tf.range(cur_batch_size), [-1, 1]), [1, num_topk]), [-1, 1]) gather_indcies = tf.stop_gradient( tf.stack([gather_row, tf.reshape(gather_col, [-1, 1])], axis=-1)) select_targets = tf.gather_nd(targets, gather_indcies) select_heatmap = tf.gather_nd(pred_outputs, gather_indcies) mse_loss_list.append( tf.losses.mean_squared_error( select_targets, select_heatmap, weights=1.0 / tf.cast(cur_batch_size, tf.float32), scope='loss', loss_collection=None, #tf.GraphKeys.LOSSES, # mean all elements of all pixels in all batch reduction=tf.losses.Reduction.MEAN)) else: mse_loss_list.append( tf.losses.mean_squared_error( targets, pred_outputs, weights=1.0 / tf.cast(cur_batch_size, tf.float32), scope='loss', loss_collection=None, #tf.GraphKeys.LOSSES, # mean all elements of all pixels in all batch reduction=tf.losses.Reduction.MEAN) ) # SUM, SUM_OVER_BATCH_SIZE, default mean by all elements mse_loss = tf.multiply(params['mse_weight'], tf.add_n(mse_loss_list), name='mse_loss') tf.summary.scalar('mse', mse_loss) tf.losses.add_loss(mse_loss) # bce_loss_list = [] # for pred_ind in list(range(len(pred_outputs))): # bce_loss_list.append(tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=pred_outputs[pred_ind], labels=targets_list[pred_ind]/255., name='loss_{}'.format(pred_ind)), name='loss_mean_{}'.format(pred_ind))) # mse_loss = tf.multiply(params['mse_weight'] / params['num_stacks'], tf.add_n(bce_loss_list), name='mse_loss') # tf.summary.scalar('mse', mse_loss) # tf.losses.add_loss(mse_loss) # Add weight decay to the loss. We exclude the batch norm variables because # doing so leads to a small improvement in accuracy. loss = mse_loss + params['weight_decay'] * tf.add_n([ tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'batch_normalization' not in v.name ]) total_loss = tf.identity(loss, name='total_loss') tf.summary.scalar('loss', total_loss) if mode == tf.estimator.ModeKeys.EVAL: return tf.estimator.EstimatorSpec(mode=mode, loss=loss, predictions=predictions, eval_metric_ops=metrics) if mode == tf.estimator.ModeKeys.TRAIN: global_step = tf.train.get_or_create_global_step() lr_values = [params['warmup_learning_rate']] + [ base_learning_rate * decay for decay in params['lr_decay_factors'] ] learning_rate = tf.train.piecewise_constant( tf.cast(global_step, tf.int32), [params['warmup_steps']] + [ int(float(ep) * params['steps_per_epoch']) for ep in params['decay_boundaries'] ], lr_values) truncated_learning_rate = tf.maximum(learning_rate, tf.constant( params['end_learning_rate'], dtype=learning_rate.dtype), name='learning_rate') tf.summary.scalar('lr', truncated_learning_rate) optimizer = tf.train.MomentumOptimizer( learning_rate=truncated_learning_rate, momentum=params['momentum']) optimizer = tf_replicate_model_fn.TowerOptimizer(optimizer) # Batch norm requires update_ops to be added as a train_op dependency. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss, global_step) else: train_op = None return tf.estimator.EstimatorSpec( mode=mode, predictions=predictions, loss=loss, train_op=train_op, eval_metric_ops=metrics, scaffold=tf.train. Scaffold(init_fn=train_helper.get_init_fn_for_scaffold_( params['checkpoint_path'], params['model_dir'], params['checkpoint_exclude_scopes'], params['model_scope'], params['checkpoint_model_scope'], params['ignore_missing_vars'])))