예제 #1
0
    def forward(self, input, target):
        probs = F.softmax(input[0], dim=1)
        confidence = torch.sigmoid(input[1])

        # Make sure we don't have any numerical instability
        eps = 1e-12
        probs = torch.clamp(probs, 0.0 + eps, 1.0 - eps)
        confidence = torch.clamp(confidence, 0.0 + eps, 1.0 - eps)

        if self.half_random:
            # Randomly set half of the confidences to 1 (i.e. no hints)
            b = torch.bernoulli(
                torch.Tensor(confidence.size()).uniform_(0, 1)).to(self.device)
            conf = confidence * b + (1 - b)
        else:
            conf = confidence

        labels_hot = misc.one_hot_embedding(target,
                                            self.nb_classes).to(self.device)
        # Segmentation special case
        if self.task == "segmentation":
            labels_hot = labels_hot.permute(0, 3, 1, 2)
        probs_interpol = torch.log(conf * probs + (1 - conf) * labels_hot)
        self.loss_nll = nn.NLLLoss()(probs_interpol, target)
        self.loss_confid = torch.mean(-(torch.log(confidence)))
        total_loss = self.loss_nll + self.lbda * self.loss_confid

        # Update lbda
        if self.lbda_control:
            if self.loss_confid >= self.beta:
                self.lbda /= 0.99
            else:
                self.lbda /= 1.01
        return total_loss
예제 #2
0
 def forward(self, input, target):
     probs = F.softmax(input[0], dim=1)
     confidence = torch.sigmoid(input[1]).squeeze()
     # Apply optional weighting
     weights = torch.ones_like(target).type(torch.FloatTensor).to(
         self.device)
     weights[(probs.argmax(dim=1) != target)] *= self.weighting
     labels_hot = misc.one_hot_embedding(target,
                                         self.nb_classes).to(self.device)
     # Segmentation special case
     if self.task == "segmentation":
         labels_hot = labels_hot.permute(0, 3, 1, 2)
     loss = weights * (confidence - (probs * labels_hot).sum(dim=1))**2
     return torch.mean(loss)
예제 #3
0
    def forward(self, input, target):
        probs = F.softmax(input[0], dim=1)
        confidence = torch.sigmoid(input[1]).squeeze()
        # Apply optional weighting
        weights = torch.ones_like(target).type(torch.FloatTensor).to(
            self.device)

        print('loss_weights:', weights, 'loss_conf', confidence, 'los_probs',
              probs, 'loss_target', target, target.type, target.shape)

        weights[(probs.argmax(dim=1) !=
                 target)] *= self.weighting  # weighting = default
        labels_hot = misc.one_hot_embedding(target,
                                            self.nb_classes).to(self.device)

        print('loss_hot:', labels_hot)
        print('loss_w1:', weights)

        loss = weights * (confidence - (probs * labels_hot).sum(dim=1))**2
        return torch.mean(loss)
예제 #4
0
    def evaluate(self,
                 dloader,
                 len_dataset,
                 split="test",
                 mode="mcp",
                 samples=50,
                 verbose=False):
        self.model.eval()
        metrics = Metrics(self.metrics, len_dataset, self.num_classes)
        loss = 0

        # Special case of mc-dropout
        if mode == "mc_dropout":
            self.model.keep_dropout_in_test()
            LOGGER.info(f"Sampling {samples} times")

        # Evaluation loop
        loop = tqdm(dloader, disable=not verbose)
        for batch_id, (data, target) in enumerate(loop):
            data, target = data.to(self.device), target.to(self.device)

            with torch.no_grad():
                if mode == "mcp":
                    output = self.model(data)
                    if self.task == "classification":
                        loss += self.criterion(output, target)
                    elif self.task == "segmentation":
                        loss += self.criterion(output, target.squeeze(dim=1))
                    confidence, pred = F.softmax(output,
                                                 dim=1).max(dim=1,
                                                            keepdim=True)

                elif mode == "tcp":
                    output = self.model(data)
                    if self.task == "classification":
                        loss += self.criterion(output, target)
                    elif self.task == "segmentation":
                        loss += self.criterion(output, target.squeeze(dim=1))
                    probs = F.softmax(output, dim=1)
                    pred = probs.max(dim=1, keepdim=True)[1]
                    labels_hot = misc.one_hot_embedding(
                        target, self.num_classes).to(self.device)
                    # Segmentation special case
                    if self.task == "segmentation":
                        labels_hot = labels_hot.squeeze(1).permute(0, 3, 1, 2)
                    confidence, _ = (labels_hot * probs).max(dim=1,
                                                             keepdim=True)

                elif mode == "mc_dropout":
                    if self.task == "classification":
                        outputs = torch.zeros(samples, data.shape[0],
                                              self.num_classes).to(self.device)
                    elif self.task == "segmentation":
                        outputs = torch.zeros(
                            samples,
                            data.shape[0],
                            self.num_classes,
                            data.shape[2],
                            data.shape[3],
                        ).to(self.device)
                    for i in range(samples):
                        outputs[i] = self.model(data)
                    output = outputs.mean(0)
                    if self.task == "classification":
                        loss += self.criterion(output, target)
                    elif self.task == "segmentation":
                        loss += self.criterion(output, target.squeeze(dim=1))
                    probs = F.softmax(output, dim=1)
                    confidence = (probs * torch.log(probs + 1e-9)).sum(
                        dim=1)  # entropy
                    pred = probs.max(dim=1, keepdim=True)[1]

                metrics.update(pred, target, confidence)

        scores = metrics.get_scores(split=split)
        losses = {"loss_nll": loss}
        return losses, scores
예제 #5
0
    def evaluate(self,
                 dloader,
                 len_dataset,
                 split="test",
                 mode="mcp",
                 samples=50,
                 verbose=False):
        self.model.eval()
        metrics = Metrics(self.metrics, len_dataset, self.num_classes)
        loss = 0

        # Special case of mc-dropout
        if mode == "mc_dropout":
            self.model.keep_dropout_in_test()
            LOGGER.info(f"Sampling {samples} times")

        # Evaluation loop
        loop = tqdm(dloader, disable=not verbose)
        for step, batch in enumerate(tqdm(loop, desc="Iteration")):
            batch = tuple(t.to(self.device) for t in batch)
            idx_ids, input_ids, input_mask, segment_ids, label_ids = batch
            print(label_ids)

            with torch.no_grad():
                if mode == "mcp":
                    print(True)
                    output, pooled_output = self.model(input_ids,
                                                       segment_ids,
                                                       input_mask,
                                                       labels=None)

                    current_loss = self.criterion(output.view(-1, 2),
                                                  label_ids.view(-1))
                    loss += current_loss

                    confidence, pred = F.softmax(output,
                                                 dim=1).max(dim=1,
                                                            keepdim=True)

                    print(confidence)
                    print(pred)

                elif mode == "tcp":
                    output, pooled_output = self.model(input_ids,
                                                       segment_ids,
                                                       input_mask,
                                                       labels=None)

                    current_loss = self.criterion(output.view(-1, 2),
                                                  label_ids.view(-1))
                    loss += current_loss

                    probs = F.softmax(output, dim=1)

                    pred = probs.max(dim=1, keepdim=True)[1]

                    labels_hot = misc.one_hot_embedding(
                        label_ids, self.num_classes).to(self.device)

                    confidence, _ = (labels_hot * probs).max(dim=1,
                                                             keepdim=True)

                elif mode == "mc_dropout":
                    print('---------------input_ids.shape---------------')
                    print(input_ids.shape)
                    outputs = torch.zeros(
                        samples, self.config_args['training']['batch_size'],
                        self.num_classes).to(self.device)

                    for i in range(samples):
                        outputs[i], _ = self.model(input_ids,
                                                   segment_ids,
                                                   input_mask,
                                                   labels=None)
                    output = outputs.mean(0)

                    loss += self.criterion(output.view(-1, 2),
                                           label_ids.view(-1))

                    probs = F.softmax(output, dim=1)
                    confidence = (probs * torch.log(probs + 1e-9)).sum(dim=1)
                    pred = probs.max(dim=1, keepdim=True)[1]

                metrics.update(idx_ids, pred, label_ids, confidence)
                pred_detach, label_detach, confidence_detach, idx_detach = pred.detach(
                ), label_ids.detach(), confidence.detach(), idx_ids.detach()
                print('pred', pred_detach.cpu())
                print('label', label_detach.cpu())
                print('idx', idx_detach.cpu())
                print('confidence', confidence_detach.cpu())

        print('----------------------------------------------------')
        pred_list = []
        target_list = []
        confidence_list = []

        for i, p, t, c in zip(metrics.new_idx, metrics.new_pred,
                              metrics.new_taget, metrics.new_conf):
            print('idx,pred,target,confidence', i, p[0], t, c[0])
            pred_list.append(p[0])
            target_list.append(t)
            confidence_list.append(c[0])

        print('----------------------------------------------------')
        report = classifiction_metric(
            np.array(pred_list), np.array(target_list),
            np.array(self.config_args['data']['label_list']))
        print(report)
        print('----------------------------------------------------')

        scores = metrics.get_scores(split=split)
        losses = {"loss_nll": loss}
        return losses, scores