def fastrcnn_training(self, image, rcnn_labels, fg_rcnn_boxes, gt_boxes_per_fg, rcnn_label_logits, fg_rcnn_box_logits): """ Args: image (NCHW): rcnn_labels (n): labels for each sampled targets fg_rcnn_boxes (fg x 4): proposal boxes for each sampled foreground targets gt_boxes_per_fg (fg x 4): matching gt boxes for each sampled foreground targets rcnn_label_logits (n): label logits for each sampled targets fg_rcnn_box_logits (fg x 4): box logits for each sampled foreground targets """ with tf.name_scope('fg_sample_patch_viz'): fg_sampled_patches = crop_and_resize( image, fg_rcnn_boxes, tf.zeros(tf.shape(fg_rcnn_boxes)[0], dtype=tf.int32), 300) fg_sampled_patches = tf.transpose(fg_sampled_patches, [0, 2, 3, 1]) fg_sampled_patches = tf.reverse(fg_sampled_patches, axis=[-1]) # BGR->RGB tf.summary.image('viz', fg_sampled_patches, max_outputs=30) encoded_boxes = encode_bbox_target( gt_boxes_per_fg, fg_rcnn_boxes) * tf.constant( config.FASTRCNN_BBOX_REG_WEIGHTS) fastrcnn_label_loss, fastrcnn_box_loss = fastrcnn_losses( rcnn_labels, rcnn_label_logits, encoded_boxes, fg_rcnn_box_logits) return fastrcnn_label_loss, fastrcnn_box_loss
def fastrcnn_training(self, image, rcnn_labels, fg_rcnn_boxes, gt_boxes_per_fg, rcnn_label_logits, fg_rcnn_box_logits): """ Args: image (NCHW): rcnn_labels (n): labels for each sampled targets fg_rcnn_boxes (fg x 4): proposal boxes for each sampled foreground targets gt_boxes_per_fg (fg x 4): matching gt boxes for each sampled foreground targets rcnn_label_logits (n): label logits for each sampled targets fg_rcnn_box_logits (fg x 4): box logits for each sampled foreground targets """ with tf.name_scope('fg_sample_patch_viz'): fg_sampled_patches = crop_and_resize( image, fg_rcnn_boxes, tf.zeros(tf.shape(fg_rcnn_boxes)[0], dtype=tf.int32), 300) fg_sampled_patches = tf.transpose(fg_sampled_patches, [0, 2, 3, 1]) fg_sampled_patches = tf.reverse(fg_sampled_patches, axis=[-1]) # BGR->RGB tf.summary.image('viz', fg_sampled_patches, max_outputs=30) encoded_boxes = encode_bbox_target( gt_boxes_per_fg, fg_rcnn_boxes) * tf.constant(config.FASTRCNN_BBOX_REG_WEIGHTS) fastrcnn_label_loss, fastrcnn_box_loss = fastrcnn_losses( rcnn_labels, rcnn_label_logits, encoded_boxes, fg_rcnn_box_logits) return fastrcnn_label_loss, fastrcnn_box_loss
def _build_graph(self, inputs): is_training = get_current_tower_context().is_training if config.MODE_MASK: image, anchor_labels, anchor_boxes, gt_boxes, gt_labels, gt_masks = inputs else: image, anchor_labels, anchor_boxes, gt_boxes, gt_labels = inputs fm_anchors = self._get_anchors(image) image = self._preprocess(image) # 1CHW image_shape2d = tf.shape(image)[2:] anchor_boxes_encoded = encode_bbox_target(anchor_boxes, fm_anchors) featuremap = pretrained_resnet_conv4(image, config.RESNET_NUM_BLOCK[:3]) rpn_label_logits, rpn_box_logits = rpn_head('rpn', featuremap, 1024, config.NUM_ANCHOR) decoded_boxes = decode_bbox_target(rpn_box_logits, fm_anchors) # fHxfWxNAx4, floatbox proposal_boxes, proposal_scores = generate_rpn_proposals( tf.reshape(decoded_boxes, [-1, 4]), tf.reshape(rpn_label_logits, [-1]), image_shape2d) if is_training: # sample proposal boxes in training rcnn_sampled_boxes, rcnn_labels, fg_inds_wrt_gt = sample_fast_rcnn_targets( proposal_boxes, gt_boxes, gt_labels) boxes_on_featuremap = rcnn_sampled_boxes * (1.0 / config.ANCHOR_STRIDE) else: # use all proposal boxes in inference boxes_on_featuremap = proposal_boxes * (1.0 / config.ANCHOR_STRIDE) roi_resized = roi_align(featuremap, boxes_on_featuremap, 14) # HACK to work around https://github.com/tensorflow/tensorflow/issues/14657 def ff_true(): feature_fastrcnn = resnet_conv5( roi_resized, config.RESNET_NUM_BLOCK[-1]) # nxcx7x7 fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_head( 'fastrcnn', feature_fastrcnn, config.NUM_CLASS) return feature_fastrcnn, fastrcnn_label_logits, fastrcnn_box_logits def ff_false(): ncls = config.NUM_CLASS return tf.zeros([0, 2048, 7, 7]), tf.zeros([0, ncls]), tf.zeros([0, ncls - 1, 4]) feature_fastrcnn, fastrcnn_label_logits, fastrcnn_box_logits = tf.cond( tf.size(boxes_on_featuremap) > 0, ff_true, ff_false) if is_training: # rpn loss rpn_label_loss, rpn_box_loss = rpn_losses(anchor_labels, anchor_boxes_encoded, rpn_label_logits, rpn_box_logits) # fastrcnn loss fg_inds_wrt_sample = tf.reshape(tf.where(rcnn_labels > 0), [-1]) # fg inds w.r.t all samples fg_sampled_boxes = tf.gather(rcnn_sampled_boxes, fg_inds_wrt_sample) with tf.name_scope('fg_sample_patch_viz'): fg_sampled_patches = crop_and_resize( image, fg_sampled_boxes, tf.zeros_like(fg_inds_wrt_sample, dtype=tf.int32), 300) fg_sampled_patches = tf.transpose(fg_sampled_patches, [0, 2, 3, 1]) fg_sampled_patches = tf.reverse(fg_sampled_patches, axis=[-1]) # BGR->RGB tf.summary.image('viz', fg_sampled_patches, max_outputs=30) matched_gt_boxes = tf.gather(gt_boxes, fg_inds_wrt_gt) encoded_boxes = encode_bbox_target( matched_gt_boxes, fg_sampled_boxes) * tf.constant( config.FASTRCNN_BBOX_REG_WEIGHTS) fastrcnn_label_loss, fastrcnn_box_loss = fastrcnn_losses( rcnn_labels, fastrcnn_label_logits, encoded_boxes, tf.gather(fastrcnn_box_logits, fg_inds_wrt_sample)) if config.MODE_MASK: # maskrcnn loss fg_labels = tf.gather(rcnn_labels, fg_inds_wrt_sample) fg_feature = tf.gather(feature_fastrcnn, fg_inds_wrt_sample) mask_logits = maskrcnn_head( 'maskrcnn', fg_feature, config.NUM_CLASS) # #fg x #cat x 14x14 gt_masks_for_fg = tf.gather(gt_masks, fg_inds_wrt_gt) # nfg x H x W target_masks_for_fg = crop_and_resize( tf.expand_dims(gt_masks_for_fg, 1), fg_sampled_boxes, tf.range(tf.size(fg_inds_wrt_gt)), 14) # nfg x 1x14x14 target_masks_for_fg = tf.squeeze(target_masks_for_fg, 1, 'sampled_fg_mask_targets') mrcnn_loss = maskrcnn_loss(mask_logits, fg_labels, target_masks_for_fg) else: mrcnn_loss = 0.0 wd_cost = regularize_cost( '(?:group1|group2|group3|rpn|fastrcnn|maskrcnn)/.*W', l2_regularizer(1e-4), name='wd_cost') self.cost = tf.add_n([ rpn_label_loss, rpn_box_loss, fastrcnn_label_loss, fastrcnn_box_loss, mrcnn_loss, wd_cost ], 'total_cost') add_moving_summary(self.cost, wd_cost) else: label_probs = tf.nn.softmax( fastrcnn_label_logits, name='fastrcnn_all_probs') # #proposal x #Class anchors = tf.tile( tf.expand_dims(proposal_boxes, 1), [1, config.NUM_CLASS - 1, 1]) # #proposal x #Cat x 4 decoded_boxes = decode_bbox_target( fastrcnn_box_logits / tf.constant(config.FASTRCNN_BBOX_REG_WEIGHTS), anchors) decoded_boxes = clip_boxes(decoded_boxes, image_shape2d, name='fastrcnn_all_boxes') # indices: Nx2. Each index into (#proposal, #category) pred_indices, final_probs = fastrcnn_predictions( decoded_boxes, label_probs) final_probs = tf.identity(final_probs, 'final_probs') final_boxes = tf.gather_nd(decoded_boxes, pred_indices, name='final_boxes') final_labels = tf.add(pred_indices[:, 1], 1, name='final_labels') if config.MODE_MASK: # HACK to work around https://github.com/tensorflow/tensorflow/issues/14657 def f1(): roi_resized = roi_align( featuremap, final_boxes * (1.0 / config.ANCHOR_STRIDE), 14) feature_maskrcnn = resnet_conv5( roi_resized, config.RESNET_NUM_BLOCK[-1]) mask_logits = maskrcnn_head( 'maskrcnn', feature_maskrcnn, config.NUM_CLASS) # #result x #cat x 14x14 indices = tf.stack([ tf.range(tf.size(final_labels)), tf.to_int32(final_labels) - 1 ], axis=1) final_mask_logits = tf.gather_nd(mask_logits, indices) # #resultx14x14 return tf.sigmoid(final_mask_logits) final_masks = tf.cond( tf.size(final_probs) > 0, f1, lambda: tf.zeros([0, 14, 14])) tf.identity(final_masks, name='final_masks')
def _build_graph(self, inputs): is_training = get_current_tower_context().is_training if config.MODE_MASK: image, anchor_labels, anchor_boxes, gt_boxes, gt_labels, gt_masks = inputs else: image, anchor_labels, anchor_boxes, gt_boxes, gt_labels = inputs fm_anchors = self._get_anchors(image) image = self._preprocess(image) # 1CHW image_shape2d = tf.shape(image)[2:] anchor_boxes_encoded = encode_bbox_target(anchor_boxes, fm_anchors) featuremap = pretrained_resnet_conv4(image, config.RESNET_NUM_BLOCK[:3]) rpn_label_logits, rpn_box_logits = rpn_head('rpn', featuremap, 1024, config.NUM_ANCHOR) decoded_boxes = decode_bbox_target(rpn_box_logits, fm_anchors) # fHxfWxNAx4, floatbox proposal_boxes, proposal_scores = generate_rpn_proposals( tf.reshape(decoded_boxes, [-1, 4]), tf.reshape(rpn_label_logits, [-1]), image_shape2d) if is_training: # sample proposal boxes in training rcnn_sampled_boxes, rcnn_labels, fg_inds_wrt_gt = sample_fast_rcnn_targets( proposal_boxes, gt_boxes, gt_labels) boxes_on_featuremap = rcnn_sampled_boxes * (1.0 / config.ANCHOR_STRIDE) else: # use all proposal boxes in inference boxes_on_featuremap = proposal_boxes * (1.0 / config.ANCHOR_STRIDE) roi_resized = roi_align(featuremap, boxes_on_featuremap, 14) # HACK to work around https://github.com/tensorflow/tensorflow/issues/14657 def ff_true(): feature_fastrcnn = resnet_conv5(roi_resized, config.RESNET_NUM_BLOCK[-1]) # nxcx7x7 fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_head('fastrcnn', feature_fastrcnn, config.NUM_CLASS) return feature_fastrcnn, fastrcnn_label_logits, fastrcnn_box_logits def ff_false(): ncls = config.NUM_CLASS return tf.zeros([0, 2048, 7, 7]), tf.zeros([0, ncls]), tf.zeros([0, ncls - 1, 4]) feature_fastrcnn, fastrcnn_label_logits, fastrcnn_box_logits = tf.cond( tf.size(boxes_on_featuremap) > 0, ff_true, ff_false) if is_training: # rpn loss rpn_label_loss, rpn_box_loss = rpn_losses( anchor_labels, anchor_boxes_encoded, rpn_label_logits, rpn_box_logits) # fastrcnn loss fg_inds_wrt_sample = tf.reshape(tf.where(rcnn_labels > 0), [-1]) # fg inds w.r.t all samples fg_sampled_boxes = tf.gather(rcnn_sampled_boxes, fg_inds_wrt_sample) with tf.name_scope('fg_sample_patch_viz'): fg_sampled_patches = crop_and_resize( image, fg_sampled_boxes, tf.zeros_like(fg_inds_wrt_sample, dtype=tf.int32), 300) fg_sampled_patches = tf.transpose(fg_sampled_patches, [0, 2, 3, 1]) fg_sampled_patches = tf.reverse(fg_sampled_patches, axis=[-1]) # BGR->RGB tf.summary.image('viz', fg_sampled_patches, max_outputs=30) matched_gt_boxes = tf.gather(gt_boxes, fg_inds_wrt_gt) encoded_boxes = encode_bbox_target( matched_gt_boxes, fg_sampled_boxes) * tf.constant(config.FASTRCNN_BBOX_REG_WEIGHTS) fastrcnn_label_loss, fastrcnn_box_loss = fastrcnn_losses( rcnn_labels, fastrcnn_label_logits, encoded_boxes, tf.gather(fastrcnn_box_logits, fg_inds_wrt_sample)) if config.MODE_MASK: # maskrcnn loss fg_labels = tf.gather(rcnn_labels, fg_inds_wrt_sample) fg_feature = tf.gather(feature_fastrcnn, fg_inds_wrt_sample) mask_logits = maskrcnn_head('maskrcnn', fg_feature, config.NUM_CLASS) # #fg x #cat x 14x14 gt_masks_for_fg = tf.gather(gt_masks, fg_inds_wrt_gt) # nfg x H x W target_masks_for_fg = crop_and_resize( tf.expand_dims(gt_masks_for_fg, 1), fg_sampled_boxes, tf.range(tf.size(fg_inds_wrt_gt)), 14) # nfg x 1x14x14 target_masks_for_fg = tf.squeeze(target_masks_for_fg, 1, 'sampled_fg_mask_targets') mrcnn_loss = maskrcnn_loss(mask_logits, fg_labels, target_masks_for_fg) else: mrcnn_loss = 0.0 wd_cost = regularize_cost( '(?:group1|group2|group3|rpn|fastrcnn|maskrcnn)/.*W', l2_regularizer(1e-4), name='wd_cost') self.cost = tf.add_n([ rpn_label_loss, rpn_box_loss, fastrcnn_label_loss, fastrcnn_box_loss, mrcnn_loss, wd_cost], 'total_cost') add_moving_summary(self.cost, wd_cost) else: label_probs = tf.nn.softmax(fastrcnn_label_logits, name='fastrcnn_all_probs') # #proposal x #Class anchors = tf.tile(tf.expand_dims(proposal_boxes, 1), [1, config.NUM_CLASS - 1, 1]) # #proposal x #Cat x 4 decoded_boxes = decode_bbox_target( fastrcnn_box_logits / tf.constant(config.FASTRCNN_BBOX_REG_WEIGHTS), anchors) decoded_boxes = clip_boxes(decoded_boxes, image_shape2d, name='fastrcnn_all_boxes') # indices: Nx2. Each index into (#proposal, #category) pred_indices, final_probs = fastrcnn_predictions(decoded_boxes, label_probs) final_probs = tf.identity(final_probs, 'final_probs') final_boxes = tf.gather_nd(decoded_boxes, pred_indices, name='final_boxes') final_labels = tf.add(pred_indices[:, 1], 1, name='final_labels') if config.MODE_MASK: # HACK to work around https://github.com/tensorflow/tensorflow/issues/14657 def f1(): roi_resized = roi_align(featuremap, final_boxes * (1.0 / config.ANCHOR_STRIDE), 14) feature_maskrcnn = resnet_conv5(roi_resized, config.RESNET_NUM_BLOCK[-1]) mask_logits = maskrcnn_head( 'maskrcnn', feature_maskrcnn, config.NUM_CLASS) # #result x #cat x 14x14 indices = tf.stack([tf.range(tf.size(final_labels)), tf.to_int32(final_labels) - 1], axis=1) final_mask_logits = tf.gather_nd(mask_logits, indices) # #resultx14x14 return tf.sigmoid(final_mask_logits) final_masks = tf.cond(tf.size(final_probs) > 0, f1, lambda: tf.zeros([0, 14, 14])) tf.identity(final_masks, name='final_masks')
def _build_graph(self, inputs): is_training = get_current_tower_context().is_training image, anchor_labels, anchor_boxes, gt_boxes, gt_labels = inputs fm_anchors = self._get_anchors(image) image = self._preprocess(image) # 1CHW image_shape2d = tf.shape(image)[2:] anchor_boxes_encoded = encode_bbox_target(anchor_boxes, fm_anchors) featuremap = pretrained_resnet_conv4(image, config.RESNET_NUM_BLOCK[:3]) rpn_label_logits, rpn_box_logits = rpn_head('rpn', featuremap, 1024, config.NUM_ANCHOR) decoded_boxes = decode_bbox_target(rpn_box_logits, fm_anchors) # fHxfWxNAx4, floatbox proposal_boxes, proposal_scores = generate_rpn_proposals( tf.reshape(decoded_boxes, [-1, 4]), tf.reshape(rpn_label_logits, [-1]), image_shape2d) if is_training: # sample proposal boxes in training rcnn_sampled_boxes, rcnn_labels, fg_inds_wrt_gt = sample_fast_rcnn_targets( proposal_boxes, gt_boxes, gt_labels) boxes_on_featuremap = rcnn_sampled_boxes * (1.0 / config.ANCHOR_STRIDE) else: # use all proposal boxes in inference boxes_on_featuremap = proposal_boxes * (1.0 / config.ANCHOR_STRIDE) roi_resized = roi_align(featuremap, boxes_on_featuremap, 14) feature_fastrcnn = resnet_conv5(roi_resized, config.RESNET_NUM_BLOCK[-1]) # nxcx7x7 fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_head( 'fastrcnn', feature_fastrcnn, config.NUM_CLASS) if is_training: # rpn loss rpn_label_loss, rpn_box_loss = rpn_losses(anchor_labels, anchor_boxes_encoded, rpn_label_logits, rpn_box_logits) # fastrcnn loss fg_inds_wrt_sample = tf.reshape(tf.where(rcnn_labels > 0), [-1]) # fg inds w.r.t all samples fg_sampled_boxes = tf.gather(rcnn_sampled_boxes, fg_inds_wrt_sample) with tf.name_scope('fg_sample_patch_viz'): fg_sampled_patches = crop_and_resize( image, fg_sampled_boxes, tf.zeros_like(fg_inds_wrt_sample, dtype=tf.int32), 300) fg_sampled_patches = tf.transpose(fg_sampled_patches, [0, 2, 3, 1]) tf.summary.image('viz', fg_sampled_patches, max_outputs=30) matched_gt_boxes = tf.gather(gt_boxes, fg_inds_wrt_gt) encoded_boxes = encode_bbox_target( matched_gt_boxes, fg_sampled_boxes) * tf.constant( config.FASTRCNN_BBOX_REG_WEIGHTS) fastrcnn_label_loss, fastrcnn_box_loss = fastrcnn_losses( rcnn_labels, fastrcnn_label_logits, encoded_boxes, tf.gather(fastrcnn_box_logits, fg_inds_wrt_sample)) wd_cost = regularize_cost( '(?:group1|group2|group3|rpn|fastrcnn)/.*W', l2_regularizer(1e-4), name='wd_cost') self.cost = tf.add_n([ rpn_label_loss, rpn_box_loss, fastrcnn_label_loss, fastrcnn_box_loss, wd_cost ], 'total_cost') add_moving_summary(self.cost, wd_cost) else: label_probs = tf.nn.softmax( fastrcnn_label_logits, name='fastrcnn_all_probs') # #proposal x #Class anchors = tf.tile( tf.expand_dims(proposal_boxes, 1), [1, config.NUM_CLASS - 1, 1]) # #proposal x #Cat x 4 decoded_boxes = decode_bbox_target( fastrcnn_box_logits / tf.constant(config.FASTRCNN_BBOX_REG_WEIGHTS), anchors) decoded_boxes = clip_boxes(decoded_boxes, image_shape2d, name='fastrcnn_all_boxes') # indices: Nx2. Each index into (#proposal, #category) pred_indices, final_probs = fastrcnn_predictions( decoded_boxes, label_probs) final_probs = tf.identity(final_probs, 'final_probs') tf.gather_nd(decoded_boxes, pred_indices, name='final_boxes') tf.add(pred_indices[:, 1], 1, name='final_labels')
def _build_graph(self, inputs): is_training = get_current_tower_context().is_training image, anchor_labels, anchor_boxes, gt_boxes, gt_labels = inputs image = tf.expand_dims(image, 0) # FSxFSxNAx4 (FS=MAX_SIZE//ANCHOR_STRIDE) with tf.name_scope('anchors'): all_anchors = tf.constant(get_all_anchors(), name='all_anchors', dtype=tf.float32) fm_anchors = tf.slice( all_anchors, [0, 0, 0, 0], tf.stack([ tf.shape(image)[1] // config.ANCHOR_STRIDE, tf.shape(image)[2] // config.ANCHOR_STRIDE, -1, -1 ]), name='fm_anchors') anchor_boxes_encoded = encode_bbox_target(anchor_boxes, fm_anchors) image = image_preprocess(image, bgr=True) image = tf.transpose(image, [0, 3, 1, 2]) # resnet50 featuremap = pretrained_resnet_conv4(image, [3, 4, 6]) rpn_label_logits, rpn_box_logits = rpn_head(featuremap) rpn_label_loss, rpn_box_loss = rpn_losses(anchor_labels, anchor_boxes_encoded, rpn_label_logits, rpn_box_logits) decoded_boxes = decode_bbox_target( rpn_box_logits, fm_anchors) # (fHxfWxNA)x4, floatbox proposal_boxes, proposal_scores = generate_rpn_proposals( decoded_boxes, tf.reshape(rpn_label_logits, [-1]), tf.shape(image)[2:]) if is_training: rcnn_sampled_boxes, rcnn_encoded_boxes, rcnn_labels = sample_fast_rcnn_targets( proposal_boxes, gt_boxes, gt_labels) boxes_on_featuremap = rcnn_sampled_boxes * (1.0 / config.ANCHOR_STRIDE) roi_resized = roi_align(featuremap, boxes_on_featuremap, 14) feature_fastrcnn = resnet_conv5(roi_resized) # nxc fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_head( feature_fastrcnn, config.NUM_CLASS) fastrcnn_label_loss, fastrcnn_box_loss = fastrcnn_losses( rcnn_labels, rcnn_encoded_boxes, fastrcnn_label_logits, fastrcnn_box_logits) wd_cost = regularize_cost( '(?:group1|group2|group3|rpn|fastrcnn)/.*W', l2_regularizer(1e-4), name='wd_cost') self.cost = tf.add_n([ rpn_label_loss, rpn_box_loss, fastrcnn_label_loss, fastrcnn_box_loss, wd_cost ], 'total_cost') for k in self.cost, wd_cost: add_moving_summary(k) else: roi_resized = roi_align( featuremap, proposal_boxes * (1.0 / config.ANCHOR_STRIDE), 14) feature_fastrcnn = resnet_conv5(roi_resized) # nxc label_logits, fastrcnn_box_logits = fastrcnn_head( feature_fastrcnn, config.NUM_CLASS) label_probs = tf.nn.softmax(label_logits, name='fastrcnn_all_probs') # NP, labels = tf.argmax(label_logits, axis=1) fg_ind, fg_box_logits = fastrcnn_predict_boxes( labels, fastrcnn_box_logits) fg_label_probs = tf.gather(label_probs, fg_ind, name='fastrcnn_fg_probs') fg_boxes = tf.gather(proposal_boxes, fg_ind) fg_box_logits = fg_box_logits / tf.constant( config.FASTRCNN_BBOX_REG_WEIGHTS) decoded_boxes = decode_bbox_target(fg_box_logits, fg_boxes) # Nfx4, floatbox decoded_boxes = tf.identity(decoded_boxes, name='fastrcnn_fg_boxes')
def build_graph(self, *inputs): num_fpn_level = len(config.ANCHOR_STRIDES_FPN) assert len(config.ANCHOR_SIZES) == num_fpn_level is_training = get_current_tower_context().is_training image = inputs[0] input_anchors = inputs[1:1 + 2 * num_fpn_level] multilevel_anchor_labels = input_anchors[0::2] multilevel_anchor_boxes = input_anchors[1::2] gt_boxes, gt_labels = inputs[11], inputs[12] if config.MODE_MASK: gt_masks = inputs[-1] image = self.preprocess(image) # 1CHW image_shape2d = tf.shape(image)[2:] # h,w c2345 = resnet_fpn_backbone(image, config.RESNET_NUM_BLOCK) p23456 = fpn_model('fpn', c2345) # Multi-Level RPN Proposals multilevel_proposals = [] rpn_loss_collection = [] for lvl in range(num_fpn_level): rpn_label_logits, rpn_box_logits = rpn_head( 'rpn', p23456[lvl], config.FPN_NUM_CHANNEL, len(config.ANCHOR_RATIOS)) with tf.name_scope('FPN_lvl{}'.format(lvl + 2)): anchors = tf.constant(get_all_anchors_fpn()[lvl], name='rpn_anchor_lvl{}'.format(lvl + 2)) anchors, anchor_labels, anchor_boxes = \ self.narrow_to_featuremap(p23456[lvl], anchors, multilevel_anchor_labels[lvl], multilevel_anchor_boxes[lvl]) anchor_boxes_encoded = encode_bbox_target( anchor_boxes, anchors) pred_boxes_decoded = decode_bbox_target( rpn_box_logits, anchors) proposal_boxes, proposal_scores = generate_rpn_proposals( tf.reshape(pred_boxes_decoded, [-1, 4]), tf.reshape(rpn_label_logits, [-1]), image_shape2d, config.TRAIN_FPN_NMS_TOPK if is_training else config.TEST_FPN_NMS_TOPK) multilevel_proposals.append((proposal_boxes, proposal_scores)) if is_training: label_loss, box_loss = rpn_losses(anchor_labels, anchor_boxes_encoded, rpn_label_logits, rpn_box_logits) rpn_loss_collection.extend([label_loss, box_loss]) # Merge proposals from multi levels, pick top K proposal_boxes = tf.concat([x[0] for x in multilevel_proposals], axis=0) # nx4 proposal_scores = tf.concat([x[1] for x in multilevel_proposals], axis=0) # n proposal_topk = tf.minimum( tf.size(proposal_scores), config.TRAIN_FPN_NMS_TOPK if is_training else config.TEST_FPN_NMS_TOPK) proposal_scores, topk_indices = tf.nn.top_k(proposal_scores, k=proposal_topk, sorted=False) proposal_boxes = tf.gather(proposal_boxes, topk_indices) if is_training: rcnn_boxes, rcnn_labels, fg_inds_wrt_gt = sample_fast_rcnn_targets( proposal_boxes, gt_boxes, gt_labels) else: # The boxes to be used to crop RoIs. rcnn_boxes = proposal_boxes roi_feature_fastrcnn = multilevel_roi_align(p23456[:4], rcnn_boxes, 7) fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_2fc_head( 'fastrcnn', roi_feature_fastrcnn, config.NUM_CLASS) if is_training: # rpn loss is already defined above with tf.name_scope('rpn_losses'): rpn_total_label_loss = tf.add_n(rpn_loss_collection[::2], name='label_loss') rpn_total_box_loss = tf.add_n(rpn_loss_collection[1::2], name='box_loss') add_moving_summary(rpn_total_box_loss, rpn_total_label_loss) # fastrcnn loss: matched_gt_boxes = tf.gather(gt_boxes, fg_inds_wrt_gt) fg_inds_wrt_sample = tf.reshape(tf.where(rcnn_labels > 0), [-1]) # fg inds w.r.t all samples fg_sampled_boxes = tf.gather(rcnn_boxes, fg_inds_wrt_sample) fg_fastrcnn_box_logits = tf.gather(fastrcnn_box_logits, fg_inds_wrt_sample) fastrcnn_label_loss, fastrcnn_box_loss = self.fastrcnn_training( image, rcnn_labels, fg_sampled_boxes, matched_gt_boxes, fastrcnn_label_logits, fg_fastrcnn_box_logits) if config.MODE_MASK: # maskrcnn loss fg_labels = tf.gather(rcnn_labels, fg_inds_wrt_sample) roi_feature_maskrcnn = multilevel_roi_align( p23456[:4], fg_sampled_boxes, 14) mask_logits = maskrcnn_upXconv_head('maskrcnn', roi_feature_maskrcnn, config.NUM_CLASS, 4) # #fg x #cat x 28 x 28 matched_gt_masks = tf.gather(gt_masks, fg_inds_wrt_gt) # fg x H x W target_masks_for_fg = crop_and_resize( tf.expand_dims(matched_gt_masks, 1), fg_sampled_boxes, tf.range(tf.size(fg_inds_wrt_gt)), 28, pad_border=False) # fg x 1x28x28 target_masks_for_fg = tf.squeeze(target_masks_for_fg, 1, 'sampled_fg_mask_targets') mrcnn_loss = maskrcnn_loss(mask_logits, fg_labels, target_masks_for_fg) else: mrcnn_loss = 0.0 wd_cost = regularize_cost( '(?:group1|group2|group3|rpn|fastrcnn|maskrcnn)/.*W', l2_regularizer(1e-4), name='wd_cost') total_cost = tf.add_n( rpn_loss_collection + [fastrcnn_label_loss, fastrcnn_box_loss, mrcnn_loss, wd_cost], 'total_cost') add_moving_summary(total_cost, wd_cost) return total_cost else: final_boxes, final_labels = self.fastrcnn_inference( image_shape2d, rcnn_boxes, fastrcnn_label_logits, fastrcnn_box_logits) if config.MODE_MASK: # Cascade inference needs roi transform with refined boxes. roi_feature_maskrcnn = multilevel_roi_align( p23456[:4], final_boxes, 14) mask_logits = maskrcnn_upXconv_head('maskrcnn', roi_feature_maskrcnn, config.NUM_CLASS, 4) # #fg x #cat x 28 x 28 indices = tf.stack([ tf.range(tf.size(final_labels)), tf.to_int32(final_labels) - 1 ], axis=1) final_mask_logits = tf.gather_nd(mask_logits, indices) # #resultx28x28 tf.sigmoid(final_mask_logits, name='final_masks')
def build_graph(self, *inputs): is_training = get_current_tower_context().is_training if config.MODE_MASK: image, anchor_labels, anchor_boxes, gt_boxes, gt_labels, gt_masks = inputs else: image, anchor_labels, anchor_boxes, gt_boxes, gt_labels = inputs image = self.preprocess(image) # 1CHW featuremap = resnet_c4_backbone(image, config.RESNET_NUM_BLOCK[:3]) rpn_label_logits, rpn_box_logits = rpn_head('rpn', featuremap, 1024, config.NUM_ANCHOR) fm_anchors, anchor_labels, anchor_boxes = self.narrow_to_featuremap( featuremap, get_all_anchors(), anchor_labels, anchor_boxes) anchor_boxes_encoded = encode_bbox_target(anchor_boxes, fm_anchors) image_shape2d = tf.shape(image)[2:] # h,w pred_boxes_decoded = decode_bbox_target( rpn_box_logits, fm_anchors) # fHxfWxNAx4, floatbox proposal_boxes, proposal_scores = generate_rpn_proposals( tf.reshape(pred_boxes_decoded, [-1, 4]), tf.reshape(rpn_label_logits, [-1]), image_shape2d, config.TRAIN_PRE_NMS_TOPK if is_training else config.TEST_PRE_NMS_TOPK, config.TRAIN_POST_NMS_TOPK if is_training else config.TEST_POST_NMS_TOPK) if is_training: # sample proposal boxes in training rcnn_boxes, rcnn_labels, fg_inds_wrt_gt = sample_fast_rcnn_targets( proposal_boxes, gt_boxes, gt_labels) else: # The boxes to be used to crop RoIs. # Use all proposal boxes in inference rcnn_boxes = proposal_boxes boxes_on_featuremap = rcnn_boxes * (1.0 / config.ANCHOR_STRIDE) roi_resized = roi_align(featuremap, boxes_on_featuremap, 14) # HACK to work around https://github.com/tensorflow/tensorflow/issues/14657 # which was fixed in TF 1.6 def ff_true(): feature_fastrcnn = resnet_conv5( roi_resized, config.RESNET_NUM_BLOCK[-1]) # nxcx7x7 feature_gap = GlobalAvgPooling('gap', feature_fastrcnn, data_format='channels_first') fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_outputs( 'fastrcnn', feature_gap, config.NUM_CLASS) # Return C5 feature to be shared with mask branch return feature_fastrcnn, fastrcnn_label_logits, fastrcnn_box_logits def ff_false(): ncls = config.NUM_CLASS return tf.zeros([0, 2048, 7, 7]), tf.zeros([0, ncls]), tf.zeros([0, ncls - 1, 4]) if get_tf_version_number() >= 1.6: feature_fastrcnn, fastrcnn_label_logits, fastrcnn_box_logits = ff_true( ) else: logger.warn("This example may drop support for TF < 1.6 soon.") feature_fastrcnn, fastrcnn_label_logits, fastrcnn_box_logits = tf.cond( tf.size(boxes_on_featuremap) > 0, ff_true, ff_false) if is_training: # rpn loss rpn_label_loss, rpn_box_loss = rpn_losses(anchor_labels, anchor_boxes_encoded, rpn_label_logits, rpn_box_logits) # fastrcnn loss matched_gt_boxes = tf.gather(gt_boxes, fg_inds_wrt_gt) fg_inds_wrt_sample = tf.reshape(tf.where(rcnn_labels > 0), [-1]) # fg inds w.r.t all samples fg_sampled_boxes = tf.gather(rcnn_boxes, fg_inds_wrt_sample) fg_fastrcnn_box_logits = tf.gather(fastrcnn_box_logits, fg_inds_wrt_sample) fastrcnn_label_loss, fastrcnn_box_loss = self.fastrcnn_training( image, rcnn_labels, fg_sampled_boxes, matched_gt_boxes, fastrcnn_label_logits, fg_fastrcnn_box_logits) if config.MODE_MASK: # maskrcnn loss fg_labels = tf.gather(rcnn_labels, fg_inds_wrt_sample) # In training, mask branch shares the same C5 feature. fg_feature = tf.gather(feature_fastrcnn, fg_inds_wrt_sample) mask_logits = maskrcnn_upXconv_head( 'maskrcnn', fg_feature, config.NUM_CLASS, num_convs=0) # #fg x #cat x 14x14 matched_gt_masks = tf.gather(gt_masks, fg_inds_wrt_gt) # nfg x H x W target_masks_for_fg = crop_and_resize( tf.expand_dims(matched_gt_masks, 1), fg_sampled_boxes, tf.range(tf.size(fg_inds_wrt_gt)), 14, pad_border=False) # nfg x 1x14x14 target_masks_for_fg = tf.squeeze(target_masks_for_fg, 1, 'sampled_fg_mask_targets') mrcnn_loss = maskrcnn_loss(mask_logits, fg_labels, target_masks_for_fg) else: mrcnn_loss = 0.0 wd_cost = regularize_cost( '(?:group1|group2|group3|rpn|fastrcnn|maskrcnn)/.*W', l2_regularizer(1e-4), name='wd_cost') total_cost = tf.add_n([ rpn_label_loss, rpn_box_loss, fastrcnn_label_loss, fastrcnn_box_loss, mrcnn_loss, wd_cost ], 'total_cost') add_moving_summary(total_cost, wd_cost) return total_cost else: final_boxes, final_labels = self.fastrcnn_inference( image_shape2d, rcnn_boxes, fastrcnn_label_logits, fastrcnn_box_logits) if config.MODE_MASK: # HACK to work around https://github.com/tensorflow/tensorflow/issues/14657 def f1(): roi_resized = roi_align( featuremap, final_boxes * (1.0 / config.ANCHOR_STRIDE), 14) feature_maskrcnn = resnet_conv5( roi_resized, config.RESNET_NUM_BLOCK[-1]) mask_logits = maskrcnn_upXconv_head( 'maskrcnn', feature_maskrcnn, config.NUM_CLASS, 0) # #result x #cat x 14x14 indices = tf.stack([ tf.range(tf.size(final_labels)), tf.to_int32(final_labels) - 1 ], axis=1) final_mask_logits = tf.gather_nd(mask_logits, indices) # #resultx14x14 return tf.sigmoid(final_mask_logits) final_masks = tf.cond( tf.size(final_labels) > 0, f1, lambda: tf.zeros([0, 14, 14])) tf.identity(final_masks, name='final_masks')
def _build_graph(self, inputs): is_training = get_current_tower_context().is_training if config.USE_SECOND_HEAD: # USE SECOND HEAD if config.MODE_MASK: if config.PROVIDE_BOXES_AS_INPUT: image, anchor_labels, anchor_boxes, gt_boxes, gt_labels, second_gt_labels, gt_masks, input_boxes = inputs else: image, anchor_labels, anchor_boxes, gt_boxes, gt_labels, second_gt_labels, gt_masks = inputs else: image, anchor_labels, anchor_boxes, gt_boxes, gt_labels, second_gt_labels = inputs else: # NO SECOND HEAD if config.MODE_MASK: if config.PROVIDE_BOXES_AS_INPUT: image, anchor_labels, anchor_boxes, gt_boxes, gt_labels, gt_masks, input_boxes = inputs else: image, anchor_labels, anchor_boxes, gt_boxes, gt_labels, gt_masks = inputs else: image, anchor_labels, anchor_boxes, gt_boxes, gt_labels = inputs second_gt_labels = None fm_anchors = self._get_anchors(image) image_shape2d_before_resize = tf.shape(image)[:2] image = self._preprocess(image) # 1CHW image_shape2d = tf.shape(image)[2:] anchor_boxes_encoded = encode_bbox_target(anchor_boxes, fm_anchors) featuremap = pretrained_resnet_conv4(image, config.RESNET_NUM_BLOCK[:3]) rpn_label_logits, rpn_box_logits = rpn_head('rpn', featuremap, 1024, config.NUM_ANCHOR) decoded_boxes = decode_bbox_target(rpn_box_logits, fm_anchors) # fHxfWxNAx4, floatbox proposal_boxes, proposal_scores = generate_rpn_proposals( tf.reshape(decoded_boxes, [-1, 4]), tf.reshape(rpn_label_logits, [-1]), image_shape2d) if config.PROVIDE_BOXES_AS_INPUT: old_height, old_width = image_shape2d_before_resize[ 0], image_shape2d_before_resize[1] new_height, new_width = image_shape2d[0], image_shape2d[1] height_scale = new_height / old_height width_scale = new_width / old_width # TODO: check the order of dimensions! scale = tf.stack( [width_scale, height_scale, width_scale, height_scale], axis=0) proposal_boxes = input_boxes * tf.cast(scale, tf.float32) secondclassification_labels = None if is_training: # sample proposal boxes in training rcnn_sampled_boxes, rcnn_labels, fg_inds_wrt_gt, bg_inds = sample_fast_rcnn_targets( proposal_boxes, gt_boxes, gt_labels) if config.USE_SECOND_HEAD: secondclassification_labels = tf.stop_gradient( tf.concat([ tf.gather(second_gt_labels, fg_inds_wrt_gt), tf.zeros_like(bg_inds, dtype=tf.int64) ], axis=0, name='second_sampled_labels')) boxes_on_featuremap = rcnn_sampled_boxes * (1.0 / config.ANCHOR_STRIDE) else: # use all proposal boxes in inference boxes_on_featuremap = proposal_boxes * (1.0 / config.ANCHOR_STRIDE) rcnn_labels = rcnn_sampled_boxes = fg_inds_wrt_gt = None roi_resized = roi_align(featuremap, boxes_on_featuremap, 14) # HACK to work around https://github.com/tensorflow/tensorflow/issues/14657 def ff_true(): feature_fastrcnn = resnet_conv5( roi_resized, config.RESNET_NUM_BLOCK[-1]) # nxcx7x7 if config.TRAIN_HEADS_ONLY: feature_fastrcnn = tf.stop_gradient(feature_fastrcnn) fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_head( 'fastrcnn', feature_fastrcnn, config.NUM_CLASS) return feature_fastrcnn, fastrcnn_label_logits, fastrcnn_box_logits def ff_false(): ncls = config.NUM_CLASS return tf.zeros([0, 2048, 7, 7]), tf.zeros([0, ncls]), tf.zeros([0, ncls - 1, 4]) feature_fastrcnn, fastrcnn_label_logits, fastrcnn_box_logits = tf.cond( tf.size(boxes_on_featuremap) > 0, ff_true, ff_false) feature_fastrcnn_pooled = tf.reduce_mean(feature_fastrcnn, axis=[2, 3]) if config.USE_SECOND_HEAD: def if_true(): secondclassification_label_logits = secondclassification_head( 'secondclassification', feature_fastrcnn, config.SECOND_NUM_CLASS) return secondclassification_label_logits def if_false(): ncls = config.SECOND_NUM_CLASS return tf.zeros([0, ncls]) secondclassification_label_logits = tf.cond( tf.size(boxes_on_featuremap) > 0, if_true, if_false) else: secondclassification_label_logits = None if is_training: # rpn loss rpn_label_loss, rpn_box_loss = rpn_losses(anchor_labels, anchor_boxes_encoded, rpn_label_logits, rpn_box_logits) # fastrcnn loss fg_inds_wrt_sample = tf.reshape(tf.where(rcnn_labels > 0), [-1]) # fg inds w.r.t all samples fg_sampled_boxes = tf.gather(rcnn_sampled_boxes, fg_inds_wrt_sample) with tf.name_scope('fg_sample_patch_viz'): fg_sampled_patches = crop_and_resize( image, fg_sampled_boxes, tf.zeros_like(fg_inds_wrt_sample, dtype=tf.int32), 300) fg_sampled_patches = tf.transpose(fg_sampled_patches, [0, 2, 3, 1]) fg_sampled_patches = tf.reverse(fg_sampled_patches, axis=[-1]) # BGR->RGB tf.summary.image('viz', fg_sampled_patches, max_outputs=30) matched_gt_boxes = tf.gather(gt_boxes, fg_inds_wrt_gt) encoded_boxes = encode_bbox_target( matched_gt_boxes, fg_sampled_boxes) * tf.constant( config.FASTRCNN_BBOX_REG_WEIGHTS) fastrcnn_label_loss, fastrcnn_box_loss = fastrcnn_losses( rcnn_labels, fastrcnn_label_logits, encoded_boxes, tf.gather(fastrcnn_box_logits, fg_inds_wrt_sample)) if config.USE_SECOND_HEAD: secondclassification_label_loss = secondclassification_losses( secondclassification_labels, secondclassification_label_logits) else: secondclassification_label_loss = None if config.MODE_MASK: # maskrcnn loss fg_labels = tf.gather(rcnn_labels, fg_inds_wrt_sample) fg_feature = tf.gather(feature_fastrcnn, fg_inds_wrt_sample) mask_logits = maskrcnn_head( 'maskrcnn', fg_feature, config.NUM_CLASS) # #fg x #cat x 14x14 gt_masks_for_fg = tf.gather(gt_masks, fg_inds_wrt_gt) # nfg x H x W target_masks_for_fg = crop_and_resize( tf.expand_dims(gt_masks_for_fg, 1), fg_sampled_boxes, tf.range(tf.size(fg_inds_wrt_gt)), 14) # nfg x 1x14x14 target_masks_for_fg = tf.squeeze(target_masks_for_fg, 1, 'sampled_fg_mask_targets') mrcnn_loss = maskrcnn_loss(mask_logits, fg_labels, target_masks_for_fg) else: mrcnn_loss = 0.0 wd_cost = regularize_cost( '(?:group1|group2|group3|rpn|fastrcnn|maskrcnn)/.*W', l2_regularizer(1e-4), name='wd_cost') if config.TRAIN_HEADS_ONLY: # don't include the rpn loss if config.USE_SECOND_HEAD: self.cost = tf.add_n([ fastrcnn_label_loss, fastrcnn_box_loss, secondclassification_label_loss, mrcnn_loss, wd_cost ], 'total_cost') else: self.cost = tf.add_n([ fastrcnn_label_loss, fastrcnn_box_loss, mrcnn_loss, wd_cost ], 'total_cost') else: if config.USE_SECOND_HEAD: self.cost = tf.add_n([ rpn_label_loss, rpn_box_loss, fastrcnn_label_loss, fastrcnn_box_loss, secondclassification_label_loss, mrcnn_loss, wd_cost ], 'total_cost') else: self.cost = tf.add_n([ rpn_label_loss, rpn_box_loss, fastrcnn_label_loss, fastrcnn_box_loss, mrcnn_loss, wd_cost ], 'total_cost') add_moving_summary(self.cost, wd_cost) else: label_probs = tf.nn.softmax( fastrcnn_label_logits, name='fastrcnn_all_probs') # #proposal x #Class anchors = tf.tile( tf.expand_dims(proposal_boxes, 1), [1, config.NUM_CLASS - 1, 1]) # #proposal x #Cat x 4 decoded_boxes = decode_bbox_target( fastrcnn_box_logits / tf.constant(config.FASTRCNN_BBOX_REG_WEIGHTS), anchors) decoded_boxes = clip_boxes(decoded_boxes, image_shape2d, name='fastrcnn_all_boxes') # indices: Nx2. Each index into (#proposal, #category) pred_indices, final_probs = fastrcnn_predictions( decoded_boxes, label_probs) final_probs = tf.identity(final_probs, 'final_probs') final_boxes = tf.gather_nd(decoded_boxes, pred_indices, name='final_boxes') pred_indices_only_boxes = pred_indices[:, 0] # bugfix (?): changing this 1->0 fixes quite some issues. Why? final_labels = tf.add(pred_indices[:, 1], 1, name='final_labels') # pred_indices[:, 1] final_posterior = tf.gather(label_probs, pred_indices_only_boxes, name='final_posterior') feature_fastrcnn_pooled = tf.gather(feature_fastrcnn_pooled, pred_indices_only_boxes, name="feature_fastrcnn_pooled") if config.USE_SECOND_HEAD: second_label_probs = tf.nn.softmax( secondclassification_label_logits, name='secondclassification_all_probs' ) # #proposal x #Class second_final_posterior = tf.gather( second_label_probs, pred_indices_only_boxes, name='second_final_posterior') second_final_labels = tf.add(tf.argmax(second_final_posterior, -1), 1, name='second_final_labels') if config.MODE_MASK: # HACK to work around https://github.com/tensorflow/tensorflow/issues/14657 def f1(): roi_resized = roi_align( featuremap, final_boxes * (1.0 / config.ANCHOR_STRIDE), 14) feature_maskrcnn = resnet_conv5( roi_resized, config.RESNET_NUM_BLOCK[-1]) mask_logits = maskrcnn_head( 'maskrcnn', feature_maskrcnn, config.NUM_CLASS) # #result x #cat x 14x14 indices = tf.stack([ tf.range(tf.size(final_labels)), tf.to_int32(final_labels) - 1 ], axis=1) final_mask_logits = tf.gather_nd(mask_logits, indices) # #resultx14x14 return tf.sigmoid(final_mask_logits) final_masks = tf.cond( tf.size(final_probs) > 0, f1, lambda: tf.zeros([0, 14, 14])) tf.identity(final_masks, name='final_masks')
def build_graph(self, *inputs): num_fpn_level = len(config.ANCHOR_STRIDES_FPN) assert len(config.ANCHOR_SIZES) == num_fpn_level is_training = get_current_tower_context().is_training image = inputs[0] input_anchors = inputs[1: 1 + 2 * num_fpn_level] multilevel_anchor_labels = input_anchors[0::2] multilevel_anchor_boxes = input_anchors[1::2] gt_boxes, gt_labels = inputs[11], inputs[12] if config.MODE_MASK: gt_masks = inputs[-1] image = self.preprocess(image) # 1CHW image_shape2d = tf.shape(image)[2:] # h,w c2345 = resnet_fpn_backbone(image, config.RESNET_NUM_BLOCK) p23456 = fpn_model('fpn', c2345) # Multi-Level RPN Proposals multilevel_proposals = [] rpn_loss_collection = [] for lvl in range(num_fpn_level): rpn_label_logits, rpn_box_logits = rpn_head( 'rpn', p23456[lvl], config.FPN_NUM_CHANNEL, len(config.ANCHOR_RATIOS)) with tf.name_scope('FPN_lvl{}'.format(lvl + 2)): anchors = tf.constant(get_all_anchors_fpn()[lvl], name='rpn_anchor_lvl{}'.format(lvl + 2)) anchors, anchor_labels, anchor_boxes = \ self.narrow_to_featuremap(p23456[lvl], anchors, multilevel_anchor_labels[lvl], multilevel_anchor_boxes[lvl]) anchor_boxes_encoded = encode_bbox_target(anchor_boxes, anchors) pred_boxes_decoded = decode_bbox_target(rpn_box_logits, anchors) proposal_boxes, proposal_scores = generate_rpn_proposals( tf.reshape(pred_boxes_decoded, [-1, 4]), tf.reshape(rpn_label_logits, [-1]), image_shape2d, config.TRAIN_FPN_NMS_TOPK if is_training else config.TEST_FPN_NMS_TOPK) multilevel_proposals.append((proposal_boxes, proposal_scores)) if is_training: label_loss, box_loss = rpn_losses( anchor_labels, anchor_boxes_encoded, rpn_label_logits, rpn_box_logits) rpn_loss_collection.extend([label_loss, box_loss]) # Merge proposals from multi levels, pick top K proposal_boxes = tf.concat([x[0] for x in multilevel_proposals], axis=0) # nx4 proposal_scores = tf.concat([x[1] for x in multilevel_proposals], axis=0) # n proposal_topk = tf.minimum(tf.size(proposal_scores), config.TRAIN_FPN_NMS_TOPK if is_training else config.TEST_FPN_NMS_TOPK) proposal_scores, topk_indices = tf.nn.top_k(proposal_scores, k=proposal_topk, sorted=False) proposal_boxes = tf.gather(proposal_boxes, topk_indices) if is_training: rcnn_boxes, rcnn_labels, fg_inds_wrt_gt = sample_fast_rcnn_targets( proposal_boxes, gt_boxes, gt_labels) else: # The boxes to be used to crop RoIs. rcnn_boxes = proposal_boxes roi_feature_fastrcnn = multilevel_roi_align(p23456[:4], rcnn_boxes, 7) fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_2fc_head( 'fastrcnn', roi_feature_fastrcnn, config.NUM_CLASS) if is_training: # rpn loss is already defined above with tf.name_scope('rpn_losses'): rpn_total_label_loss = tf.add_n(rpn_loss_collection[::2], name='label_loss') rpn_total_box_loss = tf.add_n(rpn_loss_collection[1::2], name='box_loss') add_moving_summary(rpn_total_box_loss, rpn_total_label_loss) # fastrcnn loss: matched_gt_boxes = tf.gather(gt_boxes, fg_inds_wrt_gt) fg_inds_wrt_sample = tf.reshape(tf.where(rcnn_labels > 0), [-1]) # fg inds w.r.t all samples fg_sampled_boxes = tf.gather(rcnn_boxes, fg_inds_wrt_sample) fg_fastrcnn_box_logits = tf.gather(fastrcnn_box_logits, fg_inds_wrt_sample) fastrcnn_label_loss, fastrcnn_box_loss = self.fastrcnn_training( image, rcnn_labels, fg_sampled_boxes, matched_gt_boxes, fastrcnn_label_logits, fg_fastrcnn_box_logits) if config.MODE_MASK: # maskrcnn loss fg_labels = tf.gather(rcnn_labels, fg_inds_wrt_sample) roi_feature_maskrcnn = multilevel_roi_align( p23456[:4], fg_sampled_boxes, 14) mask_logits = maskrcnn_upXconv_head( 'maskrcnn', roi_feature_maskrcnn, config.NUM_CLASS, 4) # #fg x #cat x 28 x 28 target_masks_for_fg = crop_and_resize( tf.expand_dims(gt_masks, 1), fg_sampled_boxes, fg_inds_wrt_gt, 28, pad_border=False) # fg x 1x28x28 target_masks_for_fg = tf.squeeze(target_masks_for_fg, 1, 'sampled_fg_mask_targets') mrcnn_loss = maskrcnn_loss(mask_logits, fg_labels, target_masks_for_fg) else: mrcnn_loss = 0.0 wd_cost = regularize_cost( '(?:group1|group2|group3|rpn|fastrcnn|maskrcnn)/.*W', l2_regularizer(1e-4), name='wd_cost') total_cost = tf.add_n(rpn_loss_collection + [ fastrcnn_label_loss, fastrcnn_box_loss, mrcnn_loss, wd_cost], 'total_cost') add_moving_summary(total_cost, wd_cost) return total_cost else: final_boxes, final_labels = self.fastrcnn_inference( image_shape2d, rcnn_boxes, fastrcnn_label_logits, fastrcnn_box_logits) if config.MODE_MASK: # Cascade inference needs roi transform with refined boxes. roi_feature_maskrcnn = multilevel_roi_align( p23456[:4], final_boxes, 14) mask_logits = maskrcnn_upXconv_head( 'maskrcnn', roi_feature_maskrcnn, config.NUM_CLASS, 4) # #fg x #cat x 28 x 28 indices = tf.stack([tf.range(tf.size(final_labels)), tf.to_int32(final_labels) - 1], axis=1) final_mask_logits = tf.gather_nd(mask_logits, indices) # #resultx28x28 tf.sigmoid(final_mask_logits, name='final_masks')
def build_graph(self, *inputs): is_training = get_current_tower_context().is_training if config.MODE_MASK: image, anchor_labels, anchor_boxes, gt_boxes, gt_labels, gt_masks = inputs else: image, anchor_labels, anchor_boxes, gt_boxes, gt_labels = inputs image = self.preprocess(image) # 1CHW featuremap = resnet_c4_backbone(image, config.RESNET_NUM_BLOCK[:3]) rpn_label_logits, rpn_box_logits = rpn_head('rpn', featuremap, 1024, config.NUM_ANCHOR) fm_anchors, anchor_labels, anchor_boxes = self.narrow_to_featuremap( featuremap, get_all_anchors(), anchor_labels, anchor_boxes) anchor_boxes_encoded = encode_bbox_target(anchor_boxes, fm_anchors) image_shape2d = tf.shape(image)[2:] # h,w pred_boxes_decoded = decode_bbox_target(rpn_box_logits, fm_anchors) # fHxfWxNAx4, floatbox proposal_boxes, proposal_scores = generate_rpn_proposals( tf.reshape(pred_boxes_decoded, [-1, 4]), tf.reshape(rpn_label_logits, [-1]), image_shape2d, config.TRAIN_PRE_NMS_TOPK if is_training else config.TEST_PRE_NMS_TOPK, config.TRAIN_POST_NMS_TOPK if is_training else config.TEST_POST_NMS_TOPK) if is_training: # sample proposal boxes in training rcnn_boxes, rcnn_labels, fg_inds_wrt_gt = sample_fast_rcnn_targets( proposal_boxes, gt_boxes, gt_labels) else: # The boxes to be used to crop RoIs. # Use all proposal boxes in inference rcnn_boxes = proposal_boxes boxes_on_featuremap = rcnn_boxes * (1.0 / config.ANCHOR_STRIDE) roi_resized = roi_align(featuremap, boxes_on_featuremap, 14) # HACK to work around https://github.com/tensorflow/tensorflow/issues/14657 # which was fixed in TF 1.6 def ff_true(): feature_fastrcnn = resnet_conv5(roi_resized, config.RESNET_NUM_BLOCK[-1]) # nxcx7x7 feature_gap = GlobalAvgPooling('gap', feature_fastrcnn, data_format='channels_first') fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_outputs('fastrcnn', feature_gap, config.NUM_CLASS) # Return C5 feature to be shared with mask branch return feature_fastrcnn, fastrcnn_label_logits, fastrcnn_box_logits def ff_false(): ncls = config.NUM_CLASS return tf.zeros([0, 2048, 7, 7]), tf.zeros([0, ncls]), tf.zeros([0, ncls - 1, 4]) if get_tf_version_number() >= 1.6: feature_fastrcnn, fastrcnn_label_logits, fastrcnn_box_logits = ff_true() else: logger.warn("This example may drop support for TF < 1.6 soon.") feature_fastrcnn, fastrcnn_label_logits, fastrcnn_box_logits = tf.cond( tf.size(boxes_on_featuremap) > 0, ff_true, ff_false) if is_training: # rpn loss rpn_label_loss, rpn_box_loss = rpn_losses( anchor_labels, anchor_boxes_encoded, rpn_label_logits, rpn_box_logits) # fastrcnn loss matched_gt_boxes = tf.gather(gt_boxes, fg_inds_wrt_gt) fg_inds_wrt_sample = tf.reshape(tf.where(rcnn_labels > 0), [-1]) # fg inds w.r.t all samples fg_sampled_boxes = tf.gather(rcnn_boxes, fg_inds_wrt_sample) fg_fastrcnn_box_logits = tf.gather(fastrcnn_box_logits, fg_inds_wrt_sample) fastrcnn_label_loss, fastrcnn_box_loss = self.fastrcnn_training( image, rcnn_labels, fg_sampled_boxes, matched_gt_boxes, fastrcnn_label_logits, fg_fastrcnn_box_logits) if config.MODE_MASK: # maskrcnn loss fg_labels = tf.gather(rcnn_labels, fg_inds_wrt_sample) # In training, mask branch shares the same C5 feature. fg_feature = tf.gather(feature_fastrcnn, fg_inds_wrt_sample) mask_logits = maskrcnn_upXconv_head( 'maskrcnn', fg_feature, config.NUM_CLASS, num_convs=0) # #fg x #cat x 14x14 target_masks_for_fg = crop_and_resize( tf.expand_dims(gt_masks, 1), fg_sampled_boxes, fg_inds_wrt_gt, 14, pad_border=False) # nfg x 1x14x14 target_masks_for_fg = tf.squeeze(target_masks_for_fg, 1, 'sampled_fg_mask_targets') mrcnn_loss = maskrcnn_loss(mask_logits, fg_labels, target_masks_for_fg) else: mrcnn_loss = 0.0 wd_cost = regularize_cost( '(?:group1|group2|group3|rpn|fastrcnn|maskrcnn)/.*W', l2_regularizer(1e-4), name='wd_cost') total_cost = tf.add_n([ rpn_label_loss, rpn_box_loss, fastrcnn_label_loss, fastrcnn_box_loss, mrcnn_loss, wd_cost], 'total_cost') add_moving_summary(total_cost, wd_cost) return total_cost else: final_boxes, final_labels = self.fastrcnn_inference( image_shape2d, rcnn_boxes, fastrcnn_label_logits, fastrcnn_box_logits) if config.MODE_MASK: # HACK to work around https://github.com/tensorflow/tensorflow/issues/14657 def f1(): roi_resized = roi_align(featuremap, final_boxes * (1.0 / config.ANCHOR_STRIDE), 14) feature_maskrcnn = resnet_conv5(roi_resized, config.RESNET_NUM_BLOCK[-1]) mask_logits = maskrcnn_upXconv_head( 'maskrcnn', feature_maskrcnn, config.NUM_CLASS, 0) # #result x #cat x 14x14 indices = tf.stack([tf.range(tf.size(final_labels)), tf.to_int32(final_labels) - 1], axis=1) final_mask_logits = tf.gather_nd(mask_logits, indices) # #resultx14x14 return tf.sigmoid(final_mask_logits) final_masks = tf.cond(tf.size(final_labels) > 0, f1, lambda: tf.zeros([0, 14, 14])) tf.identity(final_masks, name='final_masks')
def _build_graph(self, inputs): is_training = get_current_tower_context().is_training image, anchor_labels, anchor_boxes, gt_boxes, gt_labels = inputs fm_anchors = self._get_anchors(image) image = self._preprocess(image) anchor_boxes_encoded = encode_bbox_target(anchor_boxes, fm_anchors) featuremap = pretrained_resnet_conv4(image, config.RESNET_NUM_BLOCK[:3]) rpn_label_logits, rpn_box_logits = rpn_head(featuremap, 1024, config.NUM_ANCHOR) rpn_label_loss, rpn_box_loss = rpn_losses(anchor_labels, anchor_boxes_encoded, rpn_label_logits, rpn_box_logits) decoded_boxes = decode_bbox_target( rpn_box_logits, fm_anchors, config.ANCHOR_STRIDE) # (fHxfWxNA)x4, floatbox proposal_boxes, proposal_scores = generate_rpn_proposals( decoded_boxes, tf.reshape(rpn_label_logits, [-1]), tf.shape(image)[2:]) if is_training: rcnn_sampled_boxes, rcnn_encoded_boxes, rcnn_labels = sample_fast_rcnn_targets( proposal_boxes, gt_boxes, gt_labels) boxes_on_featuremap = rcnn_sampled_boxes * (1.0 / config.ANCHOR_STRIDE) roi_resized = roi_align(featuremap, boxes_on_featuremap, 14) feature_fastrcnn = resnet_conv5_gap( roi_resized, config.RESNET_NUM_BLOCK[-1]) # nxc fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_head( feature_fastrcnn, config.NUM_CLASS) fastrcnn_label_loss, fastrcnn_box_loss = fastrcnn_losses( rcnn_labels, rcnn_encoded_boxes, fastrcnn_label_logits, fastrcnn_box_logits) wd_cost = regularize_cost( '(?:group1|group2|group3|rpn|fastrcnn)/.*W', l2_regularizer(1e-4), name='wd_cost') self.cost = tf.add_n([ rpn_label_loss, rpn_box_loss, fastrcnn_label_loss, fastrcnn_box_loss, wd_cost ], 'total_cost') for k in self.cost, wd_cost: add_moving_summary(k) else: roi_resized = roi_align( featuremap, proposal_boxes * (1.0 / config.ANCHOR_STRIDE), 14) feature_fastrcnn = resnet_conv5_gap( roi_resized, config.RESNET_NUM_BLOCK[-1]) # nxc label_logits, fastrcnn_box_logits = fastrcnn_head( feature_fastrcnn, config.NUM_CLASS) label_probs = tf.nn.softmax(label_logits, name='fastrcnn_all_probs') # NP, labels = tf.argmax(label_logits, axis=1) fg_ind, fg_box_logits = fastrcnn_predict_boxes( labels, fastrcnn_box_logits) fg_label_probs = tf.gather(label_probs, fg_ind, name='fastrcnn_fg_probs') fg_boxes = tf.gather(proposal_boxes, fg_ind) fg_box_logits = fg_box_logits / tf.constant( config.FASTRCNN_BBOX_REG_WEIGHTS) decoded_boxes = decode_bbox_target( fg_box_logits, fg_boxes, config.ANCHOR_STRIDE) # Nfx4, floatbox decoded_boxes = tf.identity(decoded_boxes, name='fastrcnn_fg_boxes')