コード例 #1
0
class CAC(BaseModel):
    def __init__(self,
                 num_classes,
                 conf,
                 sup_loss=None,
                 ignore_index=None,
                 testing=False,
                 pretrained=True):

        super(CAC, self).__init__()
        assert int(conf['supervised']) + int(
            conf['semi']) == 1, 'one mode only'
        if conf['supervised']:
            self.mode = 'supervised'
        elif conf['semi']:
            self.mode = 'semi'
        else:
            raise ValueError('No such mode choice {}'.format(self.mode))

        self.ignore_index = ignore_index

        self.num_classes = num_classes
        self.sup_loss_w = conf['supervised_w']
        self.sup_loss = sup_loss
        self.downsample = conf['downsample']
        self.backbone = conf['backbone']
        self.layers = conf['layers']
        self.out_dim = conf['out_dim']
        self.proj_final_dim = conf['proj_final_dim']

        assert self.layers in [50, 101]

        if self.backbone == 'deeplab_v3+':
            self.encoder = DeepLab_v3p(backbone='resnet{}'.format(self.layers))
            self.classifier = nn.Sequential(
                nn.Dropout(0.1),
                nn.Conv2d(256, num_classes, kernel_size=1, stride=1))
            for m in self.classifier.modules():
                if isinstance(m, nn.Conv2d):
                    torch.nn.init.kaiming_normal_(m.weight)
                elif isinstance(m, nn.BatchNorm2d):
                    m.weight.data.fill_(1)
                    m.bias.data.zero_()
                elif isinstance(m, nn.SyncBatchNorm):
                    m.weight.data.fill_(1)
                    m.bias.data.zero_()
        elif self.backbone == 'psp':
            self.encoder = Encoder(pretrained=pretrained)
            self.classifier = nn.Conv2d(self.out_dim,
                                        num_classes,
                                        kernel_size=1,
                                        stride=1)
        else:
            raise ValueError("No such backbone {}".format(self.backbone))

        if self.mode == 'semi':
            self.project = nn.Sequential(
                nn.Conv2d(self.out_dim, self.out_dim, kernel_size=1, stride=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(self.out_dim,
                          self.proj_final_dim,
                          kernel_size=1,
                          stride=1))

            self.weight_unsup = conf['weight_unsup']
            self.temp = conf['temp']
            self.epoch_start_unsup = conf['epoch_start_unsup']
            self.selected_num = conf['selected_num']
            self.step_save = conf['step_save']
            self.step_count = 0
            self.feature_bank = []
            self.pseudo_label_bank = []
            self.pos_thresh_value = conf['pos_thresh_value']
            self.stride = conf['stride']

    def forward(self, x_l=None, target_l=None, x_ul=None, target_ul=None, curr_iter=None, epoch=None, gpu=None, gt_l=None, ul1=None, br1=None, \
                ul2=None, br2=None, flip=None):
        if not self.training:
            enc = self.encoder(x_l)
            enc = self.classifier(enc)
            return F.interpolate(enc,
                                 size=x_l.size()[2:],
                                 mode='bilinear',
                                 align_corners=True)

        if self.mode == 'supervised':
            enc = self.encoder(x_l)
            enc = self.classifier(enc)
            output_l = F.interpolate(enc,
                                     size=x_l.size()[2:],
                                     mode='bilinear',
                                     align_corners=True)

            loss_sup = self.sup_loss(output_l,
                                     target_l,
                                     ignore_index=self.ignore_index,
                                     temperature=1.0) * self.sup_loss_w

            curr_losses = {'loss_sup': loss_sup}
            outputs = {'sup_pred': output_l}
            total_loss = loss_sup
            return total_loss, curr_losses, outputs

        elif self.mode == 'semi':

            enc = self.encoder(x_l)
            enc = self.classifier(enc)
            output_l = F.interpolate(enc,
                                     size=x_l.size()[2:],
                                     mode='bilinear',
                                     align_corners=True)
            loss_sup = self.sup_loss(output_l,
                                     target_l,
                                     ignore_index=self.ignore_index,
                                     temperature=1.0) * self.sup_loss_w

            curr_losses = {'loss_sup': loss_sup}
            outputs = {'sup_pred': output_l}
            total_loss = loss_sup

            if epoch < self.epoch_start_unsup:
                return total_loss, curr_losses, outputs

            # x_ul: [batch_size, 2, 3, H, W]
            x_ul1 = x_ul[:, 0, :, :, :]
            x_ul2 = x_ul[:, 1, :, :, :]

            enc_ul1 = self.encoder(x_ul1)
            if self.downsample:
                enc_ul1 = F.avg_pool2d(enc_ul1, kernel_size=2, stride=2)
            output_ul1 = self.project(enc_ul1)  #[b, c, h, w]
            output_ul1 = F.normalize(output_ul1, 2, 1)

            enc_ul2 = self.encoder(x_ul2)
            if self.downsample:
                enc_ul2 = F.avg_pool2d(enc_ul2, kernel_size=2, stride=2)
            output_ul2 = self.project(enc_ul2)  #[b, c, h, w]
            output_ul2 = F.normalize(output_ul2, 2, 1)

            # compute pseudo label
            logits1 = self.classifier(
                enc_ul1)  #[batch_size, num_classes, h, w]
            logits2 = self.classifier(enc_ul2)
            pseudo_logits_1 = F.softmax(
                logits1, 1).max(1)[0].detach()  #[batch_size, h, w]
            pseudo_logits_2 = F.softmax(logits2, 1).max(1)[0].detach()
            pseudo_label1 = logits1.max(1)[1].detach()  #[batch_size, h, w]
            pseudo_label2 = logits2.max(1)[1].detach()

            # get overlap part
            output_feature_list1 = []
            output_feature_list2 = []
            pseudo_label_list1 = []
            pseudo_label_list2 = []
            pseudo_logits_list1 = []
            pseudo_logits_list2 = []
            for idx in range(x_ul1.size(0)):
                output_ul1_idx = output_ul1[idx]
                output_ul2_idx = output_ul2[idx]
                pseudo_label1_idx = pseudo_label1[idx]
                pseudo_label2_idx = pseudo_label2[idx]
                pseudo_logits_1_idx = pseudo_logits_1[idx]
                pseudo_logits_2_idx = pseudo_logits_2[idx]
                if flip[0][idx] == True:
                    output_ul1_idx = torch.flip(output_ul1_idx, dims=(2, ))
                    pseudo_label1_idx = torch.flip(pseudo_label1_idx,
                                                   dims=(1, ))
                    pseudo_logits_1_idx = torch.flip(pseudo_logits_1_idx,
                                                     dims=(1, ))
                if flip[1][idx] == True:
                    output_ul2_idx = torch.flip(output_ul2_idx, dims=(2, ))
                    pseudo_label2_idx = torch.flip(pseudo_label2_idx,
                                                   dims=(1, ))
                    pseudo_logits_2_idx = torch.flip(pseudo_logits_2_idx,
                                                     dims=(1, ))
                output_feature_list1.append(
                    output_ul1_idx[:, ul1[0][idx] // 8:br1[0][idx] // 8,
                                   ul1[1][idx] // 8:br1[1][idx] // 8].permute(
                                       1, 2, 0).contiguous().view(
                                           -1, output_ul1.size(1)))
                output_feature_list2.append(
                    output_ul2_idx[:, ul2[0][idx] // 8:br2[0][idx] // 8,
                                   ul2[1][idx] // 8:br2[1][idx] // 8].permute(
                                       1, 2, 0).contiguous().view(
                                           -1, output_ul2.size(1)))
                pseudo_label_list1.append(
                    pseudo_label1_idx[ul1[0][idx] // 8:br1[0][idx] // 8,
                                      ul1[1][idx] // 8:br1[1][idx] //
                                      8].contiguous().view(-1))
                pseudo_label_list2.append(
                    pseudo_label2_idx[ul2[0][idx] // 8:br2[0][idx] // 8,
                                      ul2[1][idx] // 8:br2[1][idx] //
                                      8].contiguous().view(-1))
                pseudo_logits_list1.append(
                    pseudo_logits_1_idx[ul1[0][idx] // 8:br1[0][idx] // 8,
                                        ul1[1][idx] // 8:br1[1][idx] //
                                        8].contiguous().view(-1))
                pseudo_logits_list2.append(
                    pseudo_logits_2_idx[ul2[0][idx] // 8:br2[0][idx] // 8,
                                        ul2[1][idx] // 8:br2[1][idx] //
                                        8].contiguous().view(-1))
            output_feat1 = torch.cat(output_feature_list1, 0)  #[n, c]
            output_feat2 = torch.cat(output_feature_list2, 0)  #[n, c]
            pseudo_label1_overlap = torch.cat(pseudo_label_list1, 0)  #[n,]
            pseudo_label2_overlap = torch.cat(pseudo_label_list2, 0)  #[n,]
            pseudo_logits1_overlap = torch.cat(pseudo_logits_list1, 0)  #[n,]
            pseudo_logits2_overlap = torch.cat(pseudo_logits_list2, 0)  #[n,]
            assert output_feat1.size(0) == output_feat2.size(0)
            assert pseudo_label1_overlap.size(0) == pseudo_label2_overlap.size(
                0)
            assert output_feat1.size(0) == pseudo_label1_overlap.size(0)

            # concat across multi-gpus
            b, c, h, w = output_ul1.size()
            selected_num = self.selected_num
            output_ul1_flatten = output_ul1.permute(0, 2, 3,
                                                    1).contiguous().view(
                                                        b * h * w, c)
            output_ul2_flatten = output_ul2.permute(0, 2, 3,
                                                    1).contiguous().view(
                                                        b * h * w, c)
            selected_idx1 = np.random.choice(range(b * h * w),
                                             selected_num,
                                             replace=False)
            selected_idx2 = np.random.choice(range(b * h * w),
                                             selected_num,
                                             replace=False)
            output_ul1_flatten_selected = output_ul1_flatten[selected_idx1]
            output_ul2_flatten_selected = output_ul2_flatten[selected_idx2]
            output_ul_flatten_selected = torch.cat(
                [output_ul1_flatten_selected, output_ul2_flatten_selected],
                0)  #[2*kk, c]
            output_ul_all = self.concat_all_gather(
                output_ul_flatten_selected)  #[2*N, c]

            pseudo_label1_flatten_selected = pseudo_label1.view(
                -1)[selected_idx1]
            pseudo_label2_flatten_selected = pseudo_label2.view(
                -1)[selected_idx2]
            pseudo_label_flatten_selected = torch.cat([
                pseudo_label1_flatten_selected, pseudo_label2_flatten_selected
            ], 0)  #[2*kk]
            pseudo_label_all = self.concat_all_gather(
                pseudo_label_flatten_selected)  #[2*N]

            self.feature_bank.append(output_ul_all)
            self.pseudo_label_bank.append(pseudo_label_all)
            if self.step_count > self.step_save:
                self.feature_bank = self.feature_bank[1:]
                self.pseudo_label_bank = self.pseudo_label_bank[1:]
            else:
                self.step_count += 1
            output_ul_all = torch.cat(self.feature_bank, 0)
            pseudo_label_all = torch.cat(self.pseudo_label_bank, 0)

            eps = 1e-8
            pos1 = (output_feat1 * output_feat2.detach()).sum(
                -1, keepdim=True) / self.temp  #[n, 1]
            pos2 = (output_feat1.detach() * output_feat2).sum(
                -1, keepdim=True) / self.temp  #[n, 1]

            # compute loss1
            b = 8000

            def run1(pos, output_feat1, output_ul_idx, pseudo_label_idx,
                     pseudo_label1_overlap, neg_max1):
                # print("gpu: {}, i_1: {}".format(gpu, i))
                mask1_idx = (
                    pseudo_label_idx.unsqueeze(0) !=
                    pseudo_label1_overlap.unsqueeze(-1)).float()  #[n, b]
                neg1_idx = (
                    output_feat1 @ output_ul_idx.T) / self.temp  #[n, b]
                logits1_neg_idx = (torch.exp(neg1_idx - neg_max1) *
                                   mask1_idx).sum(-1)  #[n, ]
                return logits1_neg_idx

            def run1_0(pos, output_feat1, output_ul_idx, pseudo_label_idx,
                       pseudo_label1_overlap):
                # print("gpu: {}, i_1_0: {}".format(gpu, i))
                mask1_idx = (
                    pseudo_label_idx.unsqueeze(0) !=
                    pseudo_label1_overlap.unsqueeze(-1)).float()  #[n, b]
                neg1_idx = (
                    output_feat1 @ output_ul_idx.T) / self.temp  #[n, b]
                neg1_idx = torch.cat([pos, neg1_idx], 1)  #[n, 1+b]
                mask1_idx = torch.cat([
                    torch.ones(mask1_idx.size(0), 1).float().cuda(), mask1_idx
                ], 1)  #[n, 1+b]
                neg_max1 = torch.max(neg1_idx, 1, keepdim=True)[0]  #[n, 1]
                logits1_neg_idx = (torch.exp(neg1_idx - neg_max1) *
                                   mask1_idx).sum(-1)  #[n, ]
                return logits1_neg_idx, neg_max1

            N = output_ul_all.size(0)
            logits1_down = torch.zeros(pos1.size(0)).float().cuda()
            for i in range((N - 1) // b + 1):
                # print("gpu: {}, i: {}".format(gpu, i))
                pseudo_label_idx = pseudo_label_all[i * b:(i + 1) * b]
                output_ul_idx = output_ul_all[i * b:(i + 1) * b]
                if i == 0:
                    logits1_neg_idx, neg_max1 = torch.utils.checkpoint.checkpoint(
                        run1_0, pos1, output_feat1, output_ul_idx,
                        pseudo_label_idx, pseudo_label1_overlap)
                else:
                    logits1_neg_idx = torch.utils.checkpoint.checkpoint(
                        run1, pos1, output_feat1, output_ul_idx,
                        pseudo_label_idx, pseudo_label1_overlap, neg_max1)
                logits1_down += logits1_neg_idx

            logits1 = torch.exp(pos1 - neg_max1).squeeze(-1) / (logits1_down +
                                                                eps)

            pos_mask_1 = (
                (pseudo_logits2_overlap > self.pos_thresh_value) &
                (pseudo_logits1_overlap < pseudo_logits2_overlap)).float()
            loss1 = -torch.log(logits1 + eps)
            loss1 = (loss1 * pos_mask_1).sum() / (pos_mask_1.sum() + 1e-12)

            # compute loss2
            def run2(pos, output_feat2, output_ul_idx, pseudo_label_idx,
                     pseudo_label2_overlap, neg_max2):
                # print("gpu: {}, i_2: {}".format(gpu, i))
                mask2_idx = (
                    pseudo_label_idx.unsqueeze(0) !=
                    pseudo_label2_overlap.unsqueeze(-1)).float()  #[n, b]
                neg2_idx = (
                    output_feat2 @ output_ul_idx.T) / self.temp  #[n, b]
                logits2_neg_idx = (torch.exp(neg2_idx - neg_max2) *
                                   mask2_idx).sum(-1)  #[n, ]
                return logits2_neg_idx

            def run2_0(pos, output_feat2, output_ul_idx, pseudo_label_idx,
                       pseudo_label2_overlap):
                # print("gpu: {}, i_2_0: {}".format(gpu, i))
                mask2_idx = (
                    pseudo_label_idx.unsqueeze(0) !=
                    pseudo_label2_overlap.unsqueeze(-1)).float()  #[n, b]
                neg2_idx = (
                    output_feat2 @ output_ul_idx.T) / self.temp  #[n, b]
                neg2_idx = torch.cat([pos, neg2_idx], 1)  #[n, 1+b]
                mask2_idx = torch.cat([
                    torch.ones(mask2_idx.size(0), 1).float().cuda(), mask2_idx
                ], 1)  #[n, 1+b]
                neg_max2 = torch.max(neg2_idx, 1, keepdim=True)[0]  #[n, 1]
                logits2_neg_idx = (torch.exp(neg2_idx - neg_max2) *
                                   mask2_idx).sum(-1)  #[n, ]
                return logits2_neg_idx, neg_max2

            N = output_ul_all.size(0)
            logits2_down = torch.zeros(pos2.size(0)).float().cuda()
            for i in range((N - 1) // b + 1):
                pseudo_label_idx = pseudo_label_all[i * b:(i + 1) * b]
                output_ul_idx = output_ul_all[i * b:(i + 1) * b]
                if i == 0:
                    logits2_neg_idx, neg_max2 = torch.utils.checkpoint.checkpoint(
                        run2_0, pos2, output_feat2, output_ul_idx,
                        pseudo_label_idx, pseudo_label2_overlap)
                else:
                    logits2_neg_idx = torch.utils.checkpoint.checkpoint(
                        run2, pos2, output_feat2, output_ul_idx,
                        pseudo_label_idx, pseudo_label2_overlap, neg_max2)
                logits2_down += logits2_neg_idx

            logits2 = torch.exp(pos2 - neg_max2).squeeze(-1) / (logits2_down +
                                                                eps)

            pos_mask_2 = (
                (pseudo_logits1_overlap > self.pos_thresh_value) &
                (pseudo_logits2_overlap < pseudo_logits1_overlap)).float()

            loss2 = -torch.log(logits2 + eps)
            loss2 = (loss2 * pos_mask_2).sum() / (pos_mask_2.sum() + 1e-12)

            loss_unsup = self.weight_unsup * (loss1 + loss2)
            curr_losses['loss1'] = loss1
            curr_losses['loss2'] = loss2
            curr_losses['loss_unsup'] = loss_unsup
            total_loss = total_loss + loss_unsup
            return total_loss, curr_losses, outputs

        else:
            raise ValueError("No such mode {}".format(self.mode))

    def concat_all_gather(self, tensor):
        """
        Performs all_gather operation on the provided tensors.
        *** Warning ***: torch.distributed.all_gather has no gradient.
        """
        with torch.no_grad():
            tensors_gather = [
                torch.ones_like(tensor)
                for _ in range(torch.distributed.get_world_size())
            ]
            torch.distributed.all_gather(tensors_gather,
                                         tensor,
                                         async_op=False)

            output = torch.cat(tensors_gather, dim=0)
        return output

    def get_backbone_params(self):
        return self.encoder.get_backbone_params()

    def get_other_params(self):
        if self.mode == 'supervised':
            return chain(self.encoder.get_module_params(),
                         self.classifier.parameters())
        elif self.mode == 'semi':
            return chain(self.encoder.get_module_params(),
                         self.classifier.parameters(),
                         self.project.parameters())
        else:
            raise ValueError("No such mode {}".format(self.mode))
コード例 #2
0
class CCT(BaseModel):
    def __init__(self,
                 num_classes,
                 conf,
                 sup_loss=None,
                 cons_w_unsup=None,
                 ignore_index=None,
                 testing=False,
                 pretrained=True,
                 use_weak_lables=False,
                 weakly_loss_w=0.4):

        if not testing:
            assert (ignore_index
                    is not None) and (sup_loss is not None) and (cons_w_unsup
                                                                 is not None)

        super(CCT, self).__init__()
        assert int(conf['supervised']) + int(
            conf['semi']) == 1, 'one mode only'
        if conf['supervised']:
            self.mode = 'supervised'
        else:
            self.mode = 'semi'

        # Supervised and unsupervised losses
        self.ignore_index = ignore_index
        if conf['un_loss'] == "KL":
            self.unsuper_loss = softmax_kl_loss
        elif conf['un_loss'] == "MSE":
            self.unsuper_loss = softmax_mse_loss
        elif conf['un_loss'] == "JS":
            self.unsuper_loss = softmax_js_loss
        else:
            raise ValueError(f"Invalid supervised loss {conf['un_loss']}")

        self.unsup_loss_w = cons_w_unsup
        self.sup_loss_w = conf['supervised_w']
        self.softmax_temp = conf['softmax_temp']
        self.sup_loss = sup_loss
        self.sup_type = conf['sup_loss']

        # Use weak labels
        self.use_weak_lables = use_weak_lables
        self.weakly_loss_w = weakly_loss_w
        # pair wise loss (sup mat)
        self.aux_constraint = conf['aux_constraint']
        self.aux_constraint_w = conf['aux_constraint_w']
        # confidence masking (sup mat)
        self.confidence_th = conf['confidence_th']
        self.confidence_masking = conf['confidence_masking']

        # Create the model
        self.encoder = Encoder(pretrained=pretrained)

        # The main encoder
        upscale = 8
        num_out_ch = 2048
        decoder_in_ch = num_out_ch // 4
        self.main_decoder = MainDecoder(upscale,
                                        decoder_in_ch,
                                        num_classes=num_classes)

        # The auxilary decoders
        if self.mode == 'semi' or self.mode == 'weakly_semi':
            vat_decoder = [
                VATDecoder(upscale,
                           decoder_in_ch,
                           num_classes,
                           xi=conf['xi'],
                           eps=conf['eps']) for _ in range(conf['vat'])
            ]
            drop_decoder = [
                DropOutDecoder(upscale,
                               decoder_in_ch,
                               num_classes,
                               drop_rate=conf['drop_rate'],
                               spatial_dropout=conf['spatial'])
                for _ in range(conf['drop'])
            ]
            cut_decoder = [
                CutOutDecoder(upscale,
                              decoder_in_ch,
                              num_classes,
                              erase=conf['erase'])
                for _ in range(conf['cutout'])
            ]
            context_m_decoder = [
                ContextMaskingDecoder(upscale, decoder_in_ch, num_classes)
                for _ in range(conf['context_masking'])
            ]
            object_masking = [
                ObjectMaskingDecoder(upscale, decoder_in_ch, num_classes)
                for _ in range(conf['object_masking'])
            ]
            feature_drop = [
                FeatureDropDecoder(upscale, decoder_in_ch, num_classes)
                for _ in range(conf['feature_drop'])
            ]
            feature_noise = [
                FeatureNoiseDecoder(upscale,
                                    decoder_in_ch,
                                    num_classes,
                                    uniform_range=conf['uniform_range'])
                for _ in range(conf['feature_noise'])
            ]
            random_rotate = [
                RandomRotationDecoder(
                    upscale,
                    decoder_in_ch,
                    num_classes,
                    deg_range=conf['random_rotate_deg_range'])
                for _ in range(conf['rotate'])
            ]
            random_zoom = [
                RandomZoomDecoder(
                    upscale,
                    decoder_in_ch,
                    num_classes,
                    zoom_ratio_lower=conf['random_zoom_ratio_lower'],
                    zoom_ratio_upper=conf['random_zoom_ratio_upper'])
                for _ in range(conf['zoom'])
            ]
            random_blur = [
                RandomBlurDecoder(
                    upscale,
                    decoder_in_ch,
                    num_classes,
                    sigma_lower=conf['gaussian_blur_sigma_lower'],
                    sigma_upper=conf['gaussian_blur_sigma_upper'])
                for _ in range(conf['blur'])
            ]
            random_brightness = [
                RandomBrightnessDecoder(
                    upscale,
                    decoder_in_ch,
                    num_classes,
                    brightness_factor_lower=conf['brightness_factor_lower'],
                    brightness_factor_upper=conf['brightness_factor_upper'])
                for _ in range(conf['brightness'])
            ]

            self.aux_decoders = nn.ModuleList([
                *vat_decoder, *drop_decoder, *cut_decoder, *context_m_decoder,
                *object_masking, *feature_drop, *feature_noise, *random_rotate,
                *random_zoom, *random_blur, *random_brightness
            ])

    def forward(self,
                x_l=None,
                target_l=None,
                x_ul=None,
                target_ul=None,
                curr_iter=None,
                epoch=None):
        if not self.training:
            return self.main_decoder(self.encoder(x_l))

        # We compute the losses in the forward pass to avoid problems encountered in muti-gpu

        # Forward pass the labels example
        input_size = (x_l.size(2), x_l.size(3))
        output_l = self.main_decoder(self.encoder(x_l))
        if output_l.shape != x_l.shape:
            output_l = F.interpolate(output_l,
                                     size=input_size,
                                     mode='bilinear',
                                     align_corners=True)

        # Supervised loss
        if self.sup_type == 'CE':
            loss_sup = self.sup_loss(
                output_l,
                target_l,
                ignore_index=self.ignore_index,
                temperature=self.softmax_temp) * self.sup_loss_w
        elif self.sup_type == 'FL':
            loss_sup = self.sup_loss(output_l, target_l) * self.sup_loss_w
        else:
            loss_sup = self.sup_loss(
                output_l,
                target_l,
                curr_iter=curr_iter,
                epoch=epoch,
                ignore_index=self.ignore_index) * self.sup_loss_w

        # If supervised mode only, return
        if self.mode == 'supervised':
            curr_losses = {'loss_sup': loss_sup}
            outputs = {'sup_pred': output_l}
            total_loss = loss_sup
            return total_loss, curr_losses, outputs

        # If semi supervised mode
        elif self.mode == 'semi':
            # Get main prediction
            x_ul = self.encoder(x_ul)
            output_ul = self.main_decoder(x_ul)

            # Get auxiliary predictions
            outputs_ul = [
                aux_decoder(x_ul, output_ul.detach())
                for aux_decoder in self.aux_decoders
            ]
            targets = F.softmax(output_ul.detach(), dim=1)

            # Compute unsupervised loss
            loss_unsup = sum([self.unsuper_loss(inputs=u, targets=targets, \
                            conf_mask=self.confidence_masking, threshold=self.confidence_th, use_softmax=False)
                            for u in outputs_ul])
            loss_unsup = (loss_unsup / len(outputs_ul))
            curr_losses = {'loss_sup': loss_sup}

            if output_ul.shape != x_l.shape:
                output_ul = F.interpolate(output_ul,
                                          size=input_size,
                                          mode='bilinear',
                                          align_corners=True)
            outputs = {'sup_pred': output_l, 'unsup_pred': output_ul}

            # Compute the unsupervised loss
            weight_u = self.unsup_loss_w(epoch=epoch, curr_iter=curr_iter)
            loss_unsup = loss_unsup * weight_u
            curr_losses['loss_unsup'] = loss_unsup
            total_loss = loss_unsup + loss_sup

            # If case we're using weak lables, add the weak loss term with a weight (self.weakly_loss_w)
            if self.use_weak_lables:
                weight_w = (weight_u /
                            self.unsup_loss_w.final_w) * self.weakly_loss_w
                loss_weakly = sum([
                    CE_loss(outp, target_ul, ignore_index=self.ignore_index)
                    for outp in outputs_ul
                ]) / len(outputs_ul)
                loss_weakly = loss_weakly * weight_w
                curr_losses['loss_weakly'] = loss_weakly
                total_loss += loss_weakly

            # Pair-wise loss
            if self.aux_constraint:
                pair_wise = pair_wise_loss(outputs_ul) * self.aux_constraint_w
                curr_losses['pair_wise'] = pair_wise
                loss_unsup += pair_wise

            return total_loss, curr_losses, outputs

    def get_backbone_params(self):
        return self.encoder.get_backbone_params()

    def get_other_params(self):
        if self.mode == 'semi':
            return chain(self.encoder.get_module_params(),
                         self.main_decoder.parameters(),
                         self.aux_decoders.parameters())

        return chain(self.encoder.get_module_params(),
                     self.main_decoder.parameters())
コード例 #3
0
ファイル: model.py プロジェクト: saramsv/TCT
class CCT(BaseModel):
    def __init__(self,
                 num_classes,
                 conf,
                 sup_loss=None,
                 cons_w_unsup=None,
                 ignore_index=None,
                 testing=False,
                 pretrained=True,
                 use_weak_lables=False,
                 unsupervised_mode=None,
                 weakly_loss_w=0.4):

        if not testing:
            assert (ignore_index
                    is not None) and (sup_loss is not None) and (cons_w_unsup
                                                                 is not None)

        super(CCT, self).__init__()
        assert int(conf['supervised']) + int(
            conf['semi']) == 1, 'one mode only'
        if conf['supervised']:
            self.mode = 'supervised'
        else:
            self.mode = 'semi'

        # Supervised and unsupervised losses
        self.ignore_index = ignore_index
        if conf['un_loss'] == "KL":
            self.unsuper_loss = softmax_kl_loss
        elif conf['un_loss'] == "MSE":
            self.unsuper_loss = softmax_mse_loss
        elif conf['un_loss'] == "JS":
            self.unsuper_loss = softmax_js_loss
        else:
            raise ValueError(f"Invalid supervised loss {conf['un_loss']}")

        self.unsup_loss_w = cons_w_unsup
        self.sup_loss_w = conf['supervised_w']
        self.softmax_temp = conf['softmax_temp']
        self.sup_loss = sup_loss
        self.sup_type = conf['sup_loss']
        self.unsupervised_mode = unsupervised_mode

        # Use weak labels
        self.use_weak_lables = use_weak_lables
        self.weakly_loss_w = weakly_loss_w
        # pair wise loss (sup mat)
        self.aux_constraint = conf['aux_constraint']
        self.aux_constraint_w = conf['aux_constraint_w']
        # confidence masking (sup mat)
        self.confidence_th = conf['confidence_th']
        self.confidence_masking = conf['confidence_masking']

        # Create the model
        self.encoder = Encoder(pretrained=pretrained)

        # The main encoder
        upscale = 8
        num_out_ch = 2048
        decoder_in_ch = num_out_ch // 4
        self.main_decoder = MainDecoder(upscale,
                                        decoder_in_ch,
                                        num_classes=num_classes)

        # The auxilary decoders
        if self.mode == 'semi' or self.mode == 'weakly_semi':
            if 'seq' in unsupervised_mode:
                vat_decoder_seq = [
                    VATDecoder(upscale,
                               decoder_in_ch,
                               num_classes,
                               xi=conf['xi'],
                               eps=conf['eps']) for _ in range(conf['vat'])
                ]
                self.seq_decoder = nn.ModuleList([*vat_decoder_seq])

            elif 'pert' in unsupervised_mode:
                vat_decoder = [
                    VATDecoder(upscale,
                               decoder_in_ch,
                               num_classes,
                               xi=conf['xi'],
                               eps=conf['eps']) for _ in range(conf['vat'])
                ]
                drop_decoder = [
                    DropOutDecoder(upscale,
                                   decoder_in_ch,
                                   num_classes,
                                   drop_rate=conf['drop_rate'],
                                   spatial_dropout=conf['spatial'])
                    for _ in range(conf['drop'])
                ]
                cut_decoder = [
                    CutOutDecoder(upscale,
                                  decoder_in_ch,
                                  num_classes,
                                  erase=conf['erase'])
                    for _ in range(conf['cutout'])
                ]
                context_m_decoder = [
                    ContextMaskingDecoder(upscale, decoder_in_ch, num_classes)
                    for _ in range(conf['context_masking'])
                ]
                object_masking = [
                    ObjectMaskingDecoder(upscale, decoder_in_ch, num_classes)
                    for _ in range(conf['object_masking'])
                ]
                feature_drop = [
                    FeatureDropDecoder(upscale, decoder_in_ch, num_classes)
                    for _ in range(conf['feature_drop'])
                ]
                feature_noise = [
                    FeatureNoiseDecoder(upscale,
                                        decoder_in_ch,
                                        num_classes,
                                        uniform_range=conf['uniform_range'])
                    for _ in range(conf['feature_noise'])
                ]

                self.aux_decoders = nn.ModuleList([
                    *vat_decoder, *drop_decoder, *cut_decoder,
                    *context_m_decoder, *object_masking, *feature_drop,
                    *feature_noise
                ])
                ''' 
                temp = [DropOutDecoder(upscale, decoder_in_ch, num_classes,
                                                                    drop_rate=conf['drop_rate'], spatial_dropout=conf['spatial'])
                                                                    for _ in range(conf['drop'])]
                self.seq_decoder = MainDecoder(upscale, decoder_in_ch, num_classes=num_classes)
                '''

    #def forward(self, x_l=None, target_l=None, x_ul=None, target_ul=None, curr_iter=None, epoch=None, img_id_l=None, img_id_ul=None):
    def forward(self,
                x_l=None,
                target_l=None,
                x_ul=None,
                target_ul=None,
                curr_iter=None,
                epoch=None,
                x_seq_triplet=None,
                unsupervised_mode=None):
        '''x_seq_triplet is a list with the lan being equal to the len of the seq
        each element in the list is a tuple with the following format
        (imgs:torch.Size([5, 3, 320, 320]), annotations:torch.Size([5, 320, 320, 3]), [list of the img_names])
        5 in this case if the batch size'''
        if not self.training:
            return self.main_decoder(self.encoder(x_l))

        def get_date(x):
            date_ = x.split('.')[0]
            y = '00'
            if date_[3] == '1':
                y = '12'
            elif date_[3] == '0':
                y = '11'
            m = date_[4:6]
            d = date_[6:8]
            date_ = m + d + y
            return datetime.datetime.strptime(date_, '%m%d%y')

        # We compute the losses in the forward pass to avoid problems encountered in muti-gpu

        # Forward pass the labels example
        input_size = (x_l.size(2), x_l.size(3))
        output_l = self.main_decoder(self.encoder(x_l))
        if output_l.shape != x_l.shape:
            output_l = F.interpolate(output_l,
                                     size=input_size,
                                     mode='bilinear',
                                     align_corners=True)

        # Supervised loss
        if self.sup_type == 'CE':
            loss_sup = self.sup_loss(
                output_l,
                target_l,
                ignore_index=self.ignore_index,
                temperature=self.softmax_temp) * self.sup_loss_w
        elif self.sup_type == 'FL':
            loss_sup = self.sup_loss(output_l, target_l) * self.sup_loss_w
        else:
            loss_sup = self.sup_loss(
                output_l,
                target_l,
                curr_iter=curr_iter,
                epoch=epoch,
                ignore_index=self.ignore_index) * self.sup_loss_w

        # If supervised mode only, return
        if self.mode == 'supervised':
            curr_losses = {'loss_sup': loss_sup}
            outputs = {'sup_pred': output_l}
            total_loss = loss_sup
            return total_loss, curr_losses, outputs

        # If semi supervised mode
        elif self.mode == 'semi':
            # Get sequence predictions
            if unsupervised_mode == 'seq' or unsupervised_mode == 'pertAndSeq':
                seqs = []
                for i in range(len(x_seq_triplet)):
                    '''a batch of ith imgs in the seq. If the batch_size ==5:
                        (imgs:torch.Size([5, 3, 320, 320]), annotations:torch.Size([5, 320, 320, 3]), [list of the img_names])'''
                    x_seq = self.encoder(x_seq_triplet[i][0])

                    output_seq = self.main_decoder(x_seq)
                    seqs.append((x_seq, output_seq))

                # Get auxiliary predictions
                seq_losses = []
                ''' seq_losses has as many elements as the number of imgs in each seq.
                the ith element in seq_losses is the unsupervised loss calculated for 
                the outputs of the maindecoder and the seqdecoder for the batch of ith imgs in seq'''
                all_outputs_ul = []
                '''the ith element in all_outputs_ul is the outputs of the seqdecoder for the batch of the ith imgs in the seq'''
                for pair in seqs:
                    ''' the number of pairs is the number of the images in each seq
                    each pair hold (the output of the encoder for the batch , the output of the maindecoder on pair[0])
                    and outputs_ul is the output of seqdecoder for the batch'''
                    #outputs_ul = [aux_decoder(pair[0], pair[1].detach()) for aux_decoder in self.aux_decoders]
                    #outputs_ul = [aux_decoder(pair[0]) for aux_decoder in [self.seq_decoder]]
                    outputs_ul = [
                        aux_decoder(pair[0], pair[1].detach())
                        for aux_decoder in self.seq_decoder
                    ]
                    targets = F.softmax(pair[1].detach(), dim=1)
                    all_outputs_ul.append(outputs_ul)

                    # Compute unsupervised loss
                    # this loss is between the output of the main decoder and the seqdecoder's
                    loss_seq = sum([
                        self.unsuper_loss(inputs=u,
                                          targets=targets,
                                          conf_mask=self.confidence_masking,
                                          threshold=self.confidence_th,
                                          use_softmax=False)
                        for u in outputs_ul
                    ])
                    loss_seq = (loss_seq / len(outputs_ul))
                    seq_losses.append(loss_seq)

                loss_seq = np.sum(seq_losses) / len(seqs)
                loss_seq_accross_decoder = []
                for i in range(len(all_outputs_ul)):
                    for j in range(i + 1, len(all_outputs_ul)):
                        pert_loss = 0
                        for z in range(
                                len(all_outputs_ul[0])
                        ):  #because we have 2 decoders in seqdecoder
                            pert_loss += self.unsuper_loss(
                                inputs=all_outputs_ul[i][z],
                                targets=all_outputs_ul[j][z],
                                conf_mask=self.confidence_masking,
                                threshold=self.confidence_th,
                                use_softmax=False)
                        pert_loss = pert_loss / len(all_outputs_ul[0])
                        loss_seq_accross_decoder.append(pert_loss)
                loss_seq += np.sum(loss_seq_accross_decoder) / len(
                    loss_seq_accross_decoder)

                output_seqs = []
                for i in range(len(seqs)):
                    temp_x_seq = seqs[i][0]
                    temp_output_seq = seqs[i][1]
                    if temp_output_seq.shape != temp_x_seq.shape:
                        temp_output_seq = F.interpolate(temp_output_seq,
                                                        size=input_size,
                                                        mode='bilinear',
                                                        align_corners=True)
                    output_seqs.append(temp_output_seq)

                # TODO: Figure out why we only return the first output from the main decoder.
                outputs = {'sup_pred': output_l, 'unsup_pred': output_seqs[0]}
                #loss_seq = (loss_seq / 3)
                weight_u = self.unsup_loss_w(epoch=epoch, curr_iter=curr_iter)
                loss_seq = loss_seq * weight_u

            if unsupervised_mode != 'seq':
                # Get main prediction
                x_ul = self.encoder(x_ul)
                output_ul = self.main_decoder(x_ul)
                # Get auxiliary predictions
                outputs_ul = [
                    aux_decoder(x_ul, output_ul.detach())
                    for aux_decoder in self.aux_decoders
                ]
                targets = F.softmax(output_ul.detach(), dim=1)
                # Compute unsupervised loss
                loss_unsup = sum([
                    self.unsuper_loss(inputs=u,
                                      targets=targets,
                                      conf_mask=self.confidence_masking,
                                      threshold=self.confidence_th,
                                      use_softmax=False) for u in outputs_ul
                ])
                loss_unsup = (loss_unsup / len(outputs_ul))

                if output_ul.shape != x_l.shape:
                    output_ul = F.interpolate(output_ul,
                                              size=input_size,
                                              mode='bilinear',
                                              align_corners=True)
                outputs = {'sup_pred': output_l, 'unsup_pred': output_ul}

                weight_u = self.unsup_loss_w(epoch=epoch, curr_iter=curr_iter)
                loss_unsup = loss_unsup * weight_u

            curr_losses = {'loss_sup': loss_sup}
            if unsupervised_mode == 'seq':
                curr_losses['loss_unsup'] = loss_seq
                total_loss = loss_seq + loss_sup
            elif unsupervised_mode == 'pertAndSeq':
                curr_losses['loss_unsup'] = loss_seq + loss_unsup
                total_loss = loss_seq + loss_unsup + loss_sup
            else:
                curr_losses['loss_unsup'] = loss_unsup
                total_loss = loss_unsup + loss_sup

            # In case we're using weak lables, add the weak loss term with a
            #weight (self.weakly_loss_w)
            if self.use_weak_lables:
                target_ul = target_ul[:, :, :, 0]
                weight_w = (weight_u /
                            self.unsup_loss_w.final_w) * self.weakly_loss_w
                loss_weakly = sum([
                    CE_loss(outp, target_ul, ignore_index=self.ignore_index)
                    for outp in outputs_ul
                ]) / len(outputs_ul)
                loss_weakly = loss_weakly * weight_w
                curr_losses['loss_weakly'] = loss_weakly
                total_loss += loss_weakly

            # Pair-wise loss
            if self.aux_constraint:
                pair_wise = pair_wise_loss(outputs_ul) * self.aux_constraint_w
                curr_losses['pair_wise'] = pair_wise
                loss_unsup += pair_wise

            return total_loss, curr_losses, outputs

    def get_backbone_params(self):
        return self.encoder.get_backbone_params()

    def get_other_params(self):
        if self.mode == 'semi':
            if self.unsupervised_mode == 'pertAndSeq':
                return chain(self.encoder.get_module_params(),
                             self.main_decoder.parameters(),
                             self.aux_decoders.parameters(),
                             self.seq_decoder.parameters())
            elif self.unsupervised_mode == 'seq':
                return chain(self.encoder.get_module_params(),
                             self.main_decoder.parameters(),
                             self.seq_decoder.parameters())
            elif self.unsupervised_mode == 'pert':
                return chain(self.encoder.get_module_params(),
                             self.main_decoder.parameters(),
                             self.aux_decoders.parameters())
        return chain(self.encoder.get_module_params(),
                     self.main_decoder.parameters())