def val_pair(self): loss_1_epoch = AverageMeter() self.model.eval() self.cal_miou = Cal_mIoU() with torch.no_grad(): for batch_index, (data_list, gt_list, name_list, label_pair) in enumerate(self.val_loader_pair): if self.break_for_debug: if batch_index == 5: break data_sup = data_list[0] data_que = data_list[1] gt_sup = gt_list[0] gt_que = gt_list[1] name_sup = name_list[0] name_que = name_list[1] data_sup = data_sup.to(device=device_ids[0]) data_que = data_que.to(device=device_ids[0]) gt_sup = gt_sup.to(device=device_ids[0]) gt_que = gt_que.to(device=device_ids[0]) gt_sup_1 = gt_sup.type(torch.FloatTensor).to(device=device_ids[0]) output = self.model(x_q=data_que, x_s=data_sup, x_s_mask=gt_sup_1, is_train=False) loss_1 = self.criterion(output[1], gt_que) # in_crf = output[1] # in_crf = F.softmax(in_crf, dim=1) # predict_2d = perform_crf(data_que.cpu(), in_crf) _, predict = torch.max(output[1], dim=1) # predict_2d = torch.from_numpy(predict_2d).to(device=device_ids[0]).unsqueeze(0) loss_1_epoch.update(loss_1.data.item()) print('train:\t{}|{}\t{}|{}\tloss_1:{}\t'.format(self.current_epoch, args.epoches, batch_index + 1, len(self.val_loader_pair), loss_1_epoch.avg)) # _, predict = torch.max(predict_2d, dim=1) outputs = predict.data.cpu().numpy() for ii, msk in enumerate(outputs): sz = msk.shape[0] output_img = np.zeros((sz, sz, 3), dtype=np.uint8) for i, color in enumerate(index2color): output_img[msk == i, :] = color output_img = Image.fromarray(output_img) check_dir(args.experiment_dir + '/_vis_val_que') output_img.save('{}/{}.png'.format(args.experiment_dir + '/_vis_val_que', str(batch_index*args.b_v+ ii) + '_pre', 'PNG')) outputs = gt_que.data.cpu().numpy() for ii, msk in enumerate(outputs): sz = msk.shape[0] output_img = np.zeros((sz, sz, 3), dtype=np.uint8) for i, color in enumerate(index2color): output_img[msk == i, :] = color output_img = Image.fromarray(output_img) output_img.save('{}/{}.png'.format(args.experiment_dir + '/_vis_val_que', str(batch_index*args.b_v+ ii) + '_gt_que'), 'PNG') outputs = gt_sup.data.cpu().numpy() for ii, msk in enumerate(outputs): sz = msk.shape[0] output_img = np.zeros((sz, sz, 3), dtype=np.uint8) for i, color in enumerate(index2color): output_img[msk == i, :] = color output_img = Image.fromarray(output_img) output_img.save('{}/{}.png'.format(args.experiment_dir + '/_vis_val_que', str(batch_index * args.b_v + ii) + '_gt_sup'), 'PNG') predict_list, gt_list = save_batch_gt(predict=F.softmax(output[1], dim=1), name_list=name_que, label_list=label_pair, batch_index=batch_index) FB_IoU, fore_IoU = self.cal_miou.caculate_miou(predict_list, gt_list, 2) print('the total miou(on the original_size) is ', FB_IoU, fore_IoU) self.val_performence_current_epoch = fore_IoU string_1 = 'miou calculated on original---{} {} {} {}'.format(self.flag_val, self.current_epoch, fore_IoU, FB_IoU) with open(args.log_path, 'a+') as f: f.write(string_1 + '\n') self.miou = fore_IoU
class Trainer(): def __init__(self, model=None, criterion=None, flag_val=None, break_for_debug=True): self.model = model self.criterion = criterion self.re_loss = get_reconstruction_loss() # self.optimizer_fcn = get_optimizer_fcn(self.model) self.optimizer_pair = get_optimizer_pair(self.model) self.train_loader_pair, self.val_loader_pair = get_dataloader_few_shot() self.break_for_debug = break_for_debug self.flag_val = flag_val self.current_epoch = 0 self.val_performence_current_epoch = 0 self.train_performence_current_epoch = 0 self.best_val = 0 self.start() self.miou = 0 # self.guass_net = Gauss_Net() def start(self): if self.flag_val == 'just_train': ## 一直训练 pass if self.flag_val == 'just_val': # 仅仅val一次 self.val_pair() if self.flag_val == 'train_val': # 每次训练后val一次 for epoch in range(args.start_epoch, args.epoches): adjust_learning_rate(self.optimizer_pair, epoch) self.current_epoch = epoch self.train_pair() self.val_pair() if self.val_performence_current_epoch > self.best_val: self.best_val = self.val_performence_current_epoch model_params_save_path = os.path.join(args.model_dir, 'epoch:_{}_perform:_{}_{}.pkl' .format(self.current_epoch, self.val_performence_current_epoch, self.flag_val)) torch.save(self.model.state_dict(), model_params_save_path) delete_existed_params(args.model_dir) print('\nnew best val saved epoch {} best val:{} '.format(self.current_epoch, self.val_performence_current_epoch)) string = '{} {} {} {}'.format(self.flag_val, self.current_epoch, self.val_performence_current_epoch, self.miou) with open(args.log_path, 'a+') as f: f.write(string + '\n') # adjust_learning_rate(self.optimizer, epoch, args.epochs, args.lr) # self.writer.close() def get_R_truth(self, gt_que_1, gt_sup_1): gt_que_1 = gt_que_1.unsqueeze(1) gt_que_1 = F.interpolate(gt_que_1, (20, 20), mode='bilinear') gt_que_1 = gt_que_1.view(gt_que_1.size()[0], -1, 1) gt_sup_1 = gt_sup_1.unsqueeze(1) gt_sup_1 = F.interpolate(gt_sup_1, (20, 20), mode='bilinear') gt_sup_1 = gt_sup_1.view(gt_sup_1.size()[0], 1, -1) R_truth = torch.matmul(gt_que_1, gt_sup_1) return R_truth def train_pair(self): loss_1_epoch = AverageMeter() loss_2_epoch = AverageMeter() loss_3_epoch = AverageMeter() miou_epoch = AverageMeter() miou_fore_epoch = AverageMeter() self.model.train() for batch_index, (data_list, gt_list, name_list, label_pair) in enumerate(self.train_loader_pair): if self.break_for_debug: if batch_index == 5: break data_sup = data_list[0] data_que = data_list[1] gt_sup = gt_list[0] gt_que = gt_list[1] name_sup = name_list[0] name_que = name_list[1] data_sup = data_sup.to(device=device_ids[0]) data_que = data_que.to(device=device_ids[0]) gt_sup = gt_sup.to(device=device_ids[0]) gt_que = gt_que.to(device=device_ids[0]) gt_sup_1 = gt_sup.type(torch.FloatTensor).to(device=device_ids[0]) gt_que_1 = gt_que.type(torch.FloatTensor).to(device=device_ids[0]) output = self.model(x_q=data_que, x_s=data_sup, x_s_mask=gt_sup_1, is_train=True) loss_1 = self.criterion(output[1], gt_que) loss_2 = self.re_loss(output[2], self.get_R_truth(gt_que_1, gt_sup_1)) loss_3 = self.criterion(output[0], gt_que) loss = loss_1 + loss_2 + loss_3 self.optimizer_pair.zero_grad() loss.backward() self.optimizer_pair.step() _, predict = torch.max(output[1], dim=1) predict_temp = predict.cpu().data.numpy() gt_temp = gt_que.cpu().data.numpy() miou, miou_fore = caculate_miou(predict_temp, gt_temp, 2) loss_1_epoch.update(loss_1.data.item()) loss_2_epoch.update(loss_2.data.item()) loss_3_epoch.update(loss_3.data.item()) miou_epoch.update(miou) miou_fore_epoch.update(miou_fore) print('train:\t{}|{}\t{}|{}\tloss_1:{}\tloss_2:{}\tloss_3:{}\tmiou:{}\tmiou_fore:{}'.format( self.current_epoch, args.epoches, batch_index + 1, len(self.train_loader_pair), loss_1_epoch.avg, loss_2_epoch.avg, loss_3_epoch.avg, miou_epoch.avg, miou_fore_epoch.avg)) outputs = predict.data.cpu().numpy() for ii, msk in enumerate(outputs): sz = msk.shape[0] output_img = np.zeros((sz, sz, 3), dtype=np.uint8) for i, color in enumerate(index2color): output_img[msk == i, :] = color output_img = Image.fromarray(output_img) check_dir(args.experiment_dir + '/_vis_train_que') output_img.save('{}/{}.png'.format(args.experiment_dir + '/_vis_train_que', name_que[ii]), 'PNG') def val_pair(self): loss_1_epoch = AverageMeter() self.model.eval() self.cal_miou = Cal_mIoU() with torch.no_grad(): for batch_index, (data_list, gt_list, name_list, label_pair) in enumerate(self.val_loader_pair): if self.break_for_debug: if batch_index == 5: break data_sup = data_list[0] data_que = data_list[1] gt_sup = gt_list[0] gt_que = gt_list[1] name_sup = name_list[0] name_que = name_list[1] data_sup = data_sup.to(device=device_ids[0]) data_que = data_que.to(device=device_ids[0]) gt_sup = gt_sup.to(device=device_ids[0]) gt_que = gt_que.to(device=device_ids[0]) gt_sup_1 = gt_sup.type(torch.FloatTensor).to(device=device_ids[0]) output = self.model(x_q=data_que, x_s=data_sup, x_s_mask=gt_sup_1, is_train=False) loss_1 = self.criterion(output[1], gt_que) # in_crf = output[1] # in_crf = F.softmax(in_crf, dim=1) # predict_2d = perform_crf(data_que.cpu(), in_crf) _, predict = torch.max(output[1], dim=1) # predict_2d = torch.from_numpy(predict_2d).to(device=device_ids[0]).unsqueeze(0) loss_1_epoch.update(loss_1.data.item()) print('train:\t{}|{}\t{}|{}\tloss_1:{}\t'.format(self.current_epoch, args.epoches, batch_index + 1, len(self.val_loader_pair), loss_1_epoch.avg)) # _, predict = torch.max(predict_2d, dim=1) outputs = predict.data.cpu().numpy() for ii, msk in enumerate(outputs): sz = msk.shape[0] output_img = np.zeros((sz, sz, 3), dtype=np.uint8) for i, color in enumerate(index2color): output_img[msk == i, :] = color output_img = Image.fromarray(output_img) check_dir(args.experiment_dir + '/_vis_val_que') output_img.save('{}/{}.png'.format(args.experiment_dir + '/_vis_val_que', str(batch_index*args.b_v+ ii) + '_pre', 'PNG')) outputs = gt_que.data.cpu().numpy() for ii, msk in enumerate(outputs): sz = msk.shape[0] output_img = np.zeros((sz, sz, 3), dtype=np.uint8) for i, color in enumerate(index2color): output_img[msk == i, :] = color output_img = Image.fromarray(output_img) output_img.save('{}/{}.png'.format(args.experiment_dir + '/_vis_val_que', str(batch_index*args.b_v+ ii) + '_gt_que'), 'PNG') outputs = gt_sup.data.cpu().numpy() for ii, msk in enumerate(outputs): sz = msk.shape[0] output_img = np.zeros((sz, sz, 3), dtype=np.uint8) for i, color in enumerate(index2color): output_img[msk == i, :] = color output_img = Image.fromarray(output_img) output_img.save('{}/{}.png'.format(args.experiment_dir + '/_vis_val_que', str(batch_index * args.b_v + ii) + '_gt_sup'), 'PNG') predict_list, gt_list = save_batch_gt(predict=F.softmax(output[1], dim=1), name_list=name_que, label_list=label_pair, batch_index=batch_index) FB_IoU, fore_IoU = self.cal_miou.caculate_miou(predict_list, gt_list, 2) print('the total miou(on the original_size) is ', FB_IoU, fore_IoU) self.val_performence_current_epoch = fore_IoU string_1 = 'miou calculated on original---{} {} {} {}'.format(self.flag_val, self.current_epoch, fore_IoU, FB_IoU) with open(args.log_path, 'a+') as f: f.write(string_1 + '\n') self.miou = fore_IoU