def setup_experiment_dirs(): """ create dir and file :return: """ check_dir(args.experiment_dir) check_dir(args.model_dir) args.log_path = os.path.join(args.experiment_dir, 'log.txt')
def save_batch_gt(predict, name_list, label_list, batch_index): # outputs = predict.data.cpu().numpy() predict_list = [] gt_list = [] for ii, (msk, name, label) in enumerate(zip(predict, name_list, label_list)): label = int(label) gt_path = gt_dir + '/' + name + '.png' img_path = img_dir + '/' + name + '.jpg' gt = Image.open(gt_path).convert('P') img = Image.open(img_path).convert('RGB') # img = np.array(img, dtype=np.uint8) w, h = gt.size # w, h = 320, 320 predict_temp = predict[ii] predict_temp = predict_temp.unsqueeze(0) predict_temp = F.interpolate(predict_temp, (h, w), mode='bilinear') predict_temp = predict_temp.squeeze(0) # predict_temp = predict_temp.cpu().numpy() # for ii in range(predict_temp.shape[0]): # predict_temp[ii] = cv2.GaussianBlur(predict_temp[ii], (7, 7), 0) predict_temp, _ = crf(img, predict_temp) # _, predict_temp = torch.max(predict_temp, dim=0) # predict_temp = predict_temp.squeeze() # predict_temp = predict_temp.data.cpu().numpy() predict_list.append(predict_temp) output_img = np.zeros((h, w, 3), dtype=np.uint8) for i, color in enumerate(index2color): output_img[predict_temp == i, :] = color # output_img = np.resize(output_img, (320, 320)) output_img = Image.fromarray(output_img) check_dir(args.experiment_dir + '/_vis_val_que') output_img.save('{}/{}.png'.format( args.experiment_dir + '/_vis_val_que', str(batch_index * args.b_v + ii) + '_pre', 'PNG')) gt = np.array(gt) gt = gt.astype(np.int64) gt_ori = gt.copy() gt[gt != label] = 0 gt[gt == label] = 1 gt[gt_ori == 255] = 0 gt_list.append(gt) return predict_list, gt_list
def setup_experiment_dirs(): check_dir(args.experiment_dir) check_dir(args.model_dir) args.log_path = os.path.join(args.experiment_dir, 'log.txt')
def val_pair(self): loss_1_epoch = AverageMeter() self.model.eval() self.cal_miou = Cal_mIoU() with torch.no_grad(): for batch_index, (data_list, gt_list, name_list, label_pair) in enumerate(self.val_loader_pair): if self.break_for_debug: if batch_index == 5: break data_sup = data_list[0] data_que = data_list[1] gt_sup = gt_list[0] gt_que = gt_list[1] name_sup = name_list[0] name_que = name_list[1] data_sup = data_sup.to(device=device_ids[0]) data_que = data_que.to(device=device_ids[0]) gt_sup = gt_sup.to(device=device_ids[0]) gt_que = gt_que.to(device=device_ids[0]) gt_sup_1 = gt_sup.type(torch.FloatTensor).to(device=device_ids[0]) output = self.model(x_q=data_que, x_s=data_sup, x_s_mask=gt_sup_1, is_train=False) loss_1 = self.criterion(output[1], gt_que) # in_crf = output[1] # in_crf = F.softmax(in_crf, dim=1) # predict_2d = perform_crf(data_que.cpu(), in_crf) _, predict = torch.max(output[1], dim=1) # predict_2d = torch.from_numpy(predict_2d).to(device=device_ids[0]).unsqueeze(0) loss_1_epoch.update(loss_1.data.item()) print('train:\t{}|{}\t{}|{}\tloss_1:{}\t'.format(self.current_epoch, args.epoches, batch_index + 1, len(self.val_loader_pair), loss_1_epoch.avg)) # _, predict = torch.max(predict_2d, dim=1) outputs = predict.data.cpu().numpy() for ii, msk in enumerate(outputs): sz = msk.shape[0] output_img = np.zeros((sz, sz, 3), dtype=np.uint8) for i, color in enumerate(index2color): output_img[msk == i, :] = color output_img = Image.fromarray(output_img) check_dir(args.experiment_dir + '/_vis_val_que') output_img.save('{}/{}.png'.format(args.experiment_dir + '/_vis_val_que', str(batch_index*args.b_v+ ii) + '_pre', 'PNG')) outputs = gt_que.data.cpu().numpy() for ii, msk in enumerate(outputs): sz = msk.shape[0] output_img = np.zeros((sz, sz, 3), dtype=np.uint8) for i, color in enumerate(index2color): output_img[msk == i, :] = color output_img = Image.fromarray(output_img) output_img.save('{}/{}.png'.format(args.experiment_dir + '/_vis_val_que', str(batch_index*args.b_v+ ii) + '_gt_que'), 'PNG') outputs = gt_sup.data.cpu().numpy() for ii, msk in enumerate(outputs): sz = msk.shape[0] output_img = np.zeros((sz, sz, 3), dtype=np.uint8) for i, color in enumerate(index2color): output_img[msk == i, :] = color output_img = Image.fromarray(output_img) output_img.save('{}/{}.png'.format(args.experiment_dir + '/_vis_val_que', str(batch_index * args.b_v + ii) + '_gt_sup'), 'PNG') predict_list, gt_list = save_batch_gt(predict=F.softmax(output[1], dim=1), name_list=name_que, label_list=label_pair, batch_index=batch_index) FB_IoU, fore_IoU = self.cal_miou.caculate_miou(predict_list, gt_list, 2) print('the total miou(on the original_size) is ', FB_IoU, fore_IoU) self.val_performence_current_epoch = fore_IoU string_1 = 'miou calculated on original---{} {} {} {}'.format(self.flag_val, self.current_epoch, fore_IoU, FB_IoU) with open(args.log_path, 'a+') as f: f.write(string_1 + '\n') self.miou = fore_IoU
def train_pair(self): loss_1_epoch = AverageMeter() loss_2_epoch = AverageMeter() loss_3_epoch = AverageMeter() miou_epoch = AverageMeter() miou_fore_epoch = AverageMeter() self.model.train() for batch_index, (data_list, gt_list, name_list, label_pair) in enumerate(self.train_loader_pair): if self.break_for_debug: if batch_index == 5: break data_sup = data_list[0] data_que = data_list[1] gt_sup = gt_list[0] gt_que = gt_list[1] name_sup = name_list[0] name_que = name_list[1] data_sup = data_sup.to(device=device_ids[0]) data_que = data_que.to(device=device_ids[0]) gt_sup = gt_sup.to(device=device_ids[0]) gt_que = gt_que.to(device=device_ids[0]) gt_sup_1 = gt_sup.type(torch.FloatTensor).to(device=device_ids[0]) gt_que_1 = gt_que.type(torch.FloatTensor).to(device=device_ids[0]) output = self.model(x_q=data_que, x_s=data_sup, x_s_mask=gt_sup_1, is_train=True) loss_1 = self.criterion(output[1], gt_que) loss_2 = self.re_loss(output[2], self.get_R_truth(gt_que_1, gt_sup_1)) loss_3 = self.criterion(output[0], gt_que) loss = loss_1 + loss_2 + loss_3 self.optimizer_pair.zero_grad() loss.backward() self.optimizer_pair.step() _, predict = torch.max(output[1], dim=1) predict_temp = predict.cpu().data.numpy() gt_temp = gt_que.cpu().data.numpy() miou, miou_fore = caculate_miou(predict_temp, gt_temp, 2) loss_1_epoch.update(loss_1.data.item()) loss_2_epoch.update(loss_2.data.item()) loss_3_epoch.update(loss_3.data.item()) miou_epoch.update(miou) miou_fore_epoch.update(miou_fore) print('train:\t{}|{}\t{}|{}\tloss_1:{}\tloss_2:{}\tloss_3:{}\tmiou:{}\tmiou_fore:{}'.format( self.current_epoch, args.epoches, batch_index + 1, len(self.train_loader_pair), loss_1_epoch.avg, loss_2_epoch.avg, loss_3_epoch.avg, miou_epoch.avg, miou_fore_epoch.avg)) outputs = predict.data.cpu().numpy() for ii, msk in enumerate(outputs): sz = msk.shape[0] output_img = np.zeros((sz, sz, 3), dtype=np.uint8) for i, color in enumerate(index2color): output_img[msk == i, :] = color output_img = Image.fromarray(output_img) check_dir(args.experiment_dir + '/_vis_train_que') output_img.save('{}/{}.png'.format(args.experiment_dir + '/_vis_train_que', name_que[ii]), 'PNG')
def save_batch_gt(predict, name_list, label_list, batch_index): ''' To resize the image and predicted result to its original size, and use the denseCRF to optimize the predicted output. :param predict: the predicted results of the model :param name_list: the list of its corresponding name of the images :param label_list: the list of the class label of the images :param batch_index: the number of the epoch, it is used to saved the predicted results in order from small to large :return: predict_list: the list of predicted results after CRF module :return: gt_list: the list of ground-truth ''' # outputs = predict.data.cpu().numpy() predict_list = [] gt_list = [] for ii, (msk, name, label) in enumerate(zip(predict, name_list, label_list)): label = int(label) gt_path = gt_dir + '/' + name+'.png' img_path = img_dir + '/' + name+'.jpg' gt = Image.open(gt_path).convert('P') img = Image.open(img_path).convert('RGB') # img = np.array(img, dtype=np.uint8) w, h = gt.size # w, h = 320, 320 predict_temp = predict[ii] predict_temp = predict_temp.unsqueeze(0) predict_temp = F.interpolate(predict_temp, (h, w), mode='bilinear') predict_temp = predict_temp.squeeze(0) # predict_temp = predict_temp.cpu().numpy() # for ii in range(predict_temp.shape[0]): # predict_temp[ii] = cv2.GaussianBlur(predict_temp[ii], (7, 7), 0) predict_temp, _ = crf(img, predict_temp) # _, predict_temp = torch.max(predict_temp, dim=0) # predict_temp = predict_temp.squeeze() # predict_temp = predict_temp.data.cpu().numpy() predict_list.append(predict_temp) output_img = np.zeros((h, w, 3), dtype=np.uint8) for i, color in enumerate(index2color): output_img[predict_temp == i, :] = color # output_img = np.resize(output_img, (320, 320)) output_img = Image.fromarray(output_img) check_dir(args.experiment_dir + '/_vis_val_que') output_img.save('{}/{}.png'.format(args.experiment_dir + '/_vis_val_que', str(batch_index * args.b_v + ii) + '_pre', 'PNG')) gt = np.array(gt) gt = gt.astype(np.int64) gt_ori = gt.copy() gt[gt != label] = 0 gt[gt == label] = 1 gt[gt_ori == 255] = 0 gt_list.append(gt) return predict_list, gt_list