def model_fn(features, mode, params): # *********************************************************************************************** # * Backbone Net * # *********************************************************************************************** net_config = params["net_config"] IS_TRAINING = False image_window = features["image_window"] origin_image_batch = features["image"] image_batch = origin_image_batch - tf.convert_to_tensor( net_config.PIXEL_MEANS, dtype=tf.float32) # there is is_training means that bn is training, so it is important! _, share_net = get_network_byname(inputs=image_batch, config=net_config, is_training=IS_TRAINING, reuse=tf.AUTO_REUSE) # *********************************************************************************************** # * FPN * # *********************************************************************************************** feature_pyramid = build_fpn.build_feature_pyramid(share_net, net_config) # *********************************************************************************************** # * RPN * # *********************************************************************************************** rpn = build_rpn.RPN(feature_pyramid=feature_pyramid, image_window=image_window, config=net_config) # rpn_proposals_scores==(2000,) rpn_proposals_boxes, rpn_proposals_scores = rpn.rpn_proposals(IS_TRAINING) # *********************************************************************************************** # * Fast RCNN Head * # *********************************************************************************************** fast_rcnn = build_head.FPNHead(feature_pyramid=feature_pyramid, rpn_proposals_boxes=rpn_proposals_boxes, origin_image=origin_image_batch, gtboxes_and_label=None, config=net_config, is_training=IS_TRAINING, image_window=image_window) detections = fast_rcnn.head_detection() # *********************************************************************************************** # * Summary * # *********************************************************************************************** if mode == tf.estimator.ModeKeys.PREDICT: predicts = { "image": origin_image_batch, "predict_bbox": detections[:, :, :4], "predict_class_id": detections[:, :, 4], "predict_scores": detections[:, :, 5], "rpn_proposal_boxes": rpn_proposals_boxes, "rpn_proposals_scores": rpn_proposals_scores, "gt_box_labels": features["gt_box_labels"] } return tf.estimator.EstimatorSpec(mode, predictions=predicts)
def model_fn(features, labels, mode, params): # *********************************************************************************************** # * Backbone Net * # *********************************************************************************************** net_config = params["net_config"] if mode == tf.estimator.ModeKeys.TRAIN: IS_TRAINING = True else: IS_TRAINING = False origin_image_batch = features["image"] image_batch = origin_image_batch - tf.convert_to_tensor( net_config.PIXEL_MEANS, dtype=tf.float32) image_window = features["image_window"] # there is is_training means that bn is training, so it is important! _, share_net = get_network_byname(inputs=image_batch, config=net_config, is_training=False, reuse=tf.AUTO_REUSE) # *********************************************************************************************** # * FPN * # *********************************************************************************************** feature_pyramid = build_fpn.build_feature_pyramid(share_net, net_config) # *********************************************************************************************** # * RPN * # *********************************************************************************************** gtboxes_and_label_batch = labels.get("gt_box_labels") rpn = build_rpn.RPN(feature_pyramid=feature_pyramid, image_window=image_window, config=net_config) # rpn_proposals_scores==(2000,) rpn_proposals_boxes, rpn_proposals_scores = rpn.rpn_proposals(IS_TRAINING) rpn_location_loss, rpn_classification_loss = rpn.rpn_losses( labels["minibatch_indices"], labels["minibatch_encode_gtboxes"], labels["minibatch_objects_one_hot"]) rpn_total_loss = rpn_classification_loss + rpn_location_loss # *********************************************************************************************** # * Fast RCNN Head * # *********************************************************************************************** fpn_fast_rcnn_head = build_head.FPNHead( feature_pyramid=feature_pyramid, rpn_proposals_boxes=rpn_proposals_boxes, origin_image=origin_image_batch, gtboxes_and_label=gtboxes_and_label_batch, config=net_config, is_training=False, image_window=image_window) detections = fpn_fast_rcnn_head.head_detection() if net_config.DEBUG: print_tensors(rpn_proposals_scores[0, :50], "scores") print_tensors(rpn_proposals_boxes[0, :50, :], "bbox") rpn_proposals_vision = draw_boxes_with_scores( origin_image_batch[0, :, :, :], rpn_proposals_boxes[0, :50, :], rpn_proposals_scores[0, :50]) head_vision = draw_boxes_with_categories_and_scores( origin_image_batch[0, :, :, :], detections[0, :, :4], detections[0, :, 4], detections[0, :, 5], net_config.LABEL_TO_NAME) tf.summary.image("rpn_proposals_vision", rpn_proposals_vision) tf.summary.image("head_vision", head_vision) head_location_loss, head_classification_loss = fpn_fast_rcnn_head.head_loss( ) head_total_loss = head_location_loss + head_classification_loss # train with tf.name_scope("regularization_losses"): regularization_list = [ tf.nn.l2_loss(w.read_value()) * net_config.WEIGHT_DECAY / tf.cast(tf.size(w.read_value()), tf.float32) for w in tf.trainable_variables() if 'gamma' not in w.name and 'beta' not in w.name ] regularization_loss = tf.add_n(regularization_list) total_loss = regularization_loss + head_total_loss + rpn_total_loss total_loss = tf.cond(tf.is_nan(total_loss), lambda: 0.0, lambda: total_loss) print_tensors(head_total_loss, "head_loss") print_tensors(rpn_total_loss, "rpn_loss") global_step = tf.train.get_or_create_global_step() tf.train.init_from_checkpoint( net_config.CHECKPOINT_DIR, {net_config.BACKBONE_NET + "/": net_config.BACKBONE_NET + "/"}) with tf.name_scope("optimizer"): lr = tf.train.piecewise_constant( global_step, boundaries=[np.int64(net_config.BOUNDARY[0])], values=[net_config.LEARNING_RATE, net_config.LEARNING_RATE / 10]) optimizer = tf.train.MomentumOptimizer(lr, momentum=net_config.MOMENTUM) optimizer = tf.contrib.estimator.TowerOptimizer(optimizer) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies([tf.group(*update_ops)]): grads = optimizer.compute_gradients(total_loss) # clip gradients grads = tf.contrib.training.clip_gradient_norms( grads, net_config.CLIP_GRADIENT_NORM) train_op = optimizer.apply_gradients(grads, global_step) # *********************************************************************************************** # * Summary * # *********************************************************************************************** # rpn loss and image tf.summary.scalar('rpn_location_loss', rpn_location_loss, family="rpn_loss") tf.summary.scalar('rpn_classification_loss', rpn_classification_loss, family="rpn_loss") tf.summary.scalar('rpn_total_loss', rpn_total_loss, family="rpn_loss") tf.summary.scalar('head_location_loss', head_location_loss, family="head_loss") tf.summary.scalar('head_classification_loss', head_classification_loss, family="head_loss") tf.summary.scalar('head_total_loss', head_total_loss, family="head_loss") tf.summary.scalar("regularization_loss", regularization_loss) tf.summary.scalar('total_loss', total_loss) tf.summary.scalar('learning_rate', lr) meta_hook = MetadataHook(save_steps=net_config.SAVE_EVERY_N_STEP * net_config.EPOCH / 2, output_dir=net_config.MODLE_DIR) summary_hook = tf.train.SummarySaverHook( save_steps=net_config.SAVE_EVERY_N_STEP, output_dir=net_config.MODLE_DIR, summary_op=tf.summary.merge_all()) hooks = [summary_hook] if net_config.COMPUTE_TIME: hooks.append(meta_hook) if mode == tf.estimator.ModeKeys.TRAIN: return tf.estimator.EstimatorSpec(mode, loss=total_loss, train_op=train_op, training_hooks=hooks) # *********************************************************************************************** # * EVAL * # *********************************************************************************************** metric_ap_dict = batch_slice([ features["gt_box_labels"][:, :, :4], features["gt_box_labels"][:, :, 4], detections[:, :, :4], detections[:, :, 4], detections[:, :, 5] ], lambda x, y, z, u, v: compute_metric_ap(x, y, z, u, v, net_config), net_config.PER_GPU_IMAGE) if mode == tf.estimator.ModeKeys.EVAL: return tf.estimator.EstimatorSpec(mode, loss=total_loss, eval_metric_ops=metric_ap_dict)
def model_fn(features, mode, params, config): # *********************************************************************************************** # * share net * # *********************************************************************************************** net_config = params["net_config"] IS_TRAINING = False origin_image_batch = features["image"] image_window = features["image_window"] image_batch = origin_image_batch - net_config.PIXEL_MEANS # there is is_training means that bn is training, so it is important! _, share_net = get_network_byname(inputs=image_batch, config=net_config, is_training=IS_TRAINING, reuse=tf.AUTO_REUSE) # *********************************************************************************************** # * fpn * # *********************************************************************************************** feature_pyramid = build_fpn.build_feature_pyramid(share_net, net_config) # *********************************************************************************************** # * rpn * # *********************************************************************************************** rpn = build_rpn.RPN(feature_pyramid=feature_pyramid, image_window=image_window, config=net_config) # rpn_proposals_scores==(2000,) rpn_proposals_boxes, rpn_proposals_scores = rpn.rpn_proposals(IS_TRAINING) # *********************************************************************************************** # * Rerference image * # *********************************************************************************************** reference_image = load_reference_image() reference_image = tf.cast(reference_image, tf.float32) reference_image = reference_image - net_config.PIXEL_MEANS _, reference_share_net = get_network_byname(inputs=reference_image, config=net_config, is_training=False, reuse=tf.AUTO_REUSE) reference_feature_pyramid = build_fpn.build_feature_pyramid(reference_share_net, net_config) # average the features of support images # reference_feature_pyramid[key](C*S, H, W, 256)---->(C, 7, 7, 256) with tf.variable_scope('reference_feature_origision'): for key, value in reference_feature_pyramid.items(): reference_feature_pyramid[key] = tf.image.resize_bilinear(reference_feature_pyramid[key], (net_config.ROI_SIZE, net_config.ROI_SIZE)) reference_feature_pyramid[key] = tf.reduce_mean(tf.reshape(reference_feature_pyramid[key], (net_config.NUM_CLASS-1, net_config.NUM_SUPPROTS, net_config.ROI_SIZE, net_config.ROI_SIZE, 256)), axis=1) # average the features of fpn features average_fpn_feature = [] for key, value in reference_feature_pyramid.items(): average_fpn_feature.append(value) reference_fpn_features = tf.reduce_mean(tf.stack(average_fpn_feature, axis=0), axis=0) # compute the negative features with tf.variable_scope("reference_negative"): with slim.arg_scope([slim.conv2d], padding="SAME", weights_initializer=tf.glorot_uniform_initializer(), weights_regularizer=slim.l2_regularizer(net_config.WEIGHT_DECAY)): # the shape of positive features is (1, H, W, C*channels) positive_features = tf.reshape(tf.transpose(reference_fpn_features, (1, 2, 0, 3)), (1, net_config.ROI_SIZE, net_config.ROI_SIZE, (net_config.NUM_CLASS-1)*256)) # (1, H, W, channels) negative_feature = slim.conv2d(positive_features, num_outputs=256, kernel_size=[3,3], stride=1) total_refernece_feature = tf.concat([negative_feature, reference_fpn_features], axis=0) # *********************************************************************************************** # * Fast RCNN * # *********************************************************************************************** fast_rcnn = build_fast_rcnn.FastRCNN(feature_pyramid=feature_pyramid, rpn_proposals_boxes=rpn_proposals_boxes, origin_image=origin_image_batch, gtboxes_and_label=None, reference_feature=total_refernece_feature, config=net_config, is_training=IS_TRAINING, image_window=image_window) detections = fast_rcnn.fast_rcnn_detection() # *********************************************************************************************** # * Summary * # *********************************************************************************************** if mode == tf.estimator.ModeKeys.PREDICT: predicts = {"image": origin_image_batch, "predict_bbox": detections[:, :, :4], "predict_class_id": detections[:, :, 4], "predict_scores": detections[:, :, 5], "rpn_proposal_boxes": rpn_proposals_boxes, "rpn_proposals_scores":rpn_proposals_scores, "gt_box_labels": features["gt_box_labels"]} return tf.estimator.EstimatorSpec(mode, predictions=predicts)
def model_fn(features, labels, mode, params, config): # *********************************************************************************************** # * share net * # *********************************************************************************************** net_config = params["net_config"] if mode == tf.estimator.ModeKeys.TRAIN: IS_TRAINING = True else: IS_TRAINING = False origin_image_batch = features["image"] image_window = features["image_window"] image_batch = origin_image_batch - net_config.PIXEL_MEANS # there is is_training means that bn is training, so it is important! _, share_net = get_network_byname(inputs=image_batch, config=net_config, is_training=False, reuse=tf.AUTO_REUSE) # *********************************************************************************************** # * fpn * # *********************************************************************************************** feature_pyramid = build_fpn.build_feature_pyramid(share_net, net_config) # *********************************************************************************************** # * rpn * # *********************************************************************************************** gtboxes_and_label_batch = labels.get("gt_box_labels") rpn = build_rpn.RPN(feature_pyramid=feature_pyramid, image_window=image_window, config=net_config) # rpn_proposals_scores==(2000,) rpn_proposals_boxes, rpn_proposals_scores = rpn.rpn_proposals(IS_TRAINING) rpn_location_loss, rpn_classification_loss = rpn.rpn_losses( labels["minibatch_indices"], labels["minibatch_encode_gtboxes"], labels["minibatch_objects_one_hot"]) rpn_total_loss = rpn_classification_loss + rpn_location_loss # *********************************************************************************************** # * Fast RCNN * # *********************************************************************************************** fast_rcnn = build_fast_rcnn.FastRCNN( feature_pyramid=feature_pyramid, rpn_proposals_boxes=rpn_proposals_boxes, origin_image=origin_image_batch, gtboxes_and_label=gtboxes_and_label_batch, config=net_config, is_training=False, image_window=image_window) detections = fast_rcnn.fast_rcnn_detection() if DEBUG: rpn_proposals_vision = draw_boxes_with_scores( origin_image_batch[0, :, :, :], rpn_proposals_boxes[0, :50, :], rpn_proposals_scores[0, :50]) fast_rcnn_vision = draw_boxes_with_categories_and_scores( origin_image_batch[0, :, :, :], detections[0, :, :4], detections[0, :, 4], detections[0, :, 5]) tf.summary.image("rpn_proposals_vision", rpn_proposals_vision) tf.summary.image("fast_rcnn_vision", fast_rcnn_vision) fast_rcnn_location_loss, fast_rcnn_classification_loss = fast_rcnn.fast_rcnn_loss( ) fast_rcnn_total_loss = fast_rcnn_location_loss + fast_rcnn_classification_loss # train with tf.variable_scope("regularization_losses"): regularization_list = [ tf.nn.l2_loss(w.read_value()) * net_config.WEIGHT_DECAY / tf.cast(tf.size(w.read_value()), tf.float32) for w in tf.trainable_variables() if 'gamma' not in w.name and 'beta' not in w.name ] regularization_losses = tf.add_n(regularization_list) total_loss = regularization_losses + fast_rcnn_total_loss + rpn_total_loss global_step = slim.get_or_create_global_step() tf.train.init_from_checkpoint( net_config.CHECKPOINT_DIR, {net_config.NET_NAME + "/": net_config.NET_NAME + "/"}) with tf.variable_scope("optimizer"): lr = tf.train.piecewise_constant(global_step, boundaries=[ np.int64(net_config.BOUNDARY[0]), np.int64(net_config.BOUNDARY[1]) ], values=[ net_config.LEARNING_RATE, net_config.LEARNING_RATE / 10, net_config.LEARNING_RATE / 100 ]) optimizer = tf.train.MomentumOptimizer(lr, momentum=net_config.MOMENTUM) optimizer = tf.contrib.estimator.TowerOptimizer(optimizer) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies([tf.group(*update_ops)]): grads = optimizer.compute_gradients(total_loss) # clip gradients grads = tf.contrib.training.clip_gradient_norms( grads, net_config.CLIP_GRADIENT_NORM) train_op = optimizer.apply_gradients(grads, global_step) # *********************************************************************************************** # * Summary * # *********************************************************************************************** # rpn loss and image tf.summary.scalar('rpn/rpn_location_loss', rpn_location_loss) tf.summary.scalar('rpn/rpn_classification_loss', rpn_classification_loss) tf.summary.scalar('rpn/rpn_total_loss', rpn_total_loss) tf.summary.scalar('fast_rcnn/fast_rcnn_location_loss', fast_rcnn_location_loss) tf.summary.scalar('fast_rcnn/fast_rcnn_classification_loss', fast_rcnn_classification_loss) tf.summary.scalar('fast_rcnn/fast_rcnn_total_loss', fast_rcnn_total_loss) tf.summary.scalar('learning_rate', lr) tf.summary.scalar('total_loss', total_loss) summary_hook = tf.train.SummarySaverHook( save_steps=net_config.SAVE_EVERY_N_STEP, output_dir=net_config.MODLE_DIR, summary_op=tf.summary.merge_all()) if mode == tf.estimator.ModeKeys.TRAIN: return tf.estimator.EstimatorSpec(mode, loss=total_loss, train_op=train_op, training_hooks=[summary_hook])
def model_fn(features, labels, mode, params, config): # *********************************************************************************************** # * share net * # *********************************************************************************************** net_config = params["net_config"] if mode == tf.estimator.ModeKeys.TRAIN: IS_TRAINING = True else: IS_TRAINING = False origin_image_batch = features["image"] image_window = features["image_window"] image_batch = origin_image_batch - net_config.PIXEL_MEANS # there is is_training means that bn is training, so it is important! _, share_net = get_network_byname(inputs=image_batch, config=net_config, is_training=False, reuse=tf.AUTO_REUSE) # *********************************************************************************************** # * fpn * # *********************************************************************************************** feature_pyramid = build_fpn.build_feature_pyramid(share_net, net_config) # *********************************************************************************************** # * rpn * # *********************************************************************************************** gtboxes_and_label_batch = labels.get("gt_box_labels") rpn = build_rpn.RPN(feature_pyramid=feature_pyramid, image_window=image_window, config=net_config) # rpn_proposals_scores==(2000,) rpn_proposals_boxes, rpn_proposals_scores = rpn.rpn_proposals(IS_TRAINING) rpn_location_loss, rpn_classification_loss = rpn.rpn_losses( labels["minibatch_indices"], labels["minibatch_encode_gtboxes"], labels["minibatch_objects_one_hot"]) rpn_total_loss = rpn_classification_loss + rpn_location_loss # *********************************************************************************************** # * Rerference image * # *********************************************************************************************** reference_image = load_reference_image() reference_image = tf.cast(reference_image, tf.float32) reference_image = reference_image - net_config.PIXEL_MEANS _, reference_share_net = get_network_byname(inputs=reference_image, config=net_config, is_training=False, reuse=tf.AUTO_REUSE) reference_feature_pyramid = build_fpn.build_feature_pyramid( reference_share_net, net_config) # average the features of support images # reference_feature_pyramid[key](C*S, H, W, 256)---->(C, 7, 7, 256) with tf.variable_scope('reference_feature_origision'): for key, value in reference_feature_pyramid.items(): reference_feature_pyramid[key] = tf.image.resize_bilinear( reference_feature_pyramid[key], (net_config.ROI_SIZE, net_config.ROI_SIZE)) reference_feature_pyramid[key] = tf.reduce_mean(tf.reshape( reference_feature_pyramid[key], (net_config.NUM_CLASS - 1, net_config.NUM_SUPPROTS, net_config.ROI_SIZE, net_config.ROI_SIZE, 256)), axis=1) # average the features of fpn features average_fpn_feature = [] for key, value in reference_feature_pyramid.items(): average_fpn_feature.append(value) reference_fpn_features = tf.reduce_mean(tf.stack(average_fpn_feature, axis=0), axis=0) # compute the negative features with tf.variable_scope("reference_negative"): with slim.arg_scope( [slim.conv2d], padding="SAME", weights_initializer=tf.glorot_uniform_initializer(), weights_regularizer=slim.l2_regularizer( net_config.WEIGHT_DECAY)): # the shape of positive features is (1, H, W, C*channels) positive_features = tf.reshape( tf.transpose(reference_fpn_features, (1, 2, 0, 3)), (1, net_config.ROI_SIZE, net_config.ROI_SIZE, (net_config.NUM_CLASS - 1) * 256)) # (1, H, W, channels) negative_feature = slim.conv2d(positive_features, num_outputs=256, kernel_size=[3, 3], stride=1) total_refernece_feature = tf.concat( [negative_feature, reference_fpn_features], axis=0) # *********************************************************************************************** # * Fast RCNN * # *********************************************************************************************** fast_rcnn = build_fast_rcnn.FastRCNN( feature_pyramid=feature_pyramid, rpn_proposals_boxes=rpn_proposals_boxes, origin_image=origin_image_batch, gtboxes_and_label=gtboxes_and_label_batch, reference_feature=total_refernece_feature, config=net_config, is_training=False, image_window=image_window) detections = fast_rcnn.fast_rcnn_detection() if DEBUG: rpn_proposals_vision = draw_boxes_with_scores( origin_image_batch[0, :, :, :], rpn_proposals_boxes[0, :50, :], rpn_proposals_scores[0, :50]) fast_rcnn_vision = draw_boxes_with_categories_and_scores( origin_image_batch[0, :, :, :], detections[0, :, :4], detections[0, :, 4], detections[0, :, 5]) tf.summary.image("rpn_proposals_vision", rpn_proposals_vision) tf.summary.image("fast_rcnn_vision", fast_rcnn_vision) fast_rcnn_location_loss, fast_rcnn_classification_loss = fast_rcnn.fast_rcnn_loss( ) fast_rcnn_total_loss = 5.0 * fast_rcnn_classification_loss + fast_rcnn_location_loss # train with tf.variable_scope("regularization_losses"): regularization_list = [ tf.nn.l2_loss(w.read_value()) * net_config.WEIGHT_DECAY / tf.cast(tf.size(w.read_value()), tf.float32) for w in tf.trainable_variables() if 'gamma' not in w.name and 'beta' not in w.name ] regularization_losses = tf.add_n(regularization_list) total_loss = regularization_losses + fast_rcnn_total_loss + rpn_total_loss global_step = slim.get_or_create_global_step() tf.train.init_from_checkpoint( net_config.CHECKPOINT_DIR, {net_config.NET_NAME + "/": net_config.NET_NAME + "/"}) with tf.variable_scope("optimizer"): lr = tf.train.piecewise_constant(global_step, boundaries=[ np.int64(net_config.BOUNDARY[0]), np.int64(net_config.BOUNDARY[1]) ], values=[ net_config.LEARNING_RATE, net_config.LEARNING_RATE / 10, net_config.LEARNING_RATE / 100 ]) optimizer = tf.train.MomentumOptimizer(lr, momentum=net_config.MOMENTUM) optimizer = tf.contrib.estimator.TowerOptimizer(optimizer) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies([tf.group(*update_ops)]): grads = optimizer.compute_gradients(total_loss) for i, (g, v) in enumerate(grads): if g is not None: grads[i] = (tf.clip_by_norm(g, 5.0), v) # clip gradients train_op = optimizer.apply_gradients(grads, global_step) # *********************************************************************************************** # * Summary * # *********************************************************************************************** # rpn loss and image tf.summary.scalar('rpn/rpn_location_loss', rpn_location_loss) tf.summary.scalar('rpn/rpn_classification_loss', rpn_classification_loss) tf.summary.scalar('rpn/rpn_total_loss', rpn_total_loss) tf.summary.scalar('fast_rcnn/fast_rcnn_location_loss', fast_rcnn_location_loss) tf.summary.scalar('fast_rcnn/fast_rcnn_classification_loss', fast_rcnn_classification_loss) tf.summary.scalar('fast_rcnn/fast_rcnn_total_loss', fast_rcnn_total_loss) tf.summary.scalar('learning_rate', lr) tf.summary.scalar('total_loss', total_loss) summary_hook = tf.train.SummarySaverHook( save_steps=net_config.SAVE_EVERY_N_STEP, output_dir=net_config.MODLE_DIR, summary_op=tf.summary.merge_all()) if mode == tf.estimator.ModeKeys.TRAIN: return tf.estimator.EstimatorSpec(mode, loss=total_loss, train_op=train_op, training_hooks=[summary_hook]) if mode == tf.estimator.ModeKeys.EVAL: predicts = { "predict_bbox": detections[:, :, :4], "predict_class_id": detections[:, :, 5], "predict_scores": detections[:, :, 4] } return tf.estimator.EstimatorSpec(mode, loss=total_loss, predictions=predicts) if mode == tf.estimator.ModeKeys.PREDICT: predicts = { "predict_bbox": detections[:, :, :4], "predict_class_id": detections[:, :, 5], "predict_scores": detections[:, :, 4] } return tf.estimator.EstimatorSpec(mode, predictions=predicts)