Example #1
0
    def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
        """Classification loss (NLL)
        targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
        """
        assert 'pred_logits' in outputs
        src_logits = outputs['pred_logits']

        idx = self._get_src_permutation_idx(indices)
        target_classes_o = torch.cat(
            [t["labels"][J] for t, (_, J) in zip(targets, indices)])
        target_classes = torch.full(src_logits.shape[:2],
                                    self.num_classes,
                                    dtype=torch.int64,
                                    device=src_logits.device)
        target_classes[idx] = target_classes_o

        loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes,
                                  self.empty_weight)
        losses = {'loss_ce': loss_ce}

        if log:
            # TODO this should probably be a separate loss, not hacked in this one here
            losses['class_error'] = 100 - accuracy(src_logits[idx],
                                                   target_classes_o)[0]
        return losses
Example #2
0
    def loss_obj_labels(self,
                        outputs,
                        targets,
                        indices,
                        num_interactions,
                        log=True):
        assert 'pred_obj_logits' in outputs
        src_logits = outputs['pred_obj_logits']

        idx = self._get_src_permutation_idx(indices)
        target_classes_o = torch.cat(
            [t['obj_labels'][J] for t, (_, J) in zip(targets, indices)])
        target_classes = torch.full(src_logits.shape[:2],
                                    self.num_obj_classes,
                                    dtype=torch.int64,
                                    device=src_logits.device)
        target_classes[idx] = target_classes_o

        loss_obj_ce = F.cross_entropy(src_logits.transpose(1, 2),
                                      target_classes, self.empty_weight)
        losses = {'loss_obj_ce': loss_obj_ce}

        if log:
            losses['obj_class_error'] = 100 - accuracy(src_logits[idx],
                                                       target_classes_o)[0]
        return losses
Example #3
0
    def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
        """Classification loss (NLL)
        targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
        """
        assert 'pred_logits' in outputs
        src_logits = outputs['pred_logits']

        idx = self._get_src_permutation_idx(indices)
        target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
        target_classes = torch.full(src_logits.shape[:2], self.num_classes,
                                    dtype=torch.int64, device=src_logits.device)
        target_classes[idx] = target_classes_o

        target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1],
                                            dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device)
        target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)

        target_classes_onehot = target_classes_onehot[:,:,:-1]
        loss_ce = sigmoid_focal_loss(src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2) * src_logits.shape[1]
        losses = {'loss_ce': loss_ce}

        if log:
            # TODO this should probably be a separate loss, not hacked in this one here
            losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0]
        return losses
Example #4
0
    def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
        """Classification loss (NLL)
        targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
        """
        assert 'human_pred_logits' in outputs
        assert 'object_pred_logits' in outputs
        assert 'action_pred_logits' in outputs
        human_src_logits = outputs['human_pred_logits']
        object_src_logits = outputs['object_pred_logits']
        action_src_logits = outputs['action_pred_logits']

        idx = self._get_src_permutation_idx(indices)

        human_target_classes_o = torch.cat(
            [t["human_labels"][J] for t, (_, J) in zip(targets, indices)])
        object_target_classes_o = torch.cat(
            [t["object_labels"][J] for t, (_, J) in zip(targets, indices)])
        action_target_classes_o = torch.cat(
            [t["action_labels"][J] for t, (_, J) in zip(targets, indices)])

        human_target_classes = torch.full(human_src_logits.shape[:2],
                                          num_humans,
                                          dtype=torch.int64,
                                          device=human_src_logits.device)
        human_target_classes[idx] = human_target_classes_o

        object_target_classes = torch.full(object_src_logits.shape[:2],
                                           self.num_classes,
                                           dtype=torch.int64,
                                           device=object_src_logits.device)
        object_target_classes[idx] = object_target_classes_o

        action_target_classes = torch.full(action_src_logits.shape[:2],
                                           num_actions,
                                           dtype=torch.int64,
                                           device=action_src_logits.device)
        action_target_classes[idx] = action_target_classes_o

        human_loss_ce = F.cross_entropy(human_src_logits.transpose(1, 2),
                                        human_target_classes,
                                        self.human_empty_weight)
        object_loss_ce = F.cross_entropy(object_src_logits.transpose(1, 2),
                                         object_target_classes,
                                         self.object_empty_weight)
        action_loss_ce = F.cross_entropy(action_src_logits.transpose(1, 2),
                                         action_target_classes,
                                         self.action_empty_weight)
        loss_ce = human_loss_ce + object_loss_ce + 2 * action_loss_ce
        losses = {
            'loss_ce': loss_ce,
            'human_loss_ce': human_loss_ce,
            'object_loss_ce': object_loss_ce,
            'action_loss_ce': action_loss_ce
        }

        if log:
            losses['class_error'] = 100 - accuracy(action_src_logits[idx],
                                                   action_target_classes_o)[0]
        return losses
Example #5
0
    def loss_labels_att(self, outputs, targets, indices, num_boxes, log=True):
        """Classification loss (NLL + Loss attenuation)
        targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
        outputs must contain the mean pred_logits and the variance pred_logits_var
        """
        if 'pred_logits_var' not in outputs:
            return self.loss_labels(outputs, targets, indices, num_boxes, log)

        assert 'pred_logits' in outputs

        src_logits = outputs['pred_logits']
        src_logits_var = outputs['pred_logits_var']

        src_logits_var = torch.sqrt(torch.exp(src_logits_var))

        univariate_normal_dists = distributions.normal.Normal(
            src_logits, scale=src_logits_var)
        pred_class_stochastic_logits = univariate_normal_dists.rsample(
            (self.cls_var_num_samples, ))
        pred_class_stochastic_logits = pred_class_stochastic_logits.view(
            pred_class_stochastic_logits.shape[1],
            pred_class_stochastic_logits.shape[2] *
            pred_class_stochastic_logits.shape[0], -1)

        idx = self._get_src_permutation_idx(indices)
        target_classes_o = torch.cat(
            [t["labels"][J] for t, (_, J) in zip(targets, indices)])
        target_classes = torch.full(src_logits.shape[:2],
                                    self.num_classes,
                                    dtype=torch.int64,
                                    device=src_logits.device)
        target_classes[idx] = target_classes_o
        target_classes = torch.unsqueeze(target_classes, dim=0)
        target_classes = torch.repeat_interleave(target_classes,
                                                 self.cls_var_num_samples,
                                                 dim=0)
        target_classes = target_classes.view(
            target_classes.shape[1],
            target_classes.shape[2] * target_classes.shape[0])

        loss_ce = F.cross_entropy(pred_class_stochastic_logits.transpose(1, 2),
                                  target_classes, self.empty_weight)

        losses = {'loss_ce': loss_ce}

        if log:
            # TODO this should probably be a separate loss, not hacked in this
            # one here
            losses['class_error'] = 100 - \
                accuracy(src_logits[idx], target_classes_o)[0]
        return losses
Example #6
0
    def loss_labels(self,
                    outputs,
                    gt_instances: List[Instances],
                    indices,
                    num_boxes,
                    log=False):
        """Classification loss (NLL)
        targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
        """
        src_logits = outputs['pred_logits']
        idx = self._get_src_permutation_idx(indices)
        target_classes = torch.full(src_logits.shape[:2],
                                    self.num_classes,
                                    dtype=torch.int64,
                                    device=src_logits.device)
        # The matched gt for disappear track query is set -1.
        labels = []
        for gt_per_img, (_, J) in zip(gt_instances, indices):
            labels_per_img = torch.ones_like(J)
            # set labels of track-appear slots to 0.
            if len(gt_per_img) > 0:
                labels_per_img[J != -1] = gt_per_img.labels[J[J != -1]]
            labels.append(labels_per_img)
        target_classes_o = torch.cat(labels)
        target_classes[idx] = target_classes_o
        if self.focal_loss:
            gt_labels_target = F.one_hot(
                target_classes, num_classes=self.num_classes +
                1)[:, :, :-1]  # no loss for the last (background) class
            gt_labels_target = gt_labels_target.to(src_logits)
            loss_ce = sigmoid_focal_loss(src_logits.flatten(1),
                                         gt_labels_target.flatten(1),
                                         alpha=0.25,
                                         gamma=2,
                                         num_boxes=num_boxes,
                                         mean_in_dim1=False)
            loss_ce = loss_ce.sum()
        else:
            loss_ce = F.cross_entropy(src_logits.transpose(1, 2),
                                      target_classes, self.empty_weight)
        losses = {'loss_ce': loss_ce}

        if log:
            # TODO this should probably be a separate loss, not hacked in this one here
            losses['class_error'] = 100 - accuracy(src_logits[idx],
                                                   target_classes_o)[0]

        return losses
Example #7
0
    def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
        """Classification loss (NLL) targets dicts must contain the key
        'labels' containing a tensor of dim [nb_target_boxes].

        Args:
            outputs (dict): Dict of RTD outputs.
            targets (list): A list of size batch_size. Each element is a dict composed of:
                'labels': Labels of groundtruth instances (0: action).
                'boxes': Relative temporal ratio of groundtruth instances.
                'video_id': ID of the video sample.
            indices (list): A list of size batch_size.
                Each element is composed of two tensors,
                the first index_i is the indices of the selected predictions (in order),
                the second index_j is the indices of the corresponding selected targets (in order).
                For each batch element,
                it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes)

        Returns:
            losses (dict): Dict of losses.
        """
        assert 'pred_logits' in outputs
        if indices is None:
            losses = {'loss_ce': 0}
            if log:
                losses['class_error'] = 0
            return losses

        src_logits = outputs['pred_logits']

        idx = self._get_src_permutation_idx(indices)
        target_classes_o = torch.cat(
            [t['labels'][J] for t, (_, J) in zip(targets, indices)])
        target_classes = torch.full(src_logits.shape[:2],
                                    self.num_classes,
                                    dtype=torch.int64,
                                    device=src_logits.device)
        target_classes[idx] = target_classes_o

        loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes,
                                  self.empty_weight)
        losses = {'loss_ce': loss_ce}

        if log:
            losses['class_error'] = 100 - accuracy(src_logits[idx],
                                                   target_classes_o)[0]
        return losses
Example #8
0
    def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
        """Classification loss (NLL)
        targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
        """
        assert 'pred_logits' in outputs
        src_logits = outputs['pred_logits']
        idx = self._get_src_permutation_idx(indices)
        target_classes_o = torch.cat(
            [t["labels"][J] for t, (_, J) in zip(targets, indices)])
        target_classes = torch.full(src_logits.shape[:2],
                                    self.num_classes,
                                    dtype=torch.int64,
                                    device=src_logits.device)
        target_classes[idx] = target_classes_o

        target_classes_onehot = torch.zeros([
            src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1
        ],
                                            dtype=src_logits.dtype,
                                            layout=src_logits.layout,
                                            device=src_logits.device)
        target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)

        target_classes_onehot = target_classes_onehot[:, :, :-1]

        # ################### Only Produce Loss for Activated Categories ###################
        activated_class_ids = outputs[
            'activated_class_ids']  # (bs, num_support)
        activated_class_ids = activated_class_ids.unsqueeze(1).repeat(
            1, target_classes_onehot.shape[1], 1)
        loss_ce = sigmoid_focal_loss(src_logits.gather(2, activated_class_ids),
                                     target_classes_onehot.gather(
                                         2, activated_class_ids),
                                     num_boxes,
                                     alpha=self.focal_alpha,
                                     gamma=2)

        loss_ce = loss_ce * src_logits.shape[1]

        losses = {'loss_ce': loss_ce}

        if log:
            losses['class_error'] = 100 - accuracy(src_logits[idx],
                                                   target_classes_o)[0]

        return losses
Example #9
0
 def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
     """Classification loss (NLL)
     targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
     """
     assert 'pred_logits' in outputs
     src_logits = outputs['pred_logits']
     idx = self._get_src_permutation_idx(indices)
     target_classes_o = torch2paddle.concat([t['labels'][J] for t, (_, J
         ) in zip(targets, indices)])
     target_classes = paddle.full(src_logits.shape[:2], self.num_classes,
         dtype='int64').requires_grad_(False)
     target_classes[idx] = target_classes_o
     loss_ce = F.cross_entropy(src_logits.transpose(1, 2),
         target_classes, self.empty_weight)
     losses = {'loss_ce': loss_ce}
     if log:
         losses['class_error'] = 100 - accuracy(src_logits[idx],
             target_classes_o)[0]
     return losses
Example #10
0
    def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
        """Classification loss (NLL)
        targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
        """
        assert 'pred_logits' in outputs
        # (b,num_queries=100, num_classes+1)
        # 分类预测结果
        src_logits = outputs['pred_logits']

        # 要理解_get_src_permutation_idx()在做什么。输入参数indices是匹配的预测结果与GT的
        # 索引,其形式在forward向前传播的indices = self.matcher(....)中已有说明。该方法返回一个tuple,
        # 代表所有匹配的预测结果的batch index(在当前batch中属于第几张图像)和 query index(图像中的第几个query对象)。
        # 这个tuple,第一个元素是各个object的batch index,第二个元素是各个object的query index,
        # shape都是(num_matched_queries1+num_matched_queries2+...,)
        idx = self._get_src_permutation_idx(indices)
        # 类似地,我们可以获得当前batch中所有匹配的GT所属的类别(target_classes_o),然后通过src_logits、target_classes_o
        # 就可以设置预测结果对应的GT了,这就是下面的target_classes。target_classes的shape和src_logits一致,代表每个
        # query objects对应的GT,首先将它们全部初始化为背景,然后根据匹配的索引(idx)设置匹配的GT(target_classes_p)类别。
        # 匹配的GT,(num_matched_queries1+num_matched_queries2+...)
        target_classes_o = torch.cat(
            [t["labels"][J] for t, (_, J) in zip(targets, indices)])
        # (b,num_queries=100),初始化为背景
        target_classes = torch.full(src_logits.shape[:2],
                                    self.num_classes,
                                    dtype=torch.int64,
                                    device=src_logits.device)
        # 匹配的预测索引对应的值置为匹配的GT
        target_classes[idx] = target_classes_o

        # “热身活动”做完后,终于可以开始计算loss了,注意在使用Pytorch的交叉熵时,需要将预测类别的那个维度转换到通道这个维度上(dim1)。
        # src_logits的shape变为(b,num_classes+1,num_queries=100)
        # 因为CELoss需要第一维对应类别数
        loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes,
                                  self.empty_weight)
        losses = {'loss_ce': loss_ce}

        if log:
            # TODO this should probably be a separate loss, not hacked in this one here
            # class_error计算的是Top-1精度(百分数),即预测概率最大的那个类别与对应被分配的GT类别是否一致,这部分仅用于log,并不参与模型训练。
            losses['class_error'] = 100 - accuracy(src_logits[idx],
                                                   target_classes_o)[0]
        return losses