예제 #1
0
 def execute(self, x1, x2):
     x1 = nn.relu(self.fc1a(x1))
     x1 = self.fc1b(x1)
     x2 = nn.relu(self.fc2a(x2))
     x2 = self.fc2b(x2)
     x = jt.contrib.concat((x1, x2), dim=0)
     return nn.log_softmax(x, dim=1)
예제 #2
0
 def execute(self, x):
     x = nn.relu(nn.max_pool2d(self.conv1(x), 2))
     x = nn.relu(nn.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
     x = x.view(((-1), 320))
     x = nn.relu(self.fc1(x))
     x = self.fc2(x)
     return nn.log_softmax(x, dim=1), x
예제 #3
0
    def focal_conf_loss(self, conf_data, conf_t):
        """
        Focal loss as described in https://arxiv.org/pdf/1708.02002.pdf
        Adapted from https://github.com/clcarwin/focal_loss_pyjt/blob/master/focalloss.py
        Note that this uses softmax and not the original sigmoid from the paper.
        """
        conf_t = conf_t.view(-1) # [batch_size*num_priors]
        conf_data = conf_data.view(-1, conf_data.shape[-1]) # [batch_size*num_priors, num_classes]

        # Ignore neutral samples (class < 0)
        keep = (conf_t >= 0).float()
        conf_t[conf_t < 0] = 0 # so that gather doesn't drum up a fuss

        logpt = nn.log_softmax(conf_data, dim=-1)
        logpt = logpt.gather(1, conf_t.unsqueeze(-1))
        logpt = logpt.view(-1)
        pt    = logpt.exp()

        # I adapted the alpha_t calculation here from
        # https://github.com/pyjt/pyjt/blob/master/modules/detectron/softmax_focal_loss_op.cu
        # You'd think you want all the alphas to sum to one, but in the original implementation they
        # just give background an alpha of 1-alpha and each forground an alpha of alpha.
        background = (conf_t == 0).float()
        at = (1 - cfg.focal_loss_alpha) * background + cfg.focal_loss_alpha * (1 - background)

        loss = -at * (1 - pt) ** cfg.focal_loss_gamma * logpt

        # See comment above for keep
        return cfg.conf_alpha * (loss * keep).sum()
예제 #4
0
    def execute(self):
        x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr
        x = nn.dropout(x, self.dropout)
        x = x_0 = nn.relu(self.lins[0](x))

        for conv in self.convs:
            x = nn.dropout(x, self.dropout)
            x = conv(x, x_0, edge_index, edge_weight)
            x = nn.relu(x)

        x = nn.dropout(x, self.dropout)
        x = self.lins[1](x)

        return nn.log_softmax(x, dim=-1)
예제 #5
0
 def execute(self):
     x, edge_index = data.x, data.edge_index
     x = self.conv1(x, edge_index)
     return nn.log_softmax(x, dim=1)
예제 #6
0
 def execute(self):
     x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr
     x = nn.relu(self.conv1(x, edge_index, edge_weight))
     x = nn.dropout(x)
     x = self.conv2(x, edge_index, edge_weight)
     return nn.log_softmax(x, dim=1)
예제 #7
0
    def execute(self, net, predictions, targets, masks, num_crowds):
        """Multibox Loss
        Args:
            predictions (tuple): A tuple containing loc preds, conf preds,
            mask preds, and prior boxes from SSD net.
                loc shape: jt.size(batch_size,num_priors,4)
                conf shape: jt.size(batch_size,num_priors,num_classes)
                masks shape: jt.size(batch_size,num_priors,mask_dim)
                priors shape: jt.size(num_priors,4)
                proto* shape: jt.size(batch_size,mask_h,mask_w,mask_dim)

            targets (list<tensor>): Ground truth boxes and labels for a batch,
                shape: [batch_size][num_objs,5] (last idx is the label).

            masks (list<tensor>): Ground truth masks for each object in each image,
                shape: [batch_size][num_objs,im_height,im_width]

            num_crowds (list<int>): Number of crowd annotations per batch. The crowd
                annotations should be the last num_crowds elements of targets and masks.
            
            * Only if mask_type == lincomb
        """

        loc_data  = predictions['loc']
        conf_data = predictions['conf']
        mask_data = predictions['mask']
        priors    = predictions['priors']

        if cfg.mask_type == mask_type.lincomb:
            proto_data = predictions['proto']

        score_data = predictions['score'] if cfg.use_mask_scoring   else None   
        inst_data  = predictions['inst']  if cfg.use_instance_coeff else None
        
        labels = [None] * len(targets) # Used in sem segm loss

        batch_size = loc_data.shape[0]
        num_priors = priors.shape[0]
        num_classes = self.num_classes

        # Match priors (default boxes) and ground truth boxes
        # These tensors will be created with the same device as loc_data
        loc_t = jt.empty((batch_size, num_priors, 4),dtype=loc_data.dtype)
        gt_box_t = jt.empty((batch_size, num_priors, 4),dtype=loc_data.dtype)
        conf_t = jt.empty((batch_size, num_priors)).int32()
        idx_t = jt.empty((batch_size, num_priors)).int32()

        if cfg.use_class_existence_loss:
            class_existence_t = jt.empty((batch_size, num_classes-1),dtype=loc_data.dtype)

        # jt.sync(list(predictions.values()))

        for idx in range(batch_size):
            truths      = targets[idx][:, :-1]
            labels[idx] = targets[idx][:, -1].int32()

            if cfg.use_class_existence_loss:
                # Construct a one-hot vector for each object and collapse it into an existence vector with max
                # Also it's fine to include the crowd annotations here
                class_existence_t[idx,:] = jt.eye(num_classes-1)[labels[idx]].max(dim=0)[0]

            # Split the crowd annotations because they come bundled in
            cur_crowds = num_crowds[idx]
            if cur_crowds > 0:
                split = lambda x: (x[-cur_crowds:], x[:-cur_crowds])
                crowd_boxes, truths = split(truths)

                # We don't use the crowd labels or masks
                _, labels[idx] = split(labels[idx])
                _, masks[idx]  = split(masks[idx])
            else:
                crowd_boxes = None

            
            match(self.pos_threshold, self.neg_threshold,
                  truths, priors, labels[idx], crowd_boxes,
                  loc_t, conf_t, idx_t, idx, loc_data[idx])
                  
            gt_box_t[idx,:,:] = truths[idx_t[idx]]

        # wrap targets
        loc_t.stop_grad()
        conf_t.stop_grad()
        idx_t.stop_grad()

        pos = conf_t > 0
        num_pos = pos.sum(dim=1, keepdims=True)
        
        # Shape: [batch,num_priors,4]
        pos_idx = pos.unsqueeze(pos.ndim).expand_as(loc_data)
        
        losses = {}

        # Localization Loss (Smooth L1)
        if cfg.train_boxes:
            loc_p = loc_data[pos_idx].view(-1, 4)
            loc_t = loc_t[pos_idx].view(-1, 4)
            # print(loc_t)
            losses['B'] = nn.smooth_l1_loss(loc_p, loc_t, reduction='sum') * cfg.bbox_alpha

        if cfg.train_masks:
            if cfg.mask_type == mask_type.direct:
                if cfg.use_gt_bboxes:
                    pos_masks = []
                    for idx in range(batch_size):
                        pos_masks.append(masks[idx][idx_t[idx, pos[idx]]])
                    masks_t = jt.contrib.concat(pos_masks, 0)
                    masks_p = mask_data[pos, :].view(-1, cfg.mask_dim)
                    losses['M'] = nn.bce_loss(jt.clamp(masks_p, 0, 1), masks_t, size_average=False) * cfg.mask_alpha
                else:
                    losses['M'] = self.direct_mask_loss(pos_idx, idx_t, loc_data, mask_data, priors, masks)
            elif cfg.mask_type == mask_type.lincomb:
                ret = self.lincomb_mask_loss(pos, idx_t, loc_data, mask_data, priors, proto_data, masks, gt_box_t, score_data, inst_data, labels)
                if cfg.use_maskiou:
                    loss, maskiou_targets = ret
                else:
                    loss = ret
                losses.update(loss)

                if cfg.mask_proto_loss is not None:
                    if cfg.mask_proto_loss == 'l1':
                        losses['P'] = jt.mean(jt.abs(proto_data)) / self.l1_expected_area * self.l1_alpha
                    elif cfg.mask_proto_loss == 'disj':
                        losses['P'] = -jt.mean(jt.max(nn.log_softmax(proto_data, dim=-1), dim=-1)[0])

        # Confidence loss
        if cfg.use_focal_loss:
            if cfg.use_sigmoid_focal_loss:
                losses['C'] = self.focal_conf_sigmoid_loss(conf_data, conf_t)
            elif cfg.use_objectness_score:
                losses['C'] = self.focal_conf_objectness_loss(conf_data, conf_t)
            else:
                losses['C'] = self.focal_conf_loss(conf_data, conf_t)
        else:
            if cfg.use_objectness_score:
                losses['C'] = self.conf_objectness_loss(conf_data, conf_t, batch_size, loc_p, loc_t, priors)
            else:
                losses['C'] = self.ohem_conf_loss(conf_data, conf_t, pos, batch_size)

        # Mask IoU Loss
        if cfg.use_maskiou and maskiou_targets is not None:
            losses['I'] = self.mask_iou_loss(net, maskiou_targets)

        # These losses also don't depend on anchors
        if cfg.use_class_existence_loss:
            losses['E'] = self.class_existence_loss(predictions['classes'], class_existence_t)
        if cfg.use_semantic_segmentation_loss:
            losses['S'] = self.semantic_segmentation_loss(predictions['segm'], masks, labels)

        # Divide all losses by the number of positives.
        # Don't do it for loss[P] because that doesn't depend on the anchors.
        total_num_pos = num_pos.sum().float()
        for k in losses:
            if k not in ('P', 'E', 'S'):
                losses[k] /= total_num_pos
            else:
                losses[k] /= batch_size

        # Loss Key:
        #  - B: Box Localization Loss
        #  - C: Class Confidence Loss
        #  - M: Mask Loss
        #  - P: Prototype Loss
        #  - D: Coefficient Diversity Loss
        #  - E: Class Existence Loss
        #  - S: Semantic Segmentation Loss
        return losses