コード例 #1
0
ファイル: ccm_trainer.py プロジェクト: feipan664/CCM-1
    def get_target_slm(self):
        print('Generate SLM[Semantic layour matrix] of target samples')
        loader = dataset.init_test_dataset(self.config,
                                           self.config.target,
                                           set='train',
                                           selected=self.target_all)
        self.target_his_h = {}
        self.target_his_w = {}

        for index, batch in tqdm(enumerate(loader)):
            img, label, _, _, name = batch
            name = name[0]
            with torch.no_grad():
                pred = self.model.forward(img.cuda())
            pred = F.softmax(pred, dim=1)
            label = pred.argmax(dim=1).squeeze()
            h, w = label.shape
            pred = pred.squeeze()
            cu_h = pred.sum(dim=2)
            cu_w = pred.sum(dim=1)
            cu_h = cu_h.t()
            cu_w = cu_w.t()
            cu_h = F.normalize(cu_h, p=1, dim=0)
            cu_w = F.normalize(cu_w, p=1, dim=0)
            self.target_his_h[name] = cu_h
            self.target_his_w[name] = cu_w
コード例 #2
0
    def source_pixel_selection(self, round):
        print("Pixel-wise similarity matching")
        loader = dataset.init_test_dataset(self.config,
                                           self.config.source,
                                           set="train",
                                           selected=self.source_selected,
                                           label_ori=True)
        if self.config.source == 'gta5':
            src_w = 1914
            src_h = 1052
        elif self.config.source == 'synthia':
            src_w = 1280
            src_h = 760
        elif self.config.source == 'cityscapes':
            src_w = 2048
            src_h = 1024

        interp = nn.Upsample(size=(src_h, src_w),
                             mode="bilinear",
                             align_corners=True)
        self.source_plabel_path = osp.join(self.config.plabel,
                                           self.config.note, str(round),
                                           self.config.source)
        mkdir(self.source_plabel_path)

        target_template = self.pool.pool

        self.mean_memo = {i: [] for i in range(self.config.num_classes)}
        self.target_cnts = {i: 0 for i in range(self.config.num_classes)}
        with torch.no_grad():
            for index, batch in tqdm(enumerate(loader)):
                image, label, _, _, name = batch
                label = label.cuda()
                img_name = name[0].split("/")[-1]
                img_name = img_name.replace('leftImg8bit', 'gtFine_labelIds')
                dir_name = name[0].split("/")[0]
                temp_dir = osp.join(self.source_plabel_path, dir_name)
                if not os.path.exists(temp_dir):
                    os.makedirs(temp_dir)

                output = self.model.forward(image.cuda())
                output = interp(output)
                pred = F.softmax(output, dim=1)
                pred_label = pred.argmax(dim=1)
                plabel = self.calc_pixel_simi(pred, label, target_template)

                plabel = plabel.view(src_h, src_w)
                plabel = plabel.cpu().numpy()

                plabel = np.asarray(plabel, dtype=np.uint8)
                plabel = Image.fromarray(plabel)

                # plabel.save("%s/%s.png" % (self.source_plabel_path, name.split(".")[0]))
                plabel.save("%s/%s.png" % (temp_dir, img_name.split(".")[0]))
コード例 #3
0
ファイル: ccm_trainer.py プロジェクト: feipan664/CCM-1
    def semantic_layout_matching(self, round, num_to_sel):
        print('[Semantic Layout Matching]')

        loader = dataset.init_test_dataset(self.config,
                                           self.config.source,
                                           set='train')
        self.get_target_slm()  # Calculate the SLM of the target domain
        self.target_his_h = torch.stack(list(self.target_his_h.values()))
        self.target_his_w = torch.stack(list(self.target_his_w.values()))

        self.target_his_h = self.k_means(self.target_his_h)
        self.target_his_w = self.k_means(self.target_his_w)

        source_selected = []
        score_dict = {}
        name_pair = {}

        for index, batch in tqdm(enumerate(loader)):
            img, label, _, _, name = batch
            name = name[0]

            label = label.cuda().squeeze()
            h, w = label.shape
            cu_h = torch.zeros(h, self.config.num_classes).float().cuda()
            cu_w = torch.zeros(w, self.config.num_classes).float().cuda()
            for i in range(self.config.num_classes):
                mask = label == i
                mask_h = mask.sum(dim=1).float()
                mask_w = mask.sum(dim=0).float()
                cu_h[:, i] = mask_h
                cu_w[:, i] = mask_w
            cu_h = F.normalize(cu_h, p=1, dim=0)
            cu_w = F.normalize(cu_w, p=1, dim=0)

            cu_h = cu_h.t()
            cu_w = cu_w.t()
            score1 = self.his_kl_simi(self.target_his_h, cu_h)
            score2 = self.his_kl_simi(self.target_his_w, cu_w)

            score_dict[name] = score1 + score2

        sorted_pair = sorted(score_dict.items(), key=operator.itemgetter(1))
        sorted_name = [m[0] for m in sorted_pair]
        distance = [m[1] for m in sorted_pair]

        self.selected = sorted_name[:num_to_sel - 1]
        return self.selected
コード例 #4
0
ファイル: ccm_trainer.py プロジェクト: feipan664/CCM-1
    def gene_thres(self, prop, num_cls=19):
        print('[Calculate Threshold using config.cb_prop]')  # r in section 3.3

        probs = {}
        freq = {}
        loader = dataset.init_test_dataset(self.config,
                                           self.config.target,
                                           set="train",
                                           selected=self.target_all,
                                           batchsize=1)
        for index, batch in tqdm(enumerate(loader)):
            img, label, _, _, _ = batch
            with torch.no_grad():
                pred = F.softmax(self.model.forward(img.cuda()), dim=1)
            pred_probs = pred.max(dim=1)[0]
            pred_probs = pred_probs.squeeze()
            pred_label = torch.argmax(pred, dim=1).squeeze()
            for i in range(num_cls):
                cls_mask = pred_label == i
                cnt = cls_mask.sum()
                if cnt == 0:
                    continue
                cls_probs = torch.masked_select(pred_probs, cls_mask)
                cls_probs = cls_probs.detach().cpu().numpy().tolist()
                cls_probs.sort()
                if i not in probs:
                    probs[i] = cls_probs[::
                                         5]  # reduce the consumption of memory
                else:
                    probs[i].extend(cls_probs[::5])

        growth = {}
        thres = {}
        for k in probs.keys():
            cls_prob = probs[k]
            cls_total = len(cls_prob)
            freq[k] = cls_total
            cls_prob = np.array(cls_prob)
            cls_prob = np.sort(cls_prob)
            index = int(cls_total * prop)
            cls_thres = cls_prob[-index]
            cls_thres2 = cls_prob[index]
            thres[k] = cls_thres
        print(thres)
        return thres
コード例 #5
0
 def validate2(self, count):
     self.model = self.model.eval()
     total_loss = 0
     testloader = dataset.init_test_dataset(self.config,
                                            self.config.target,
                                            set='val',
                                            batchsize=4)
     interp = nn.Upsample(size=(1024, 2048),
                          mode='bilinear',
                          align_corners=True)
     for i_iter, batch in tqdm(enumerate(testloader)):
         img, seg_label, _, _, name = batch
         seg_label = seg_label.long().cuda()
         b, c, h, w = img.shape
         # print(img.shape)
         seg_pred = self.model(img.cuda())
         # print(seg_pred.shape, seg_label.shape)
         seg_pred = interp(seg_pred)
         seg_loss = F.cross_entropy(seg_pred, seg_label, ignore_index=255)
         total_loss += seg_loss.item()
     total_loss /= len(iter(testloader))
     print('---------------------')
     print('Validation seg loss: {} at Epoch {}'.format(total_loss, count))
     return total_loss
コード例 #6
0
ファイル: ccm_trainer.py プロジェクト: feipan664/CCM-1
    def save_pred(self, round):
        # Using the threshold to generate pseudo labels and save
        print("[Generate pseudo labels]")
        loader = dataset.init_test_dataset(self.config,
                                           self.config.target,
                                           set="train",
                                           selected=self.target_all)
        interp = nn.Upsample(size=(1024, 2048),
                             mode="bilinear",
                             align_corners=True)

        self.plabel_path = osp.join(self.config.plabel, self.config.note,
                                    str(round))

        mkdir(self.plabel_path)
        self.config.target_data_dir = self.plabel_path
        self.pool = Pool(
        )  # save the probability of pseudo labels for the pixel-wise similarity matchinng, which is detailed around Eq. (9)
        accs = AverageMeter()  # Counter
        props = AverageMeter()  # Counter
        cls_acc = GroupAverageMeter()  # Class-wise Acc/Prop of Pseudo labels

        self.mean_memo = {i: [] for i in range(self.config.num_classes)}
        with torch.no_grad():
            for index, batch in tqdm(enumerate(loader)):
                image, label, _, _, name = batch
                label = label.cuda()
                img_name = name[0].split("/")[-1]
                dir_name = name[0].split("/")[0]
                img_name = img_name.replace("leftImg8bit", "gtFine_labelIds")
                temp_dir = osp.join(self.plabel_path, dir_name)
                if not os.path.exists(temp_dir):
                    os.mkdir(temp_dir)

                output = self.model.forward(image.cuda())
                output = interp(output)
                # pseudo labels selected by glocal threshold
                mask, plabel = thres_cb_plabel(output,
                                               self.cb_thres,
                                               num_cls=self.config.num_classes)
                # pseudo labels selected by local threshold
                mask2, plabel2 = gene_plabel_prop(output, self.config.cb_prop)
                # mask fusion
                # The fusion strategy is detailed in Sec. 3.3 of paper
                mask, plabel = mask_fusion(output, mask, mask2)
                self.pool.update_pool(output, mask=mask.float())
                acc, prop, cls_dict = Acc(plabel,
                                          label,
                                          num_cls=self.config.num_classes)
                cnt = (plabel != 255).sum().item()
                accs.update(acc, cnt)
                props.update(prop, 1)
                cls_acc.update(cls_dict)
                plabel = plabel.view(1024, 2048)
                plabel = plabel.cpu().numpy()

                plabel = np.asarray(plabel, dtype=np.uint8)
                plabel = Image.fromarray(plabel)

                plabel.save("%s/%s.png" % (temp_dir, img_name.split(".")[0]))

        print('The Accuracy :{:.2%} and proportion :{:.2%} of Pseudo Labels'.
              format(accs.avg.item(), props.avg.item()))
        if self.config.neptune:
            neptune.send_metric("Acc", accs.avg)
            neptune.send_metric("Prop", props.avg)
コード例 #7
0
ファイル: base_trainer.py プロジェクト: feipan664/CCM-1
    def validate(self):
        self.model = self.model.eval()
        testloader = dataset.init_test_dataset(self.config,
                                               self.config.target,
                                               set='val')
        interp = nn.Upsample(size=(1024, 2048),
                             mode='bilinear',
                             align_corners=True)
        union = torch.zeros(self.config.num_classes, 1,
                            dtype=torch.float).cuda().float()
        inter = torch.zeros(self.config.num_classes, 1,
                            dtype=torch.float).cuda().float()
        preds = torch.zeros(self.config.num_classes, 1,
                            dtype=torch.float).cuda().float()
        with torch.no_grad():
            for index, batch in tqdm(enumerate(testloader)):
                image, label, _, _, name = batch
                output = self.model(image.cuda())
                label = label.cuda()
                output = interp(output).squeeze()
                C, H, W = output.shape
                Mask = (label.squeeze()) < C

                pred_e = torch.linspace(0, C - 1, steps=C).view(C, 1, 1)
                pred_e = pred_e.repeat(1, H, W).cuda()
                pred = output.argmax(dim=0).float()
                pred_mask = torch.eq(pred_e, pred).byte()
                pred_mask = pred_mask * Mask.byte()

                label_e = torch.linspace(0, C - 1, steps=C).view(C, 1, 1)
                label_e = label_e.repeat(1, H, W).cuda()
                label = label.view(1, H, W)
                label_mask = torch.eq(label_e, label.float()).byte()
                label_mask = label_mask * Mask.byte()

                tmp_inter = label_mask + pred_mask.byte()
                cu_inter = (tmp_inter == 2).view(C,
                                                 -1).sum(dim=1,
                                                         keepdim=True).float()
                cu_union = (tmp_inter > 0).view(C,
                                                -1).sum(dim=1,
                                                        keepdim=True).float()
                cu_preds = pred_mask.view(C, -1).sum(dim=1,
                                                     keepdim=True).float()

                union += cu_union
                inter += cu_inter
                preds += cu_preds

            iou = inter / union
            acc = inter / preds
            if C == 16:
                iou = iou.squeeze()
                class13_iou = torch.cat((iou[:3], iou[6:]))
                class13_miou = class13_iou.mean().item()
                print('13-Class mIoU:{:.2%}'.format(class13_miou))
            mIoU = iou.mean().item()
            mAcc = acc.mean().item()
            iou = iou.cpu().numpy()
            print('mIoU: {:.2%} mAcc : {:.2%} '.format(mIoU, mAcc))
            if self.config.neptune:
                neptune.send_metric('mIoU', mIoU)
                neptune.send_metric('mAcc', mAcc)
        return mIoU