def get_target_slm(self): print('Generate SLM[Semantic layour matrix] of target samples') loader = dataset.init_test_dataset(self.config, self.config.target, set='train', selected=self.target_all) self.target_his_h = {} self.target_his_w = {} for index, batch in tqdm(enumerate(loader)): img, label, _, _, name = batch name = name[0] with torch.no_grad(): pred = self.model.forward(img.cuda()) pred = F.softmax(pred, dim=1) label = pred.argmax(dim=1).squeeze() h, w = label.shape pred = pred.squeeze() cu_h = pred.sum(dim=2) cu_w = pred.sum(dim=1) cu_h = cu_h.t() cu_w = cu_w.t() cu_h = F.normalize(cu_h, p=1, dim=0) cu_w = F.normalize(cu_w, p=1, dim=0) self.target_his_h[name] = cu_h self.target_his_w[name] = cu_w
def source_pixel_selection(self, round): print("Pixel-wise similarity matching") loader = dataset.init_test_dataset(self.config, self.config.source, set="train", selected=self.source_selected, label_ori=True) if self.config.source == 'gta5': src_w = 1914 src_h = 1052 elif self.config.source == 'synthia': src_w = 1280 src_h = 760 elif self.config.source == 'cityscapes': src_w = 2048 src_h = 1024 interp = nn.Upsample(size=(src_h, src_w), mode="bilinear", align_corners=True) self.source_plabel_path = osp.join(self.config.plabel, self.config.note, str(round), self.config.source) mkdir(self.source_plabel_path) target_template = self.pool.pool self.mean_memo = {i: [] for i in range(self.config.num_classes)} self.target_cnts = {i: 0 for i in range(self.config.num_classes)} with torch.no_grad(): for index, batch in tqdm(enumerate(loader)): image, label, _, _, name = batch label = label.cuda() img_name = name[0].split("/")[-1] img_name = img_name.replace('leftImg8bit', 'gtFine_labelIds') dir_name = name[0].split("/")[0] temp_dir = osp.join(self.source_plabel_path, dir_name) if not os.path.exists(temp_dir): os.makedirs(temp_dir) output = self.model.forward(image.cuda()) output = interp(output) pred = F.softmax(output, dim=1) pred_label = pred.argmax(dim=1) plabel = self.calc_pixel_simi(pred, label, target_template) plabel = plabel.view(src_h, src_w) plabel = plabel.cpu().numpy() plabel = np.asarray(plabel, dtype=np.uint8) plabel = Image.fromarray(plabel) # plabel.save("%s/%s.png" % (self.source_plabel_path, name.split(".")[0])) plabel.save("%s/%s.png" % (temp_dir, img_name.split(".")[0]))
def semantic_layout_matching(self, round, num_to_sel): print('[Semantic Layout Matching]') loader = dataset.init_test_dataset(self.config, self.config.source, set='train') self.get_target_slm() # Calculate the SLM of the target domain self.target_his_h = torch.stack(list(self.target_his_h.values())) self.target_his_w = torch.stack(list(self.target_his_w.values())) self.target_his_h = self.k_means(self.target_his_h) self.target_his_w = self.k_means(self.target_his_w) source_selected = [] score_dict = {} name_pair = {} for index, batch in tqdm(enumerate(loader)): img, label, _, _, name = batch name = name[0] label = label.cuda().squeeze() h, w = label.shape cu_h = torch.zeros(h, self.config.num_classes).float().cuda() cu_w = torch.zeros(w, self.config.num_classes).float().cuda() for i in range(self.config.num_classes): mask = label == i mask_h = mask.sum(dim=1).float() mask_w = mask.sum(dim=0).float() cu_h[:, i] = mask_h cu_w[:, i] = mask_w cu_h = F.normalize(cu_h, p=1, dim=0) cu_w = F.normalize(cu_w, p=1, dim=0) cu_h = cu_h.t() cu_w = cu_w.t() score1 = self.his_kl_simi(self.target_his_h, cu_h) score2 = self.his_kl_simi(self.target_his_w, cu_w) score_dict[name] = score1 + score2 sorted_pair = sorted(score_dict.items(), key=operator.itemgetter(1)) sorted_name = [m[0] for m in sorted_pair] distance = [m[1] for m in sorted_pair] self.selected = sorted_name[:num_to_sel - 1] return self.selected
def gene_thres(self, prop, num_cls=19): print('[Calculate Threshold using config.cb_prop]') # r in section 3.3 probs = {} freq = {} loader = dataset.init_test_dataset(self.config, self.config.target, set="train", selected=self.target_all, batchsize=1) for index, batch in tqdm(enumerate(loader)): img, label, _, _, _ = batch with torch.no_grad(): pred = F.softmax(self.model.forward(img.cuda()), dim=1) pred_probs = pred.max(dim=1)[0] pred_probs = pred_probs.squeeze() pred_label = torch.argmax(pred, dim=1).squeeze() for i in range(num_cls): cls_mask = pred_label == i cnt = cls_mask.sum() if cnt == 0: continue cls_probs = torch.masked_select(pred_probs, cls_mask) cls_probs = cls_probs.detach().cpu().numpy().tolist() cls_probs.sort() if i not in probs: probs[i] = cls_probs[:: 5] # reduce the consumption of memory else: probs[i].extend(cls_probs[::5]) growth = {} thres = {} for k in probs.keys(): cls_prob = probs[k] cls_total = len(cls_prob) freq[k] = cls_total cls_prob = np.array(cls_prob) cls_prob = np.sort(cls_prob) index = int(cls_total * prop) cls_thres = cls_prob[-index] cls_thres2 = cls_prob[index] thres[k] = cls_thres print(thres) return thres
def validate2(self, count): self.model = self.model.eval() total_loss = 0 testloader = dataset.init_test_dataset(self.config, self.config.target, set='val', batchsize=4) interp = nn.Upsample(size=(1024, 2048), mode='bilinear', align_corners=True) for i_iter, batch in tqdm(enumerate(testloader)): img, seg_label, _, _, name = batch seg_label = seg_label.long().cuda() b, c, h, w = img.shape # print(img.shape) seg_pred = self.model(img.cuda()) # print(seg_pred.shape, seg_label.shape) seg_pred = interp(seg_pred) seg_loss = F.cross_entropy(seg_pred, seg_label, ignore_index=255) total_loss += seg_loss.item() total_loss /= len(iter(testloader)) print('---------------------') print('Validation seg loss: {} at Epoch {}'.format(total_loss, count)) return total_loss
def save_pred(self, round): # Using the threshold to generate pseudo labels and save print("[Generate pseudo labels]") loader = dataset.init_test_dataset(self.config, self.config.target, set="train", selected=self.target_all) interp = nn.Upsample(size=(1024, 2048), mode="bilinear", align_corners=True) self.plabel_path = osp.join(self.config.plabel, self.config.note, str(round)) mkdir(self.plabel_path) self.config.target_data_dir = self.plabel_path self.pool = Pool( ) # save the probability of pseudo labels for the pixel-wise similarity matchinng, which is detailed around Eq. (9) accs = AverageMeter() # Counter props = AverageMeter() # Counter cls_acc = GroupAverageMeter() # Class-wise Acc/Prop of Pseudo labels self.mean_memo = {i: [] for i in range(self.config.num_classes)} with torch.no_grad(): for index, batch in tqdm(enumerate(loader)): image, label, _, _, name = batch label = label.cuda() img_name = name[0].split("/")[-1] dir_name = name[0].split("/")[0] img_name = img_name.replace("leftImg8bit", "gtFine_labelIds") temp_dir = osp.join(self.plabel_path, dir_name) if not os.path.exists(temp_dir): os.mkdir(temp_dir) output = self.model.forward(image.cuda()) output = interp(output) # pseudo labels selected by glocal threshold mask, plabel = thres_cb_plabel(output, self.cb_thres, num_cls=self.config.num_classes) # pseudo labels selected by local threshold mask2, plabel2 = gene_plabel_prop(output, self.config.cb_prop) # mask fusion # The fusion strategy is detailed in Sec. 3.3 of paper mask, plabel = mask_fusion(output, mask, mask2) self.pool.update_pool(output, mask=mask.float()) acc, prop, cls_dict = Acc(plabel, label, num_cls=self.config.num_classes) cnt = (plabel != 255).sum().item() accs.update(acc, cnt) props.update(prop, 1) cls_acc.update(cls_dict) plabel = plabel.view(1024, 2048) plabel = plabel.cpu().numpy() plabel = np.asarray(plabel, dtype=np.uint8) plabel = Image.fromarray(plabel) plabel.save("%s/%s.png" % (temp_dir, img_name.split(".")[0])) print('The Accuracy :{:.2%} and proportion :{:.2%} of Pseudo Labels'. format(accs.avg.item(), props.avg.item())) if self.config.neptune: neptune.send_metric("Acc", accs.avg) neptune.send_metric("Prop", props.avg)
def validate(self): self.model = self.model.eval() testloader = dataset.init_test_dataset(self.config, self.config.target, set='val') interp = nn.Upsample(size=(1024, 2048), mode='bilinear', align_corners=True) union = torch.zeros(self.config.num_classes, 1, dtype=torch.float).cuda().float() inter = torch.zeros(self.config.num_classes, 1, dtype=torch.float).cuda().float() preds = torch.zeros(self.config.num_classes, 1, dtype=torch.float).cuda().float() with torch.no_grad(): for index, batch in tqdm(enumerate(testloader)): image, label, _, _, name = batch output = self.model(image.cuda()) label = label.cuda() output = interp(output).squeeze() C, H, W = output.shape Mask = (label.squeeze()) < C pred_e = torch.linspace(0, C - 1, steps=C).view(C, 1, 1) pred_e = pred_e.repeat(1, H, W).cuda() pred = output.argmax(dim=0).float() pred_mask = torch.eq(pred_e, pred).byte() pred_mask = pred_mask * Mask.byte() label_e = torch.linspace(0, C - 1, steps=C).view(C, 1, 1) label_e = label_e.repeat(1, H, W).cuda() label = label.view(1, H, W) label_mask = torch.eq(label_e, label.float()).byte() label_mask = label_mask * Mask.byte() tmp_inter = label_mask + pred_mask.byte() cu_inter = (tmp_inter == 2).view(C, -1).sum(dim=1, keepdim=True).float() cu_union = (tmp_inter > 0).view(C, -1).sum(dim=1, keepdim=True).float() cu_preds = pred_mask.view(C, -1).sum(dim=1, keepdim=True).float() union += cu_union inter += cu_inter preds += cu_preds iou = inter / union acc = inter / preds if C == 16: iou = iou.squeeze() class13_iou = torch.cat((iou[:3], iou[6:])) class13_miou = class13_iou.mean().item() print('13-Class mIoU:{:.2%}'.format(class13_miou)) mIoU = iou.mean().item() mAcc = acc.mean().item() iou = iou.cpu().numpy() print('mIoU: {:.2%} mAcc : {:.2%} '.format(mIoU, mAcc)) if self.config.neptune: neptune.send_metric('mIoU', mIoU) neptune.send_metric('mAcc', mAcc) return mIoU