Beispiel #1
0
        def eval_batch(model, image, target):
            # outputs, weather_o, timeofday_o = model(image)
            outputs = model(image)

            # Gathers tensors from different GPUs on a specified device
            # outputs = gather(outputs, 0, dim=0)

            pred = outputs[0]

            # create weather / timeofday target mask #######################
            # b, _, h, w = weather_o.size()
            # weather_t = torch.ones((b, h, w)).long()
            # for bi in range(b): weather_t[bi] *= weather[bi]
            # timeofday_t = torch.ones((b, h, w)).long()
            # for bi in range(b): timeofday_t[bi] *= timeofday[bi]
            ################################################################
            # self.confusion_matrix_weather.update([ m.astype(np.int64) for m in weather_t.numpy() ], weather_o.cpu().numpy().argmax(1))
            # self.confusion_matrix_timeofday.update([ m.astype(np.int64) for m in timeofday_t.numpy() ], timeofday_o.cpu().numpy().argmax(1))

            correct, labeled = utils.batch_pix_accuracy(pred.data, target)
            inter, union = utils.batch_intersection_union(
                pred.data, target, self.nclass)

            # correct_weather, labeled_weather = utils.batch_pix_accuracy(weather_o.data, weather_t)
            # correct_timeofday, labeled_timeofday = utils.batch_pix_accuracy(timeofday_o.data, timeofday_t)

            # return correct, labeled, inter, union, correct_weather, labeled_weather, correct_timeofday, labeled_timeofday
            return correct, labeled, inter, union
 def evaluate(self, x, target=None):
     pred = self.forward(x)
     if isinstance(pred, (tuple, list)):
         pred = pred[0]
     if target is None:
         return pred
     correct, labeled = batch_pix_accuracy(pred.data, target.data)
     inter, union = batch_intersection_union(pred.data, target.data, self.nclass)
     return correct, labeled, inter, union
def get_mIoU_from_softmax(softmax, target, inter_union=False):
    inter, union = utils_seg.batch_intersection_union(softmax, target, 19)
    # total_inter += inter
    # total_union += union
    if inter_union:
        return inter, union
    else:
        idx = union > 0
        IoU = 1.0 * inter[idx] / (np.spacing(1) + union[idx])
        mIoU = IoU.mean()
        return float(np.nan_to_num(mIoU))
Beispiel #4
0
 def eval_batch(model, image, target):
     if image.size(2) * image.size(3) <= 2250000:  # 1500x1500
         outputs = model(image)
         # Gathers tensors from different GPUs on a specified device
         # outputs = gather(outputs, 0, dim=0)
         pred = outputs[0]
         pred = F.upsample(
             pred,
             size=(target.size(1), target.size(2)),
             mode='bilinear'
         )  # if you downsampled the input image due to large size
         correct, labeled = utils.batch_pix_accuracy(pred.data, target)
         inter, union = utils.batch_intersection_union(
             pred.data, target, self.nclass)
         return correct, labeled, inter, union
     else:
         patches, coordinates, sizes = global2patch(image, size_p)
         predicted_patches = [
             torch.zeros(len(coordinates[i]), self.nclass, size_p[0],
                         size_p[1]) for i in range(len(image))
         ]
         for i in range(len(image)):
             j = 0
             while j < len(coordinates[i]):
                 outputs = model(patches[i][j:j + sub_batch_size])[0]
                 predicted_patches[i][j:j + outputs.size()[0]] = outputs
                 j += sub_batch_size
         pred = patch2global(
             predicted_patches, self.nclass, sizes, coordinates,
             size_p)  # merge softmax scores from patches (overlaps)
         inter, union, correct, labeled = 0, 0, 0, 0
         for i in range(len(image)):
             correct_tmp, labeled_tmp = utils.batch_pix_accuracy(
                 pred[i].unsqueeze(0), target[i])
             inter_tmp, union_tmp = utils.batch_intersection_union(
                 pred[i].unsqueeze(0), target[i], self.nclass)
             correct += correct_tmp
             labeled += labeled_tmp
             inter += inter_tmp
             union += union_tmp
         return correct, labeled, inter, union
def get_mIoU(image, target, inter_union=False):
    '''
    image: already transfered by 0.5/0.5
    '''
    outputs = seg(image)
    pred = outputs[0]
    inter, union = utils_seg.batch_intersection_union(pred.data, target, 19)
    # total_inter += inter
    # total_union += union
    if inter_union:
        return inter, union
    else:
        idx = union > 0
        IoU = 1.0 * inter[idx] / (np.spacing(1) + union[idx])
        mIoU = IoU.mean()
        return float(np.nan_to_num(mIoU))
Beispiel #6
0
        def eval_batch(model, image, target):
            r,g,b = image[:, 0, :, :]+1, image[:, 1, :, :]+1, image[:, 2, :, :]+1
            gray = 1. - (0.299*r+0.587*g+0.114*b)/2. # h, w
            gray = gray.unsqueeze(1)
            with torch.no_grad(): fake_B, _, _ = gan.netG_A.forward(image, gray)
            outputs = self.model(fake_B.clamp(-1, 1))

            # Gathers tensors from different GPUs on a specified device
            # outputs = gather(outputs, 0, dim=0)

            pred = outputs[0]
            pred = F.upsample(pred, size=(target.size(1), target.size(2)), mode='bilinear')

            correct, labeled = utils.batch_pix_accuracy(pred.data, target)
            inter, union = utils.batch_intersection_union(pred.data, target, self.nclass)

            return correct, labeled, inter, union
Beispiel #7
0
    def backward_G(self, epoch, seg_criterion=None):
        # self.loss_G_A = torch.zeros(1).cuda()
        pred_fake = self.netD_A.forward(self.fake_B)
        if self.use_seg_D:
            pred_fake_Seg = self.netD_A_Seg.forward(self.fake_B_Seg)
        if self.opt.use_wgan:
            self.loss_G_A = (self.adv_image * -pred_fake.mean())
            if self.use_seg_D:
                self.loss_G_A += -pred_fake_Seg.mean()
        elif self.opt.use_ragan:
            pred_real = self.netD_A.forward(self.real_B)
            self.loss_G_A = (self.adv_image * ((self.criterionGAN(pred_real - torch.mean(pred_fake), False) + self.criterionGAN(pred_fake - torch.mean(pred_real), True)) / 2))
            if self.use_seg_D:
                pred_real_Seg = self.netD_A_Seg.forward(self.real_B_Seg)
                self.loss_G_A += (self.criterionGAN(pred_real_Seg - torch.mean(pred_fake_Seg), False) + self.criterionGAN(pred_fake_Seg - torch.mean(pred_real_Seg), True)) / 2
        else:
            self.loss_G_A = (self.adv_image * self.criterionGAN(pred_fake, True))
            if self.use_seg_D:
                self.loss_G_A += self.criterionGAN(pred_fake_Seg, True)
        
        loss_G_A = 0
        if self.opt.patchD:
            pred_fake_patch = self.netD_P.forward(self.fake_patch)
            if self.use_seg_D:
                pred_fake_patch_Seg = self.netD_P_Seg.forward(self.fake_patch_Seg)
            if self.opt.hybrid_loss:
                loss_G_A += (self.adv_image * self.criterionGAN(pred_fake_patch, True))
                if self.use_seg_D:
                    loss_G_A += self.criterionGAN(pred_fake_patch_Seg, True)
            else:
                pred_real_patch = self.netD_P.forward(self.real_patch)
                loss_G_A += (self.adv_image * ((self.criterionGAN(pred_real_patch - torch.mean(pred_fake_patch), False) + self.criterionGAN(pred_fake_patch - torch.mean(pred_real_patch), True)) / 2))
                if self.use_seg_D:
                    pred_real_patch_Seg = self.netD_P_Seg.forward(self.real_patch_Seg)
                    loss_G_A += (self.criterionGAN(pred_real_patch_Seg - torch.mean(pred_fake_patch_Seg), False) + self.criterionGAN(pred_fake_patch_Seg - torch.mean(pred_real_patch_Seg), True)) / 2
        if self.opt.patchD_3 > 0:   
            for i in range(self.opt.patchD_3):
                pred_fake_patch_1 = self.netD_P.forward(self.fake_patch_1[i])
                if self.use_seg_D:
                    pred_fake_patch_1_Seg = self.netD_P_Seg.forward(self.fake_patch_1_Seg[i])
                if self.opt.hybrid_loss:
                    loss_G_A += (self.adv_image * self.criterionGAN(pred_fake_patch_1, True))
                    if self.use_seg_D:
                        loss_G_A += self.criterionGAN(pred_fake_patch_1_Seg, True)
                else:
                    pred_real_patch_1 = self.netD_P.forward(self.real_patch_1[i])
                    loss_G_A += (self.adv_image * ((self.criterionGAN(pred_real_patch_1 - torch.mean(pred_fake_patch_1), False) + self.criterionGAN(pred_fake_patch_1 - torch.mean(pred_real_patch_1), True)) / 2))
                    if self.use_seg_D:
                        pred_real_patch_1_Seg = self.netD_P_Seg.forward(self.real_patch_1_Seg[i])
                        loss_G_A += (self.criterionGAN(pred_real_patch_1_Seg - torch.mean(pred_fake_patch_1_Seg), False) + self.criterionGAN(pred_fake_patch_1_Seg - torch.mean(pred_real_patch_1_Seg), True)) / 2
                    
            if not self.opt.D_P_times2:
                self.loss_G_A += loss_G_A/float(self.opt.patchD_3 + 1)
            else:
                self.loss_G_A += loss_G_A/float(self.opt.patchD_3 + 1)*2
        else:
            if not self.opt.D_P_times2:
                self.loss_G_A += loss_G_A
            else:
                self.loss_G_A += loss_G_A*2
                
        if epoch < 0:
            vgg_w = 0
        else:
            if seg_criterion is None: vgg_w = 1
            else: vgg_w = 0.3
        if self.opt.vgg > 0:
            self.loss_vgg_b = self.vgg_loss.compute_vgg_loss(self.vgg, self.fake_B, self.real_A) * self.opt.vgg if self.opt.vgg > 0 else 0
            if self.opt.patch_vgg:
                if not self.opt.IN_vgg:
                    loss_vgg_patch = self.vgg_loss.compute_vgg_loss(self.vgg, self.fake_patch, self.input_patch) * self.opt.vgg
                else:
                    loss_vgg_patch = self.vgg_patch_loss.compute_vgg_loss(self.vgg, self.fake_patch, self.input_patch) * self.opt.vgg
                if self.opt.patchD_3 > 0:
                    for i in range(self.opt.patchD_3):
                        if not self.opt.IN_vgg:
                            loss_vgg_patch += self.vgg_loss.compute_vgg_loss(self.vgg, self.fake_patch_1[i], self.input_patch_1[i]) * self.opt.vgg
                        else:
                            loss_vgg_patch += self.vgg_patch_loss.compute_vgg_loss(self.vgg, self.fake_patch_1[i], self.input_patch_1[i]) * self.opt.vgg
                    self.loss_vgg_b += loss_vgg_patch/float(self.opt.patchD_3 + 1)
                else:
                    self.loss_vgg_b += loss_vgg_patch
            self.loss_G = self.loss_G_A + self.loss_vgg_b*vgg_w
        elif self.opt.fcn > 0:
            self.loss_fcn_b = self.fcn_loss.compute_fcn_loss(self.fcn, self.fake_B, self.real_A) * self.opt.fcn if self.opt.fcn > 0 else 0
            if self.opt.patchD:
                loss_fcn_patch = self.fcn_loss.compute_vgg_loss(self.fcn, self.fake_patch, self.input_patch) * self.opt.fcn
                if self.opt.patchD_3 > 0:
                    for i in range(self.opt.patchD_3):
                        loss_fcn_patch += self.fcn_loss.compute_vgg_loss(self.fcn, self.fake_patch_1[i], self.input_patch_1[i]) * self.opt.fcn
                    self.loss_fcn_b += loss_fcn_patch/float(self.opt.patchD_3 + 1)
                else:
                    self.loss_fcn_b += loss_fcn_patch
            self.loss_G = self.loss_G_A + self.loss_fcn_b*vgg_w
        # self.loss_G = self.L1_AB + self.L1_BA

        ##################################
        # if seg is not None:
        if seg_criterion is not None:
            # seg_outputs = seg(self.fake_B)[0]
            # self.loss_Seg = seg_criterion(seg_outputs, self.mask)
            self.loss_Seg = seg_criterion(self.fake_B_Seg, self.mask)
            lambd = 10
            self.loss_G += (lambd * self.loss_Seg)

            # inter, union = utils_seg.batch_intersection_union(seg_outputs.data, self.mask, 19)
            inter, union = utils_seg.batch_intersection_union(self.fake_B_Seg.data, self.mask, 19)
            idx = union > 0
            IoU = 1.0 * inter[idx] / (np.spacing(1) + union[idx])
            self.mIoU = np.nan_to_num(IoU.mean())

            with torch.no_grad():
                # seg_ori_outputs = seg(self.input_A)
                # seg_outputs = seg_ori_outputs[0]
                # inter, union = utils_seg.batch_intersection_union(seg_outputs.data, self.mask, 19)
                inter, union = utils_seg.batch_intersection_union(self.real_A_Seg.data, self.mask, 19)
                idx = union > 0
                IoU = 1.0 * inter[idx] / (np.spacing(1) + union[idx])
                self.mIoU_ori = np.nan_to_num(IoU.mean())

                self.mIoU_delta_mean = 0.8 * self.mIoU_delta_mean + 0.2 * np.round(self.mIoU-self.mIoU_ori, 3) 

            print("G:", self.loss_G.data[0], "mIoU-origin:", np.round(self.mIoU-self.mIoU_ori, 3), "mean:", np.round(self.mIoU_delta_mean, 3), "lum:", 255*(1 - self.input_A_gray).mean(), "epoch:", epoch)

        ##################################

        self.loss_G.backward(retain_graph=True)
    def backward_G(self, epoch, seg_criterion=None, A_gt=False):
        # self.loss_G_A = torch.zeros(1).cuda()
        if not self.multi_D:
            pred_fake = self.netD_A.forward(self.fake_B)
            if self.opt.use_wgan:
                self.loss_G_A = -pred_fake.mean()
            elif self.opt.use_ragan:
                pred_real = self.netD_A.forward(self.real_B)
                self.loss_G_A = (
                    self.criterionGAN(pred_real - torch.mean(pred_fake), False)
                    + self.criterionGAN(pred_fake - torch.mean(pred_real),
                                        True)) / 2
            else:
                self.loss_G_A = self.criterionGAN(pred_fake, True)
        else:
            self.loss_G_A = 0
            for c in range(5):
                # select by category; if empty: tensor([])
                if (self.category == c).nonzero().size(0) == 0: continue
                pred_fake = self.netD_As[c].forward(
                    torch.index_select(
                        self.fake_B, 0,
                        (self.category == c).nonzero().view(-1).type(
                            torch.cuda.LongTensor)))
                if self.opt.use_wgan:
                    self.loss_G_A += -pred_fake.mean()
                elif self.opt.use_ragan:
                    pred_real = self.netD_As[c].forward(
                        torch.index_select(
                            self.real_B, 0,
                            (self.category == c).nonzero().view(-1).type(
                                torch.cuda.LongTensor)))
                    self.loss_G_A += (self.criterionGAN(
                        pred_real - torch.mean(pred_fake),
                        False) + self.criterionGAN(
                            pred_fake - torch.mean(pred_real), True)) / 2
                else:
                    self.loss_G_A += self.criterionGAN(pred_fake, True)

        loss_G_A = 0
        if self.opt.patchD:
            pred_fake_patch = self.netD_P.forward(self.fake_patch)
            if self.opt.hybrid_loss:
                loss_G_A += self.criterionGAN(pred_fake_patch, True)
            else:
                pred_real_patch = self.netD_P.forward(self.real_patch)
                loss_G_A += (self.criterionGAN(
                    pred_real_patch - torch.mean(pred_fake_patch), False) +
                             self.criterionGAN(
                                 pred_fake_patch - torch.mean(pred_real_patch),
                                 True)) / 2
        if self.opt.patchD_3 > 0:
            for i in range(self.opt.patchD_3):
                pred_fake_patch_1 = self.netD_P.forward(self.fake_patch_1[i])
                if self.opt.hybrid_loss:
                    loss_G_A += self.criterionGAN(pred_fake_patch_1, True)
                else:
                    pred_real_patch_1 = self.netD_P.forward(
                        self.real_patch_1[i])
                    loss_G_A += (self.criterionGAN(
                        pred_real_patch_1 - torch.mean(pred_fake_patch_1),
                        False) + self.criterionGAN(
                            pred_fake_patch_1 - torch.mean(pred_real_patch_1),
                            True)) / 2

            if not self.opt.D_P_times2:
                self.loss_G_A += loss_G_A / float(self.opt.patchD_3 + 1)
            else:
                self.loss_G_A += loss_G_A / float(self.opt.patchD_3 + 1) * 2
        else:
            if not self.opt.D_P_times2:
                self.loss_G_A += loss_G_A
            else:
                self.loss_G_A += loss_G_A * 2

        self.loss_G = self.loss_G_A
        if epoch < 0:
            vgg_w = 0
        else:
            if seg_criterion is None: vgg_w = 1
            else: vgg_w = 0
        if vgg_w > 0:
            if self.opt.vgg > 0:
                self.loss_vgg_b = self.vgg_loss.compute_vgg_loss(
                    self.vgg, self.fake_B,
                    self.real_A) * self.opt.vgg if self.opt.vgg > 0 else 0
                if self.opt.patch_vgg:
                    if not self.opt.IN_vgg:
                        loss_vgg_patch = self.vgg_loss.compute_vgg_loss(
                            self.vgg, self.fake_patch,
                            self.input_patch) * self.opt.vgg
                    else:
                        loss_vgg_patch = self.vgg_patch_loss.compute_vgg_loss(
                            self.vgg, self.fake_patch,
                            self.input_patch) * self.opt.vgg
                    if self.opt.patchD_3 > 0:
                        for i in range(self.opt.patchD_3):
                            if not self.opt.IN_vgg:
                                loss_vgg_patch += self.vgg_loss.compute_vgg_loss(
                                    self.vgg, self.fake_patch_1[i],
                                    self.input_patch_1[i]) * self.opt.vgg
                            else:
                                loss_vgg_patch += self.vgg_patch_loss.compute_vgg_loss(
                                    self.vgg, self.fake_patch_1[i],
                                    self.input_patch_1[i]) * self.opt.vgg
                        self.loss_vgg_b += loss_vgg_patch / float(
                            self.opt.patchD_3 + 1)
                    else:
                        self.loss_vgg_b += loss_vgg_patch
                self.loss_G = self.loss_G_A + self.loss_vgg_b * vgg_w
            elif self.opt.fcn > 0:
                self.loss_fcn_b = self.fcn_loss.compute_fcn_loss(
                    self.fcn, self.fake_B,
                    self.real_A) * self.opt.fcn if self.opt.fcn > 0 else 0
                if self.opt.patchD:
                    loss_fcn_patch = self.fcn_loss.compute_vgg_loss(
                        self.fcn, self.fake_patch,
                        self.input_patch) * self.opt.fcn
                    if self.opt.patchD_3 > 0:
                        for i in range(self.opt.patchD_3):
                            loss_fcn_patch += self.fcn_loss.compute_vgg_loss(
                                self.fcn, self.fake_patch_1[i],
                                self.input_patch_1[i]) * self.opt.fcn
                        self.loss_fcn_b += loss_fcn_patch / float(
                            self.opt.patchD_3 + 1)
                    else:
                        self.loss_fcn_b += loss_fcn_patch
                self.loss_G = self.loss_G_A + self.loss_fcn_b * vgg_w
            # self.loss_G = self.L1_AB + self.L1_BA

        ## Seg Loss ################################
        if seg_criterion is not None:
            # mIoU of enhanced image
            inter, union = utils_seg.batch_intersection_union(
                self.fake_B_Seg.data, self.mask, 19)
            idx = union > 0
            IoU = 1.0 * inter[idx] / (np.spacing(1) + union[idx])
            self.mIoU = np.nan_to_num(IoU.mean())

            with torch.no_grad():
                # mIoU of origin image by pretrained Seg Model
                inter, union = utils_seg.batch_intersection_union(
                    self.real_A_Seg.data, self.mask, 19)
                idx = union > 0
                IoU = 1.0 * inter[idx] / (np.spacing(1) + union[idx])
                self.mIoU_ori = np.nan_to_num(IoU.mean())
                self.mIoU_delta_mean = 0.8 * self.mIoU_delta_mean + 0.2 * np.round(
                    self.mIoU - self.mIoU_ori, 3)

                # mIoU of origin image by Generator
                inter, union = utils_seg.batch_intersection_union(
                    self.seg_real_A.data, self.mask, 19)
                idx = union > 0
                IoU = 1.0 * inter[idx] / (np.spacing(1) + union[idx])
                print("mIoU_generator", np.round(np.nan_to_num(IoU.mean()), 3))

            print("G:", self.loss_G.data[0], "mIoU gain:",
                  np.round(self.mIoU - self.mIoU_ori, 3), "mean:",
                  np.round(self.mIoU_delta_mean, 3), "lum:",
                  255 * (1 - self.input_A_gray).mean(), "epoch:", epoch)

            lambd = 3
            self.loss_Seg = seg_criterion(
                self.fake_B_Seg,
                self.mask) + lambd * seg_criterion(self.seg_real_A, self.mask)
            self.loss_G += self.loss_Seg
        ############################################
        ## GAN_GT Loss ################################
        if A_gt:
            # msssim = msssim_loss((self.fake_B.clamp(-1, 1)+1)/2*255, (self.A_gt+1)/2*255, weight_map=self.A_boundary)
            l1 = (F.l1_loss((self.fake_B + 1) / 2 * 255,
                            (self.A_gt + 1) / 2 * 255,
                            reduction='none') * self.A_boundary).mean()
            # self.loss_gt = 3 * msssim + 0.16 * l1
            self.loss_gt = 0.1 * l1
            print("loss_gt", self.loss_gt.data[0])
            self.loss_G += self.loss_gt
        ############################################

        self.loss_G.backward(retain_graph=True)