示例#1
0
    def forward(self,
                action_label,
                stop_label,
                action_pred,
                stop_pred,
                mask,
                reduce=False,
                metadata=None):
        fwd_pred = action_pred[:, 0]  # fwd velocity
        ang_pred = action_pred[:, 2]  # angular velocity
        fwd_label = action_label[:, 0]
        ang_label = action_label[:, 2]

        if mask is not None:
            mask = mask.unsqueeze(1).byte()
            (fwd_pred, ang_pred, stop_pred, fwd_label, ang_label, stop_label) = \
                mask_tensors((fwd_pred, ang_pred, stop_pred, fwd_label, ang_label, stop_label), mask)

        fwd_loss = self.act_loss(fwd_pred, fwd_label)
        ang_loss = self.act_loss(ang_pred, ang_label)
        stop_loss = self.stoploss(stop_pred, stop_label)

        if reduce:
            loss = 0.2 * fwd_loss + 0.6 * ang_loss + 0.2 * stop_loss
        else:
            loss = torch.cat([fwd_loss, ang_loss, stop_loss])

        nans = loss != loss
        if torch.sum(nans.long()).data.item() > 0:
            raise ValueError("Nan's encountered in loss calculation")

        return loss
示例#2
0
文件: action_loss.py 项目: dxsun/drif
    def forward(self, action_label, stop_label, action_pred, stop_pred, mask, reduce=False, metadata=None):
        fwd_pred = action_pred[:, 0]        # fwd velocity
        ang_pred = action_pred[:, 2]        # angular velocity
        fwd_label = action_label[:, 0]
        ang_label = action_label[:, 2]

        if mask is not None:
            mask = mask.unsqueeze(1).byte()
            (fwd_pred, ang_pred, stop_pred, fwd_label, ang_label, stop_label) = \
                mask_tensors((fwd_pred, ang_pred, stop_pred, fwd_label, ang_label, stop_label), mask)

        fwd_loss = self.act_loss(fwd_pred, fwd_label)
        ang_loss = self.act_loss(ang_pred, ang_label)
        stop_loss = self.stoploss(stop_pred, stop_label)

        if reduce:
            loss = 0.2 * fwd_loss + 0.6 * ang_loss + 0.2 * stop_loss
        else:
            loss = torch.cat([fwd_loss, ang_loss, stop_loss])

        nans = loss != loss
        if torch.sum(nans.long()).data[0] > 0:
            print ("WARNING: Nan's encountered in loss calculation")
            print(loss)
            loss[nans] = 0
            return Variable(empty_float_tensor(list[loss.size()]), self.is_cuda, self.cuda_device)

        return loss
示例#3
0
文件: action_loss.py 项目: dxsun/drif
    def forward(self, action_label, action_pred, mask=None, reduce=False, flags=None, batchreduce=True):
        fwd_pred = action_pred[:, 0]        # fwd velocity
        ang_pred = action_pred[:, 2]        # angular velocity
        stop_pred = action_pred[:, 3]
        fwd_label = action_label[:, 0]
        ang_label = action_label[:, 2]
        stop_label = action_label[:, 3]

        if mask is not None:
            mask = mask.unsqueeze(1).byte()
            (fwd_pred, ang_pred, stop_pred, fwd_label, ang_label, stop_label) = \
                mask_tensors((fwd_pred, ang_pred, stop_pred, fwd_label, ang_label, stop_label), mask)

        # Compute loss for each element in the batch
        fwd_loss = self.act_loss(fwd_pred, fwd_label)
        ang_loss = self.act_loss(ang_pred, ang_label)

        # Aggregate

        flagged_losses = {}
        """
        if flags is not None and None not in flags:
            batch_size = fwd_pred.size(0)
            seq_len = int(batch_size / len(flags))
            real_batch_size = int(batch_size / seq_len)
            for b in range(real_batch_size):
                for s in range(seq_len):
                    flag_loss = (0.2 * fwd_loss[b * seq_len + s] + 0.6 * ang_loss[b * seq_len + s]).data
                    flagged_losses[flags[b]] = flag_loss.cpu().numpy()[0]
        """

        if batchreduce:
            # Reduce the losses manually
            fwd_loss = torch.sum(fwd_loss)
            ang_loss = torch.sum(ang_loss)

            # Stop loss is already reduced, because PyTorch at the time of writing didn't have a reduce arg for it.
            stop_loss = self.stoploss(stop_pred, stop_label)
            loss = torch.cat([fwd_loss, ang_loss, stop_loss])

        else:
            stop_loss = torch.zeros_like(stop_pred)
            for i in range(len(stop_pred)):
                stop_loss[i:i+1] = self.stoploss(stop_pred[i:i+1], stop_label[i:i+1])
            loss = torch.stack([fwd_loss, ang_loss, stop_loss], dim=1)

        if reduce:
            loss = self.reduce_loss(loss)

        nans = loss != loss
        # import pdb; pdb.set_trace()
        if torch.sum(nans.long()).data[0].item() > 0:
            print ("WARNING: Nan's encountered in loss calculation")
            print(loss)
            loss[nans] = 0
            return Variable(empty_float_tensor(list[loss.size()]), self.is_cuda, self.cuda_device)

        return loss, flagged_losses