Esempio n. 1
0
    def val_func_process(self, input_data, label, device=None):
        input_data = np.ascontiguousarray(input_data[None, :, :, :],
                                          dtype=np.float32)
        input_data = torch.FloatTensor(input_data).cuda(device)
        label = np.ascontiguousarray(label[None, :, :], dtype=np.int16)
        label = torch.LongTensor(label).cuda(device)

        b, h, w = label.size()
        scaled_gts = F.interpolate((label.view(b, 1, h, w)).float(),
                                   scale_factor=0.125,
                                   mode="nearest")
        b, c, h, w = scaled_gts.size()
        scaled_gts = scaled_gts.squeeze_().view(b, h, w)
        C = config.num_classes + 1
        one_hot_gts = one_hot(scaled_gts, C).view(b, C, -1)
        similarity_gts = torch.bmm(one_hot_gts.permute(0, 2, 1), one_hot_gts)

        with torch.cuda.device(input_data.get_device()):
            self.val_func.eval()
            self.val_func.to(input_data.get_device())
            with torch.no_grad():
                score = self.val_func(input_data, aux_label=similarity_gts)
                score = score[0]

                if self.is_flip:
                    input_data = input_data.flip(-1)
                    score_flip = self.val_func(input_data,
                                               aux_label=similarity_gts)
                    score_flip = score_flip[0]
                    score += score_flip.flip(-1)
                score = torch.exp(score)
                # score = score.data

        return score
Esempio n. 2
0
def get_similarity_gt(gts, scale_factor):
    b, h, w = gts.size()
    scaled_gts = F.interpolate((gts.view(b, 1, h, w)).float(),
                               scale_factor=scale_factor,
                               mode="nearest")
    b, c, h, w = scaled_gts.size()
    scaled_gts = scaled_gts.squeeze_()
    C = config.num_classes + 1
    one_hot_gts = one_hot(scaled_gts, C).view(b, C, -1)
    similarity_gts = torch.bmm(one_hot_gts.permute(0, 2, 1), one_hot_gts)

    return similarity_gts
Esempio n. 3
0
            minibatch = dataloader.next()
            imgs = minibatch['data']
            gts = minibatch['label']

            imgs = imgs.cuda(non_blocking=True)
            gts = gts.cuda(non_blocking=True)

            b, h, w = gts.size()
            scaled_gts = F.interpolate((gts.view(b, 1, h, w)).float(),
                                       scale_factor=0.125,
                                       mode="nearest")
            b, c, h, w = scaled_gts.size()
            scaled_gts = scaled_gts.squeeze_()
            C = config.num_classes + 1
            one_hot_gts = one_hot(scaled_gts, C).view(b, C, -1)
            similarity_gts = torch.bmm(one_hot_gts.permute(0, 2, 1),
                                       one_hot_gts)

            gts = gts - 1

            loss = model(imgs, gts, similarity_gts)

            # reduce the whole loss over multi-gpu
            if engine.distributed:
                dist.all_reduce(loss, dist.ReduceOp.SUM)
                loss = loss / engine.world_size
            # else:
            #     loss = Reduce.apply(*loss) / len(loss)

            optimizer.zero_grad()
Esempio n. 4
0
    def forward(self, pred, target):
        b, h, w = target.size()
        scaled_gts = F.interpolate((target.view(b, 1, h, w)).float(),
                                   scale_factor=self.scale,
                                   mode="nearest")

        valid_mask = torch.ones_like(scaled_gts)
        valid_mask[scaled_gts == self.ignore_index] = 0
        valid_vector = valid_mask.view(b, -1, 1)
        valid_mask = torch.bmm(valid_vector, valid_vector.permute(0, 2, 1))

        scaled_gts[scaled_gts == self.ignore_index] = self.num_class
        scaled_gts = scaled_gts.squeeze_()
        C = self.num_class + 1
        one_hot_gts = one_hot(scaled_gts, C).view(b, C, -1)
        similarity_gts = torch.bmm(one_hot_gts.permute(0, 2, 1),
                                   one_hot_gts)

        bce_loss = self.criterion(pred, similarity_gts)
        num_valid = valid_mask.sum()
        num_valid = torch.where(num_valid > 0, num_valid,
                                torch.ones(1, device=num_valid.device))
        bce_loss = valid_mask * bce_loss
        bce_loss = bce_loss.sum() / num_valid

        valid_vector = valid_vector.view(b, -1)
        num_valid = valid_vector.sum()
        num_valid = torch.where(num_valid > 0, num_valid,
                                torch.ones(1, device=num_valid.device))

        vtarget = similarity_gts * valid_mask

        precision_part = torch.sum(pred * vtarget, dim=2)
        denominator = torch.sum(pred, dim=2)
        denominator = denominator.masked_fill_(1 - (denominator > 0), 1)
        precision_part = precision_part.div_(denominator)
        precision_label = torch.ones_like(precision_part)
        precision_loss = self.criterion(precision_part, precision_label)
        precision_loss = valid_vector * precision_loss
        precision_loss = precision_loss.sum() / num_valid

        recall_part = torch.sum(pred * vtarget, dim=2)
        denominator = torch.sum(vtarget, dim=2)
        denominator = denominator.masked_fill_(1 - (denominator > 0), 1)
        recall_part = recall_part.div_(denominator)
        recall_label = torch.ones_like(recall_part)
        recall_loss = self.criterion(recall_part, recall_label)
        recall_loss = valid_vector * recall_loss
        recall_loss = recall_loss.sum() / num_valid

        vtarget = (1 - similarity_gts) * valid_mask
        spec_part = torch.sum((1 - pred) * vtarget, dim=2)
        denominator = torch.sum(vtarget, dim=2)
        denominator = denominator.masked_fill_(1 - (denominator > 0), 1)
        spec_part = spec_part.div_(denominator)
        spec_label = torch.ones_like(spec_part)
        spec_loss = self.criterion(spec_part, spec_label)
        spec_loss = valid_vector * spec_loss
        spec_loss = spec_loss.sum() / num_valid

        loss = bce_loss + recall_loss + spec_loss + precision_loss

        return loss
Esempio n. 5
0
            minibatch = dataloader.next()
            imgs = minibatch['data']
            gts = minibatch['label']

            imgs = imgs.cuda(non_blocking=True)
            gts = gts.cuda(non_blocking=True)

            b, h, w = gts.size()
            # scaled_gts = F.interpolate((gts.view(b, 1, h, w)).float(),
            #                            scale_factor=0.125,
            #                            mode="nearest")
            # b, c, h, w = scaled_gts.size()
            # scaled_gts = scaled_gts.squeeze_()
            C = config.num_classes + 1
            one_hot_gts = one_hot(gts, C).view(b, C, -1)
            similarity_gts = torch.bmm(one_hot_gts.permute(0, 2, 1),
                                       one_hot_gts)

            gts = gts - 1

            loss = model(imgs, gts, similarity_gts)

            # reduce the whole loss over multi-gpu
            dist.all_reduce(loss, dist.ReduceOp.SUM)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            current_idx = epoch * config.niters_per_epoch + idx