def build_graph(self, *inputs): num_fpn_level = len(config.ANCHOR_STRIDES_FPN) assert len(config.ANCHOR_SIZES) == num_fpn_level is_training = get_current_tower_context().is_training image = inputs[0] input_anchors = inputs[1:1 + 2 * num_fpn_level] multilevel_anchor_labels = input_anchors[0::2] multilevel_anchor_boxes = input_anchors[1::2] gt_boxes, gt_labels = inputs[11], inputs[12] if config.MODE_MASK: gt_masks = inputs[-1] image = self.preprocess(image) # 1CHW image_shape2d = tf.shape(image)[2:] # h,w c2345 = resnet_fpn_backbone(image, config.RESNET_NUM_BLOCK) p23456 = fpn_model('fpn', c2345) # Multi-Level RPN Proposals multilevel_proposals = [] rpn_loss_collection = [] for lvl in range(num_fpn_level): rpn_label_logits, rpn_box_logits = rpn_head( 'rpn', p23456[lvl], config.FPN_NUM_CHANNEL, len(config.ANCHOR_RATIOS)) with tf.name_scope('FPN_lvl{}'.format(lvl + 2)): anchors = tf.constant(get_all_anchors_fpn()[lvl], name='rpn_anchor_lvl{}'.format(lvl + 2)) anchors, anchor_labels, anchor_boxes = \ self.narrow_to_featuremap(p23456[lvl], anchors, multilevel_anchor_labels[lvl], multilevel_anchor_boxes[lvl]) anchor_boxes_encoded = encode_bbox_target( anchor_boxes, anchors) pred_boxes_decoded = decode_bbox_target( rpn_box_logits, anchors) proposal_boxes, proposal_scores = generate_rpn_proposals( tf.reshape(pred_boxes_decoded, [-1, 4]), tf.reshape(rpn_label_logits, [-1]), image_shape2d, config.TRAIN_FPN_NMS_TOPK if is_training else config.TEST_FPN_NMS_TOPK) multilevel_proposals.append((proposal_boxes, proposal_scores)) if is_training: label_loss, box_loss = rpn_losses(anchor_labels, anchor_boxes_encoded, rpn_label_logits, rpn_box_logits) rpn_loss_collection.extend([label_loss, box_loss]) # Merge proposals from multi levels, pick top K proposal_boxes = tf.concat([x[0] for x in multilevel_proposals], axis=0) # nx4 proposal_scores = tf.concat([x[1] for x in multilevel_proposals], axis=0) # n proposal_topk = tf.minimum( tf.size(proposal_scores), config.TRAIN_FPN_NMS_TOPK if is_training else config.TEST_FPN_NMS_TOPK) proposal_scores, topk_indices = tf.nn.top_k(proposal_scores, k=proposal_topk, sorted=False) proposal_boxes = tf.gather(proposal_boxes, topk_indices) if is_training: rcnn_boxes, rcnn_labels, fg_inds_wrt_gt = sample_fast_rcnn_targets( proposal_boxes, gt_boxes, gt_labels) else: # The boxes to be used to crop RoIs. rcnn_boxes = proposal_boxes roi_feature_fastrcnn = multilevel_roi_align(p23456[:4], rcnn_boxes, 7) fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_2fc_head( 'fastrcnn', roi_feature_fastrcnn, config.NUM_CLASS) if is_training: # rpn loss is already defined above with tf.name_scope('rpn_losses'): rpn_total_label_loss = tf.add_n(rpn_loss_collection[::2], name='label_loss') rpn_total_box_loss = tf.add_n(rpn_loss_collection[1::2], name='box_loss') add_moving_summary(rpn_total_box_loss, rpn_total_label_loss) # fastrcnn loss: matched_gt_boxes = tf.gather(gt_boxes, fg_inds_wrt_gt) fg_inds_wrt_sample = tf.reshape(tf.where(rcnn_labels > 0), [-1]) # fg inds w.r.t all samples fg_sampled_boxes = tf.gather(rcnn_boxes, fg_inds_wrt_sample) fg_fastrcnn_box_logits = tf.gather(fastrcnn_box_logits, fg_inds_wrt_sample) fastrcnn_label_loss, fastrcnn_box_loss = self.fastrcnn_training( image, rcnn_labels, fg_sampled_boxes, matched_gt_boxes, fastrcnn_label_logits, fg_fastrcnn_box_logits) if config.MODE_MASK: # maskrcnn loss fg_labels = tf.gather(rcnn_labels, fg_inds_wrt_sample) roi_feature_maskrcnn = multilevel_roi_align( p23456[:4], fg_sampled_boxes, 14) mask_logits = maskrcnn_upXconv_head('maskrcnn', roi_feature_maskrcnn, config.NUM_CLASS, 4) # #fg x #cat x 28 x 28 matched_gt_masks = tf.gather(gt_masks, fg_inds_wrt_gt) # fg x H x W target_masks_for_fg = crop_and_resize( tf.expand_dims(matched_gt_masks, 1), fg_sampled_boxes, tf.range(tf.size(fg_inds_wrt_gt)), 28, pad_border=False) # fg x 1x28x28 target_masks_for_fg = tf.squeeze(target_masks_for_fg, 1, 'sampled_fg_mask_targets') mrcnn_loss = maskrcnn_loss(mask_logits, fg_labels, target_masks_for_fg) else: mrcnn_loss = 0.0 wd_cost = regularize_cost( '(?:group1|group2|group3|rpn|fastrcnn|maskrcnn)/.*W', l2_regularizer(1e-4), name='wd_cost') total_cost = tf.add_n( rpn_loss_collection + [fastrcnn_label_loss, fastrcnn_box_loss, mrcnn_loss, wd_cost], 'total_cost') add_moving_summary(total_cost, wd_cost) return total_cost else: final_boxes, final_labels = self.fastrcnn_inference( image_shape2d, rcnn_boxes, fastrcnn_label_logits, fastrcnn_box_logits) if config.MODE_MASK: # Cascade inference needs roi transform with refined boxes. roi_feature_maskrcnn = multilevel_roi_align( p23456[:4], final_boxes, 14) mask_logits = maskrcnn_upXconv_head('maskrcnn', roi_feature_maskrcnn, config.NUM_CLASS, 4) # #fg x #cat x 28 x 28 indices = tf.stack([ tf.range(tf.size(final_labels)), tf.to_int32(final_labels) - 1 ], axis=1) final_mask_logits = tf.gather_nd(mask_logits, indices) # #resultx28x28 tf.sigmoid(final_mask_logits, name='final_masks')
def build_graph(self, *inputs): num_fpn_level = len(config.ANCHOR_STRIDES_FPN) assert len(config.ANCHOR_SIZES) == num_fpn_level is_training = get_current_tower_context().is_training image = inputs[0] input_anchors = inputs[1: 1 + 2 * num_fpn_level] multilevel_anchor_labels = input_anchors[0::2] multilevel_anchor_boxes = input_anchors[1::2] gt_boxes, gt_labels = inputs[11], inputs[12] if config.MODE_MASK: gt_masks = inputs[-1] image = self.preprocess(image) # 1CHW image_shape2d = tf.shape(image)[2:] # h,w c2345 = resnet_fpn_backbone(image, config.RESNET_NUM_BLOCK) p23456 = fpn_model('fpn', c2345) # Multi-Level RPN Proposals multilevel_proposals = [] rpn_loss_collection = [] for lvl in range(num_fpn_level): rpn_label_logits, rpn_box_logits = rpn_head( 'rpn', p23456[lvl], config.FPN_NUM_CHANNEL, len(config.ANCHOR_RATIOS)) with tf.name_scope('FPN_lvl{}'.format(lvl + 2)): anchors = tf.constant(get_all_anchors_fpn()[lvl], name='rpn_anchor_lvl{}'.format(lvl + 2)) anchors, anchor_labels, anchor_boxes = \ self.narrow_to_featuremap(p23456[lvl], anchors, multilevel_anchor_labels[lvl], multilevel_anchor_boxes[lvl]) anchor_boxes_encoded = encode_bbox_target(anchor_boxes, anchors) pred_boxes_decoded = decode_bbox_target(rpn_box_logits, anchors) proposal_boxes, proposal_scores = generate_rpn_proposals( tf.reshape(pred_boxes_decoded, [-1, 4]), tf.reshape(rpn_label_logits, [-1]), image_shape2d, config.TRAIN_FPN_NMS_TOPK if is_training else config.TEST_FPN_NMS_TOPK) multilevel_proposals.append((proposal_boxes, proposal_scores)) if is_training: label_loss, box_loss = rpn_losses( anchor_labels, anchor_boxes_encoded, rpn_label_logits, rpn_box_logits) rpn_loss_collection.extend([label_loss, box_loss]) # Merge proposals from multi levels, pick top K proposal_boxes = tf.concat([x[0] for x in multilevel_proposals], axis=0) # nx4 proposal_scores = tf.concat([x[1] for x in multilevel_proposals], axis=0) # n proposal_topk = tf.minimum(tf.size(proposal_scores), config.TRAIN_FPN_NMS_TOPK if is_training else config.TEST_FPN_NMS_TOPK) proposal_scores, topk_indices = tf.nn.top_k(proposal_scores, k=proposal_topk, sorted=False) proposal_boxes = tf.gather(proposal_boxes, topk_indices) if is_training: rcnn_boxes, rcnn_labels, fg_inds_wrt_gt = sample_fast_rcnn_targets( proposal_boxes, gt_boxes, gt_labels) else: # The boxes to be used to crop RoIs. rcnn_boxes = proposal_boxes roi_feature_fastrcnn = multilevel_roi_align(p23456[:4], rcnn_boxes, 7) fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_2fc_head( 'fastrcnn', roi_feature_fastrcnn, config.NUM_CLASS) if is_training: # rpn loss is already defined above with tf.name_scope('rpn_losses'): rpn_total_label_loss = tf.add_n(rpn_loss_collection[::2], name='label_loss') rpn_total_box_loss = tf.add_n(rpn_loss_collection[1::2], name='box_loss') add_moving_summary(rpn_total_box_loss, rpn_total_label_loss) # fastrcnn loss: matched_gt_boxes = tf.gather(gt_boxes, fg_inds_wrt_gt) fg_inds_wrt_sample = tf.reshape(tf.where(rcnn_labels > 0), [-1]) # fg inds w.r.t all samples fg_sampled_boxes = tf.gather(rcnn_boxes, fg_inds_wrt_sample) fg_fastrcnn_box_logits = tf.gather(fastrcnn_box_logits, fg_inds_wrt_sample) fastrcnn_label_loss, fastrcnn_box_loss = self.fastrcnn_training( image, rcnn_labels, fg_sampled_boxes, matched_gt_boxes, fastrcnn_label_logits, fg_fastrcnn_box_logits) if config.MODE_MASK: # maskrcnn loss fg_labels = tf.gather(rcnn_labels, fg_inds_wrt_sample) roi_feature_maskrcnn = multilevel_roi_align( p23456[:4], fg_sampled_boxes, 14) mask_logits = maskrcnn_upXconv_head( 'maskrcnn', roi_feature_maskrcnn, config.NUM_CLASS, 4) # #fg x #cat x 28 x 28 target_masks_for_fg = crop_and_resize( tf.expand_dims(gt_masks, 1), fg_sampled_boxes, fg_inds_wrt_gt, 28, pad_border=False) # fg x 1x28x28 target_masks_for_fg = tf.squeeze(target_masks_for_fg, 1, 'sampled_fg_mask_targets') mrcnn_loss = maskrcnn_loss(mask_logits, fg_labels, target_masks_for_fg) else: mrcnn_loss = 0.0 wd_cost = regularize_cost( '(?:group1|group2|group3|rpn|fastrcnn|maskrcnn)/.*W', l2_regularizer(1e-4), name='wd_cost') total_cost = tf.add_n(rpn_loss_collection + [ fastrcnn_label_loss, fastrcnn_box_loss, mrcnn_loss, wd_cost], 'total_cost') add_moving_summary(total_cost, wd_cost) return total_cost else: final_boxes, final_labels = self.fastrcnn_inference( image_shape2d, rcnn_boxes, fastrcnn_label_logits, fastrcnn_box_logits) if config.MODE_MASK: # Cascade inference needs roi transform with refined boxes. roi_feature_maskrcnn = multilevel_roi_align( p23456[:4], final_boxes, 14) mask_logits = maskrcnn_upXconv_head( 'maskrcnn', roi_feature_maskrcnn, config.NUM_CLASS, 4) # #fg x #cat x 28 x 28 indices = tf.stack([tf.range(tf.size(final_labels)), tf.to_int32(final_labels) - 1], axis=1) final_mask_logits = tf.gather_nd(mask_logits, indices) # #resultx28x28 tf.sigmoid(final_mask_logits, name='final_masks')