Esempio n. 1
0
    def val_pair(self):
        loss_1_epoch = AverageMeter()
        self.model.eval()
        self.cal_miou = Cal_mIoU()
        with torch.no_grad():
            for batch_index, (data_list, gt_list, name_list, label_pair) in enumerate(self.val_loader_pair):
                if self.break_for_debug:
                    if batch_index == 5:
                        break
                data_sup = data_list[0]
                data_que = data_list[1]

                gt_sup = gt_list[0]
                gt_que = gt_list[1]

                name_sup = name_list[0]
                name_que = name_list[1]

                data_sup = data_sup.to(device=device_ids[0])
                data_que = data_que.to(device=device_ids[0])
                gt_sup = gt_sup.to(device=device_ids[0])
                gt_que = gt_que.to(device=device_ids[0])

                gt_sup_1 = gt_sup.type(torch.FloatTensor).to(device=device_ids[0])
                output = self.model(x_q=data_que, x_s=data_sup, x_s_mask=gt_sup_1, is_train=False)
                loss_1 = self.criterion(output[1], gt_que)

                # in_crf = output[1]
                # in_crf = F.softmax(in_crf, dim=1)
                # predict_2d = perform_crf(data_que.cpu(), in_crf)

                _, predict = torch.max(output[1], dim=1)

                # predict_2d = torch.from_numpy(predict_2d).to(device=device_ids[0]).unsqueeze(0)

                loss_1_epoch.update(loss_1.data.item())

                print('train:\t{}|{}\t{}|{}\tloss_1:{}\t'.format(self.current_epoch, args.epoches,
                                                                 batch_index + 1,
                                                                 len(self.val_loader_pair),
                                                                 loss_1_epoch.avg))
                # _, predict = torch.max(predict_2d, dim=1)
                outputs = predict.data.cpu().numpy()
                for ii, msk in enumerate(outputs):
                    sz = msk.shape[0]
                    output_img = np.zeros((sz, sz, 3), dtype=np.uint8)
                    for i, color in enumerate(index2color):
                        output_img[msk == i, :] = color
                    output_img = Image.fromarray(output_img)
                    check_dir(args.experiment_dir + '/_vis_val_que')
                    output_img.save('{}/{}.png'.format(args.experiment_dir
                                                       + '/_vis_val_que', str(batch_index*args.b_v+ ii)
                                                       + '_pre', 'PNG'))

                outputs = gt_que.data.cpu().numpy()
                for ii, msk in enumerate(outputs):
                    sz = msk.shape[0]
                    output_img = np.zeros((sz, sz, 3), dtype=np.uint8)
                    for i, color in enumerate(index2color):
                        output_img[msk == i, :] = color
                    output_img = Image.fromarray(output_img)
                    output_img.save('{}/{}.png'.format(args.experiment_dir + '/_vis_val_que',
                                                       str(batch_index*args.b_v+ ii) + '_gt_que'), 'PNG')

                outputs = gt_sup.data.cpu().numpy()
                for ii, msk in enumerate(outputs):
                    sz = msk.shape[0]
                    output_img = np.zeros((sz, sz, 3), dtype=np.uint8)
                    for i, color in enumerate(index2color):
                        output_img[msk == i, :] = color
                    output_img = Image.fromarray(output_img)
                    output_img.save('{}/{}.png'.format(args.experiment_dir + '/_vis_val_que',
                                                       str(batch_index * args.b_v + ii) + '_gt_sup'), 'PNG')

                predict_list, gt_list = save_batch_gt(predict=F.softmax(output[1], dim=1), name_list=name_que, label_list=label_pair, batch_index=batch_index)
                FB_IoU, fore_IoU = self.cal_miou.caculate_miou(predict_list, gt_list, 2)
                print('the total miou(on the original_size) is ', FB_IoU, fore_IoU)

            self.val_performence_current_epoch = fore_IoU

            string_1 = 'miou calculated on original---{}   {}  {} {}'.format(self.flag_val, self.current_epoch,
                                                                             fore_IoU,
                                                                             FB_IoU)
            with open(args.log_path, 'a+') as f:
                f.write(string_1 + '\n')

            self.miou = fore_IoU
Esempio n. 2
0
class Trainer():
    def __init__(self, model=None, criterion=None, flag_val=None, break_for_debug=True):
        self.model = model
        self.criterion = criterion

        self.re_loss = get_reconstruction_loss()
        # self.optimizer_fcn = get_optimizer_fcn(self.model)
        self.optimizer_pair = get_optimizer_pair(self.model)

        self.train_loader_pair, self.val_loader_pair = get_dataloader_few_shot()
        self.break_for_debug = break_for_debug
        self.flag_val = flag_val
        self.current_epoch = 0
        self.val_performence_current_epoch = 0
        self.train_performence_current_epoch = 0
        self.best_val = 0
        self.start()
        self.miou = 0
        # self.guass_net = Gauss_Net()

    def start(self):
        if self.flag_val == 'just_train':  ## 一直训练
            pass
        if self.flag_val == 'just_val':  # 仅仅val一次 
            self.val_pair()
        if self.flag_val == 'train_val':  # 每次训练后val一次
            for epoch in range(args.start_epoch, args.epoches):
                adjust_learning_rate(self.optimizer_pair, epoch)
                self.current_epoch = epoch

                self.train_pair()
                self.val_pair()

                if self.val_performence_current_epoch > self.best_val:
                    self.best_val = self.val_performence_current_epoch
                    model_params_save_path = os.path.join(args.model_dir, 'epoch:_{}_perform:_{}_{}.pkl'
                                                          .format(self.current_epoch,
                                                                  self.val_performence_current_epoch,
                                                                  self.flag_val))
                    torch.save(self.model.state_dict(), model_params_save_path)
                    delete_existed_params(args.model_dir)
                    print('\nnew best val saved epoch {} best val:{} '.format(self.current_epoch,
                                                                              self.val_performence_current_epoch))
                    string = '{}   {}  {} {}'.format(self.flag_val, self.current_epoch, self.val_performence_current_epoch, self.miou)
                    with open(args.log_path, 'a+') as f:
                        f.write(string + '\n')
                # adjust_learning_rate(self.optimizer, epoch, args.epochs, args.lr)
        # self.writer.close()

    def get_R_truth(self, gt_que_1, gt_sup_1):
        gt_que_1 = gt_que_1.unsqueeze(1)
        gt_que_1 = F.interpolate(gt_que_1, (20, 20), mode='bilinear')
        gt_que_1 = gt_que_1.view(gt_que_1.size()[0], -1, 1)

        gt_sup_1 = gt_sup_1.unsqueeze(1)
        gt_sup_1 = F.interpolate(gt_sup_1, (20, 20), mode='bilinear')
        gt_sup_1 = gt_sup_1.view(gt_sup_1.size()[0], 1, -1)
        R_truth = torch.matmul(gt_que_1, gt_sup_1)
        return R_truth

    def train_pair(self):
        loss_1_epoch = AverageMeter()
        loss_2_epoch = AverageMeter()
        loss_3_epoch = AverageMeter()
        miou_epoch = AverageMeter()
        miou_fore_epoch = AverageMeter()
        self.model.train()
        for batch_index, (data_list, gt_list, name_list, label_pair) in enumerate(self.train_loader_pair):
            if self.break_for_debug:
                if batch_index == 5:
                    break
            data_sup = data_list[0]
            data_que = data_list[1]

            gt_sup = gt_list[0]
            gt_que = gt_list[1]

            name_sup = name_list[0]
            name_que = name_list[1]

            data_sup = data_sup.to(device=device_ids[0])
            data_que = data_que.to(device=device_ids[0])
            gt_sup = gt_sup.to(device=device_ids[0])
            gt_que = gt_que.to(device=device_ids[0])

            gt_sup_1 = gt_sup.type(torch.FloatTensor).to(device=device_ids[0])
            gt_que_1 = gt_que.type(torch.FloatTensor).to(device=device_ids[0])
            output = self.model(x_q=data_que, x_s=data_sup, x_s_mask=gt_sup_1, is_train=True)

            loss_1 = self.criterion(output[1], gt_que)

            loss_2 = self.re_loss(output[2], self.get_R_truth(gt_que_1, gt_sup_1))

            loss_3 = self.criterion(output[0], gt_que)

            loss = loss_1 + loss_2 + loss_3

            self.optimizer_pair.zero_grad()
            loss.backward()
            self.optimizer_pair.step()

            _, predict = torch.max(output[1], dim=1)

            predict_temp = predict.cpu().data.numpy()
            gt_temp = gt_que.cpu().data.numpy()
            miou, miou_fore = caculate_miou(predict_temp, gt_temp, 2)
            loss_1_epoch.update(loss_1.data.item())
            loss_2_epoch.update(loss_2.data.item())
            loss_3_epoch.update(loss_3.data.item())

            miou_epoch.update(miou)
            miou_fore_epoch.update(miou_fore)

            print('train:\t{}|{}\t{}|{}\tloss_1:{}\tloss_2:{}\tloss_3:{}\tmiou:{}\tmiou_fore:{}'.format(
                                                                        self.current_epoch, args.epoches,
                                                                        batch_index + 1, len(self.train_loader_pair),
                                                                        loss_1_epoch.avg, loss_2_epoch.avg,
                                                                        loss_3_epoch.avg,
                                                                        miou_epoch.avg, miou_fore_epoch.avg))
            outputs = predict.data.cpu().numpy()
            for ii, msk in enumerate(outputs):
                sz = msk.shape[0]
                output_img = np.zeros((sz, sz, 3), dtype=np.uint8)
                for i, color in enumerate(index2color):
                    output_img[msk == i, :] = color
                output_img = Image.fromarray(output_img)
                check_dir(args.experiment_dir + '/_vis_train_que')
                output_img.save('{}/{}.png'.format(args.experiment_dir + '/_vis_train_que', name_que[ii]), 'PNG')

    def val_pair(self):
        loss_1_epoch = AverageMeter()
        self.model.eval()
        self.cal_miou = Cal_mIoU()
        with torch.no_grad():
            for batch_index, (data_list, gt_list, name_list, label_pair) in enumerate(self.val_loader_pair):
                if self.break_for_debug:
                    if batch_index == 5:
                        break
                data_sup = data_list[0]
                data_que = data_list[1]

                gt_sup = gt_list[0]
                gt_que = gt_list[1]

                name_sup = name_list[0]
                name_que = name_list[1]

                data_sup = data_sup.to(device=device_ids[0])
                data_que = data_que.to(device=device_ids[0])
                gt_sup = gt_sup.to(device=device_ids[0])
                gt_que = gt_que.to(device=device_ids[0])

                gt_sup_1 = gt_sup.type(torch.FloatTensor).to(device=device_ids[0])
                output = self.model(x_q=data_que, x_s=data_sup, x_s_mask=gt_sup_1, is_train=False)
                loss_1 = self.criterion(output[1], gt_que)

                # in_crf = output[1]
                # in_crf = F.softmax(in_crf, dim=1)
                # predict_2d = perform_crf(data_que.cpu(), in_crf)

                _, predict = torch.max(output[1], dim=1)

                # predict_2d = torch.from_numpy(predict_2d).to(device=device_ids[0]).unsqueeze(0)

                loss_1_epoch.update(loss_1.data.item())

                print('train:\t{}|{}\t{}|{}\tloss_1:{}\t'.format(self.current_epoch, args.epoches,
                                                                 batch_index + 1,
                                                                 len(self.val_loader_pair),
                                                                 loss_1_epoch.avg))
                # _, predict = torch.max(predict_2d, dim=1)
                outputs = predict.data.cpu().numpy()
                for ii, msk in enumerate(outputs):
                    sz = msk.shape[0]
                    output_img = np.zeros((sz, sz, 3), dtype=np.uint8)
                    for i, color in enumerate(index2color):
                        output_img[msk == i, :] = color
                    output_img = Image.fromarray(output_img)
                    check_dir(args.experiment_dir + '/_vis_val_que')
                    output_img.save('{}/{}.png'.format(args.experiment_dir
                                                       + '/_vis_val_que', str(batch_index*args.b_v+ ii)
                                                       + '_pre', 'PNG'))

                outputs = gt_que.data.cpu().numpy()
                for ii, msk in enumerate(outputs):
                    sz = msk.shape[0]
                    output_img = np.zeros((sz, sz, 3), dtype=np.uint8)
                    for i, color in enumerate(index2color):
                        output_img[msk == i, :] = color
                    output_img = Image.fromarray(output_img)
                    output_img.save('{}/{}.png'.format(args.experiment_dir + '/_vis_val_que',
                                                       str(batch_index*args.b_v+ ii) + '_gt_que'), 'PNG')

                outputs = gt_sup.data.cpu().numpy()
                for ii, msk in enumerate(outputs):
                    sz = msk.shape[0]
                    output_img = np.zeros((sz, sz, 3), dtype=np.uint8)
                    for i, color in enumerate(index2color):
                        output_img[msk == i, :] = color
                    output_img = Image.fromarray(output_img)
                    output_img.save('{}/{}.png'.format(args.experiment_dir + '/_vis_val_que',
                                                       str(batch_index * args.b_v + ii) + '_gt_sup'), 'PNG')

                predict_list, gt_list = save_batch_gt(predict=F.softmax(output[1], dim=1), name_list=name_que, label_list=label_pair, batch_index=batch_index)
                FB_IoU, fore_IoU = self.cal_miou.caculate_miou(predict_list, gt_list, 2)
                print('the total miou(on the original_size) is ', FB_IoU, fore_IoU)

            self.val_performence_current_epoch = fore_IoU

            string_1 = 'miou calculated on original---{}   {}  {} {}'.format(self.flag_val, self.current_epoch,
                                                                             fore_IoU,
                                                                             FB_IoU)
            with open(args.log_path, 'a+') as f:
                f.write(string_1 + '\n')

            self.miou = fore_IoU