def compute_aug_loss(stu_out, tea_out):
        # Augmentation loss
        if use_rampup:
            unsup_mask = None
            conf_mask_count = None
            unsup_mask_count = None
        else:
            conf_tea = torch.max(tea_out, 1)[0]
            unsup_mask = conf_mask = (conf_tea > confidence_thresh).float()
            unsup_mask_count = conf_mask_count = conf_mask.sum()

        if loss == 'bce':
            aug_loss = network_architectures.robust_binary_crossentropy(stu_out, tea_out)
        else:
            d_aug_loss = stu_out - tea_out
            aug_loss = d_aug_loss * d_aug_loss

        # Class balance scaling
        if cls_bal_scale:
            if use_rampup:
                n_samples = float(aug_loss.shape[0])
            else:
                n_samples = unsup_mask.sum()
            avg_pred = n_samples / float(n_classes)
            bal_scale = avg_pred / torch.clamp(tea_out.sum(dim=0), min=1.0)
            if cls_bal_scale_range != 0.0:
                bal_scale = torch.clamp(bal_scale, min=1.0/cls_bal_scale_range, max=cls_bal_scale_range)
            bal_scale = bal_scale.detach()
            aug_loss = aug_loss * bal_scale[None, :]

        aug_loss = aug_loss.mean(dim=1)

        if use_rampup:
            unsup_loss = aug_loss.mean() * rampup_weight_in_list[0]
        else:
            unsup_loss = (aug_loss * unsup_mask).mean()

        # Class balance loss
        if cls_balance > 0.0:
            # Compute per-sample average predicated probability
            # Average over samples to get average class prediction
            avg_cls_prob = stu_out.mean(dim=0)
            # Compute loss
            equalise_cls_loss = cls_bal_fn(avg_cls_prob, float(1.0 / n_classes))

            equalise_cls_loss = equalise_cls_loss.mean() * n_classes

            if use_rampup:
                equalise_cls_loss = equalise_cls_loss * rampup_weight_in_list[0]
            else:
                if rampup == 0:
                    equalise_cls_loss = equalise_cls_loss * unsup_mask.mean(dim=0)

            unsup_loss += equalise_cls_loss * cls_balance

        return unsup_loss, conf_mask_count, unsup_mask_count
Пример #2
0
        def compute_aug_loss(stu_out, tea_out):
            # Augmentation loss
            if use_rampup:
                unsup_mask = None
                conf_mask_count = None
                unsup_mask_count = None
            else:
                conf_tea = torch.max(tea_out, 1)[0]
                unsup_mask = conf_mask = torch.gt(conf_tea,
                                                  confidence_thresh).float()
                unsup_mask_count = conf_mask_count = torch.sum(conf_mask)

            if loss == 'bce':
                aug_loss = network_architectures.robust_binary_crossentropy(
                    stu_out, tea_out)
            else:
                d_aug_loss = stu_out - tea_out
                aug_loss = d_aug_loss * d_aug_loss

            aug_loss = torch.mean(aug_loss, 1)

            if use_rampup:
                unsup_loss = torch.mean(aug_loss) * rampup_weight_in_list[0]
            else:
                unsup_loss = torch.mean(aug_loss * unsup_mask)

            # Class balance loss
            if cls_balance > 0.0:
                # Compute per-sample average predicated probability
                # Average over samples to get average class prediction
                avg_cls_prob = torch.mean(stu_out, 0)
                # Compute loss
                equalise_cls_loss = cls_bal_fn(avg_cls_prob,
                                               float(1.0 / n_classes))

                equalise_cls_loss = torch.mean(equalise_cls_loss) * n_classes

                if use_rampup:
                    equalise_cls_loss = equalise_cls_loss * rampup_weight_in_list[
                        0]
                else:
                    if rampup == 0:
                        equalise_cls_loss = equalise_cls_loss * torch.mean(
                            unsup_mask, 0)

                unsup_loss += equalise_cls_loss * cls_balance

            return unsup_loss, conf_mask_count, unsup_mask_count