def loss(self, inputs, truth_boxes, truth_labels, truth_instances): self.rpn_cls_loss, self.rpn_reg_loss = \ rpn_loss(self.rpn_logits_flat, self.rpn_deltas_flat, self.rpn_labels, self.rpn_label_weights, self.rpn_targets, self.rpn_target_weights) self.rcnn_cls_loss, self.rcnn_reg_loss = \ rcnn_loss(self.rcnn_logits, self.rcnn_deltas, self.rcnn_labels, self.rcnn_targets) if self.train_box_only: self.total_loss = \ self.rpn_cls_loss + \ self.rpn_reg_loss + \ self.rcnn_cls_loss + \ self.rcnn_reg_loss else: self.mask_cls_loss = \ mask_loss(self.mask_logits, self.mask_labels, self.mask_instances) self.total_loss = \ self.rpn_cls_loss + \ self.rpn_reg_loss + \ self.rcnn_cls_loss + \ self.rcnn_reg_loss + \ self.mask_cls_loss # self.total_loss = self.rpn_cls_loss + self.rpn_reg_loss + self.rcnn_cls_loss + self.rcnn_reg_loss # self.total_loss = self.rpn_cls_loss + self.rpn_reg_loss + self.mask_cls_loss # self.total_loss = self.rcnn_cls_loss + self.rcnn_reg_loss + self.mask_cls_loss return self.total_loss
def loss_train_rcnn(self, inputs, truth_boxes, truth_labels, truth_instances): self.rpn_cls_loss, self.rpn_reg_loss = \ rpn_loss(self.rpn_logits_flat, self.rpn_deltas_flat, self.rpn_labels, self.rpn_label_weights, self.rpn_targets, self.rpn_target_weights) self.rcnn_cls_loss, self.rcnn_reg_loss = \ rcnn_loss(self.rcnn_logits, self.rcnn_deltas, self.rcnn_labels, self.rcnn_targets) self.total_loss = self.rpn_cls_loss + self.rpn_reg_loss + self.rcnn_cls_loss + self.rcnn_reg_loss return self.total_loss
def loss(self, inputs, truth_boxes, truth_labels, truth_instances): cfg = self.cfg self.rpn_cls_loss, self.rpn_reg_loss = \ rpn_loss( self.rpn_logits_flat, self.rpn_deltas_flat, self.rpn_labels, self.rpn_label_weights, self.rpn_targets, self.rpn_target_weights) self.rcnn_cls_loss, self.rcnn_reg_loss = \ rcnn_loss(self.rcnn_logits, self.rcnn_deltas, self.rcnn_labels, self.rcnn_targets) ## self.mask_cls_loss = Variable(torch.cuda.FloatTensor(1).zero_()).sum() self.mask_cls_loss = \ mask_loss( self.mask_logits, self.mask_labels, self.mask_instances ) self.total_loss = self.rpn_cls_loss + self.rpn_reg_loss \ + self.rcnn_cls_loss + self.rcnn_reg_loss \ + self.mask_cls_loss return self.total_loss