コード例 #1
0
ファイル: test.py プロジェクト: yhu9/RCAN
    def getIOU(self, lr, hr):
        if self.model == 'ESRGAN':
            lr = lr / 255.0
            hr = hr / 255.0
        #FORMAT THE INPUT
        h1, w1, d1 = hr.shape
        LR, HR, info = util.getTrainingPatches(lr,
                                               hr,
                                               self.args,
                                               transform=False)
        h2, w2 = info

        LR = LR.to(self.device)
        HR = HR.to(self.device)

        #GET EACH SR RESULT
        choices = self.agent.model(LR)
        maxval, maxarg = choices.max(dim=1)
        minval, minarg = choices.min(dim=1)

        l1 = []
        for i, sisr in enumerate(self.SRmodels):
            sr = sisr(LR)
            l1.append(torch.abs(sr - HR).mean(dim=1).mean(dim=1).mean(dim=1))
        l1diff = torch.stack(l1, dim=1)

        _, optimal_idx = l1diff.min(dim=1)
        _, predicted_idx = choices.max(dim=1)

        return info
コード例 #2
0
ファイル: train.py プロジェクト: yhu9/RCAN
    def getTrainingIndices(self):
        indices = list(range(len(self.TRAINING_HRPATH)))
        data = []
        for idx in indices:
            HRpath = self.TRAINING_HRPATH[idx]
            LRpath = self.TRAINING_LRPATH[idx]
            LR = imageio.imread(LRpath)
            HR = imageio.imread(HRpath)
            LR, HR, _ = util.getTrainingPatches(LR, HR, args)

            data.append(range(len(LR)))
        return data
コード例 #3
0
ファイル: train.py プロジェクト: yhu9/RCAN
    def getGroundTruthIOU(self, lr, hr, samplesize=0.01):
        if self.model == 'ESRGAN' or self.model == 'basic':
            lr = lr / 255.0
            hr = hr / 255.0

        #FORMAT THE INPUT
        lr, hr, info = util.getTrainingPatches(lr, hr, args, transform=False)

        # WE EVALUATE IOU ON ENTIRE IMAGE FED AS BATCH OF PATCHES
        #batchsize = int(len(lr) * samplesize)
        maxsize = int(len(lr) * samplesize)
        patch_ids = list(range(len(hr)))
        score = 0
        batch_ids = torch.Tensor(random.sample(patch_ids, maxsize)).long()

        #for i in range(0,maxsize-1,self.batch_size):
        #batch_ids = torch.Tensor(patch_ids[i:i+self.batch_size]).long()

        LR = lr[batch_ids]
        HR = hr[batch_ids]
        LR = LR.to(self.device)
        HR = HR.to(self.device)

        #GET EACH SR RESULT
        choices = self.agent.model(LR)
        diff = []
        for i, sisr in enumerate(self.SRmodels):
            error = (sisr(LR) - HR).pow(2).mean(dim=1).mean(dim=1).mean(
                dim=1)  # MSE
            #error = (torch.abs(sisr(LR) - HR).mean(dim=1).mean(dim=1).mean(dim=1)  # MAE
            diff.append(error)
        l1diff = torch.stack(diff, dim=1)

        _, optimal_idx = l1diff.min(dim=1)
        _, predicted_idx = choices.max(dim=1)
        score += torch.sum((optimal_idx == predicted_idx).float()).item()

        # output IOU
        IOU = score / maxsize
        return IOU
コード例 #4
0
ファイル: test.py プロジェクト: yhu9/RCAN
    def evaluateBounds(self, lr, hr):
        if self.model == 'ESRGAN':
            lr = lr / 255.0
            hr = hr / 255.0
        #FORMAT THE INPUT
        h1, w1, d1 = hr.shape
        LR, HR, info = util.getTrainingPatches(lr,
                                               hr,
                                               self.args,
                                               transform=False)
        h2, w2 = info

        LR = LR.to(self.device)
        HR = HR.to(self.device)

        #GET EACH SR RESULT
        choices = self.agent.model(LR)
        maxval, maxarg = choices.max(dim=1)
        minval, minarg = choices.min(dim=1)

        sisrs = []
        l1 = []
        for i, sisr in enumerate(self.SRmodels):
            sr = sisr(LR)
            sisrs.append(sr)
            l1.append(torch.abs(sr - HR).mean(dim=1).mean(dim=1).mean(dim=1))
        l1diff = torch.stack(l1, dim=1)
        sisrs = torch.stack(sisrs, dim=1)
        mask = torch.zeros(sisrs.shape).to(self.device)
        mask[:, 0, 0] += 255
        mask[:, 1, 1] += 255
        mask[:, 2, 2] += 255

        _, optimal_idx = l1diff.min(dim=1)
        _, worst_idx = l1diff.max(dim=1)
        sr_opt = sisrs[torch.arange(sisrs.size(0)), optimal_idx]

        # GATHER SUPER RESOLUTION BASED ON CHOICES
        optimal = sisrs[torch.arange(sisrs.size(0)), optimal_idx]
        worst = sisrs[torch.arange(sisrs.size(0)), worst_idx]
        sr = sisrs[torch.arange(sisrs.size(0)), maxarg]
        minsisr = sisrs[torch.arange(sisrs.size(0)), minarg]
        weight = torch.zeros(sr.shape).to(self.device)
        for i, w in enumerate(choices):
            for j, c in enumerate(w):
                weight[i, j] = c

        # GET DECISION MASKS
        optimalchoice = mask[torch.arange(sisrs.size(0)), optimal_idx]
        worstchoice = mask[torch.arange(sisrs.size(0)), worst_idx]
        maxchoice = mask[torch.arange(sisrs.size(0)), maxarg]
        minchoice = mask[torch.arange(sisrs.size(0)), minarg]

        # RECOMBINE RESULTS
        optimal = util.recombine(optimal, h1, w1, h2, w2)
        worst = util.recombine(worst, h1, w1, h2, w2)
        sr = util.recombine(sr, h1, w1, h2, w2)
        minsisr = util.recombine(minsisr, h1, w1, h2, w2)
        optimalchoice = util.recombine(optimalchoice, h1, w1, h2, w2)
        worstchoice = util.recombine(worstchoice, h1, w1, h2, w2)
        maxchoice = util.recombine(maxchoice, h1, w1, h2, w2)
        minchoice = util.recombine(minchoice, h1, w1, h2, w2)
        mask = util.recombine(weight, h1, w1, h2, w2)

        #FORMAT THE OUTPUT
        if self.model == 'ESRGAN':
            optimal = optimal.clip(0, 1) * 255.0
            worst = worst.clip(0, 1) * 255.0
            sr = sr.clip(0, 1) * 255.0
            minsisr = minsisr.clip(0, 1) * 255.0
        else:
            optimal = optimal.clip(0, 255)
            worst = worst.clip(0, 255)
            sr = sr.clip(0, 255)
            minsisr = minsisr.clip(0, 255)

        maxchoice = maxchoice.clip(0, 255)
        minchoice = minchoice.clip(0, 255)
        worstchoice = worstchoice.clip(0, 255)
        optimalchoice = optimalchoice.clip(0, 255)
        mask = mask * 255.0

        info = {
            'HR': hr,
            'min': minsisr,
            'max': sr,
            'worst': worst,
            'optimal': optimal,
            'optimalchoice': optimalchoice,
            'worstchoice': worstchoice,
            'maxchoice': maxchoice,
            'minchoice': minchoice,
            'choices': choices,
            'mask': mask
        }

        return info
コード例 #5
0
    def train(self,maxepoch=100,start=.01,end=0.0001):
        #EACH EPISODE TAKE ONE LR/HR PAIR WITH CORRESPONDING PATCHES
        #AND ATTEMPT TO SUPER RESOLVE EACH PATCH

        #requires pytorch 1.1.0+ which is not possible on the server
        #scheduler = torch.optim.lr_scheduler.CyclicLR(self.agent.optimizer,base_lr=0.0001,max_lr=0.1)

        #QUICK CHECK ON EVERYTHING
        #with torch.no_grad():
        #    psnr,ssim,info = self.test.validateSet5(save=False,quick=False)

        #START TRAINING
        indices = list(range(len(self.TRAINING_HRPATH)))
        lossfn = torch.nn.L1Loss()
        lossMSE = torch.nn.MSELoss()
        lossCE = torch.nn.CrossEntropyLoss()
        softmaxfn = torch.nn.Softmax(dim=1)

        #random.shuffle(indices)
        for c in count():

            #FOR EACH HIGH RESOLUTION IMAGE
            for n,idx in enumerate(indices):
                idx = random.sample(indices,1)[0]
                #idx = 0

                #GET INPUT FROM CURRENT IMAGE
                HRpath = self.TRAINING_HRPATH[idx]
                LRpath = self.TRAINING_LRPATH[idx]
                LR = imageio.imread(LRpath)
                HR = imageio.imread(HRpath)

                LR,HR,_ = util.getTrainingPatches(LR,HR,args)

                # WE GO THROUGH PATCH IN RANDOM ORDER
                patch_ids = list(range(len(LR)))
                random.shuffle(patch_ids)
                P = []
                for step in range(1):
                    batch_ids = random.sample(patch_ids,self.batch_size)
                    #batch_ids = patch_ids[:self.batch_size]
                    labels = torch.Tensor(batch_ids).long()

                    lrbatch = LR[labels,:,:,:]
                    hrbatch = HR[labels,:,:,:]

                    lrbatch = lrbatch.to(self.device)
                    hrbatch = hrbatch.to(self.device)
                    if args.model == 'ESRGAN':
                        lrbatch = lrbatch / 255.0
                        hrbatch = hrbatch / 255.0

                    #GET SISR RESULTS FROM EACH MODEL
                    loss_SISR = 0
                    sisrs = []

                    probs = self.agent.model(lrbatch)
                    for sisr in self.SRmodels:
                        hr_pred = sisr(lrbatch)
                        sisrs.append(hr_pred)

                    #UPDATE BOTH THE SISR MODELS AND THE SELECTION MODEL ACCORDING TO THEIR LOSS
                    sisrloss = []
                    l1loss = []
                    maxarg = probs.max(dim=1)[1]
                    #onehot_mask = torch.nn.functional.one_hot(maxarg,len(sisrs)).float()
                    for j, sr in enumerate(sisrs):
                        self.SRoptimizers[j].zero_grad()
                        l1 = torch.abs(sr - hrbatch).sum(dim=1).sum(dim=1).sum(dim=1) / ((self.PATCH_SIZE * self.UPSIZE)**2 * 3)
                        l1loss.append(l1)
                        loss = torch.mean(l1 * probs[:,j])
                        #loss = torch.mean(l1 * onehot_mask[:,j])
                        sisrloss.append(loss)

                    l1loss = torch.stack(l1loss,dim=1)
                    self.agent.opt.zero_grad()
                    sisrloss_total = sum(sisrloss)
                    sisrloss_total.backward()
                    [opt.step() for opt in self.SRoptimizers]
                    self.agent.opt.step()

                    #[sched.step() for sched in self.schedulers]
                    #self.agent.scheduler.step()

                    #CONSOLE OUTPUT FOR QUICK AND DIRTY DEBUGGING
                    lr = self.SRoptimizers[-1].param_groups[0]['lr']
                    lr2 = self.agent.opt.param_groups[0]['lr']
                    _,maxarg = probs[0].max(0)
                    sampleSR = sisrs[maxarg.item()][0]
                    sampleHR = hrbatch[0]
                    if args.model != 'ESRGAN':
                        sampleSR = sampleSR / 255.0
                        sampleHR = sampleHR / 255.0

                    choice = probs.max(dim=1)[1]
                    c1 = (choice == 0).float().mean()
                    c2 = (choice == 1).float().mean()
                    c3 = (choice == 2).float().mean()
                    s1 = torch.mean(l1loss[:,0]).item()
                    s2 = torch.mean(l1loss[:,1]).item()
                    s3 = torch.mean(l1loss[:,2]).item()
                    agentloss = torch.mean(l1loss.gather(1,choice.unsqueeze(1)))

                    print('\rEpoch/img: {}/{} | LR sr/ag: {:.8f}/{:.8f} | Agent Loss: {:.4f}, SISR Loss: {:.4f} | s1: {:.4f} | s2: {:.4f} | s3: {:.4f}'\
                            .format(c,n,lr,lr2,agentloss.item(),sisrloss_total.item(),s1,s2,s3),end="\n")

                    #LOG AND SAVE THE INFORMATION
                    scalar_summaries = {'Loss/AgentLoss': agentloss, 'Loss/SISRLoss': sisrloss_total, "choice/c1": c1, "choice/c2": c2, "choice/c3": c3, "sisr/s1": s1, "sisr/s2": s2, "sisr/s3": s3}
                    hist_summaries = {'actions': probs.view(-1), "choices": choice.view(-1)}
                    img_summaries = {'sr/HR': sampleHR.clamp(0,1),'sr/SR': sampleSR.clamp(0,1)}
                    self.logger.hist_summary(hist_summaries)
                    self.logger.scalar_summary(scalar_summaries)
                    self.logger.image_summary(img_summaries)
                    if self.logger.step % 100 == 0:
                        with torch.no_grad():
                            psnr,ssim,info = self.test.validateSet5(save=False,quick=False)
                        if self.logger:
                            self.logger.scalar_summary({'Testing_PSNR': psnr, 'Testing_SSIM': ssim})
                            weightedmask = torch.from_numpy(info['mask']).permute(2,0,1) / 255.0
                            mask = torch.from_numpy(info['maxchoice']).permute(2,0,1) / 255.0
                            optimal_mask = torch.from_numpy(info['optimalchoice']).permute(2,0,1) / 255.0
                            hrimg = torch.Tensor(info['HR']).permute(2,0,1) / 255.0
                            srimg = torch.from_numpy(info['max']).permute(2,0,1) / 255.0
                            self.logger.image_summary({'Testing/Test Assignment':mask[:3],'Testing/Weight': weightedmask[:3], 'Testing/SR':srimg, 'Testing/HR': hrimg, 'Testing/optimalmask': optimal_mask})
                        self.savemodels()
                        self.agent.model.train()
                        [model.train() for model in self.SRmodels]

                    self.logger.incstep()
コード例 #6
0
ファイル: train.py プロジェクト: yhu9/RCAN
    def optimize(self, data, iou_threshold=0.5):
        self.agent.model.train()
        [model.train() for model in self.SRmodels]
        agent_iou = deque(maxlen=100)

        # while the agent iou is not good enough
        for c in count():
            # get an image
            idx = random.sample(data, 1)[0]

            hr_path = self.TRAINING_HRPATH[idx]
            lr_path = self.TRAINING_LRPATH[idx]
            lr = imageio.imread(lr_path)
            hr = imageio.imread(hr_path)

            lr, hr, _ = util.getTrainingPatches(lr, hr, args, transform=False)

            patch_ids = list(range(len(lr)))
            random.shuffle(patch_ids)

            # get the mini batch
            batch_ids = random.sample(patch_ids, self.batch_size)
            labels = torch.Tensor(batch_ids).long()
            lr_batch = lr[labels, :, :, :]
            hr_batch = hr[labels, :, :, :]
            lr_batch = lr_batch.to(self.device)
            hr_batch = hr_batch.to(self.device)
            if args.model == 'ESRGAN' or args.model == 'basic':
                lr_batch = lr_batch / 255.0
                hr_batch = hr_batch / 255.0

            # UPDATE THE SISR MODELS
            self.agent.opt.zero_grad()
            sr_result = torch.zeros(
                self.batch_size, 3, self.PATCH_SIZE * self.UPSIZE,
                self.PATCH_SIZE * self.UPSIZE).to(self.device)
            sr_result.requires_gard = False
            probs = self.agent.model(lr_batch)
            #sisr_loss = 0
            sisrs = []
            pred_diff = []
            for j, sisr in enumerate(self.SRmodels):
                self.SRoptimizers[j].zero_grad()
                hr_pred = sisr(lr_batch)
                diff = (hr_pred - hr_batch).pow(2).mean(dim=1).mean(
                    dim=1).mean(dim=1)  #MEAN OF FROB NORM SQUARED ACROSS CxHxW
                #diff = torch.abs(hr_pred - hr_batch).sum(dim=1).sum(dim=1).sum(dim=1) / ((self.PATCH_SIZE * self.UPSIZE)**2 * 3)   #MAE ACROSS CxHxW
                #sisr_loss += torch.mean(diff * probs[:,j])
                pred_diff.append(diff)
                sisrs.append(hr_pred)
            pred_diff = torch.stack(pred_diff, dim=1)
            minval, optimalidx = pred_diff.min(dim=1)
            selectionloss = torch.mean(
                probs.gather(1, optimalidx.unsqueeze(1)).clamp(1e-16,
                                                               1).log()) * -1
            sisrloss = minval.mean()
            sisr_loss_total = sisrloss + selectionloss
            sisr_loss_total.backward()
            [opt.step() for opt in self.SRoptimizers]
            self.agent.opt.step()

            # VISUALIZATION
            # CONSOLE OUTPUT FOR QUICK AND DIRTY DEBUGGING
            lr1 = self.SRoptimizers[-1].param_groups[0]['lr']
            lr2 = self.agent.opt.param_groups[0]['lr']
            _, maxarg = probs[0].max(0)
            sample_sr = sisrs[maxarg.item()][0]
            sample_hr = hr_batch[0]
            if args.model != 'ESRGAN' and args.model != 'basic':
                sample_sr = sample_sr / 255.0
                sample_hr = sample_hr / 255.0

            choice = probs.max(dim=1)[1]
            iou = (choice == optimalidx).float().sum() / (len(choice))
            c1 = (choice == 0).float().mean()
            c2 = (choice == 1).float().mean()
            c3 = (choice == 2).float().mean()
            s1 = torch.mean(pred_diff[:, 0]).item()
            s2 = torch.mean(pred_diff[:, 1]).item()
            s3 = torch.mean(pred_diff[:, 2]).item()

            agent_iou.append(iou.item())

            print('\rEpoch/img: {}/{} | LR: {:.8f} | Agent Loss: {:.4f}, SISR Loss: {:.4f} | IOU: {:.4f} | c1: {:.4f}, c2: {:.4f}, c3: {:.4f}'\
                        .format(c,self.logger.step,lr2,selectionloss.item(),sisrloss.item(), np.mean(agent_iou), c1.item(), c2.item(), c3.item()),end="\n")
            #print('\rEpoch/img: {}/{} | LR sr/ag: {:.8f}/{:.8f} | Agent Loss: {:.4f} | SISR Loss: {:.4f} | IOU : {:.4f} | s1: {:.4f} | s2: {:.4f} | s3: {:.4f}'\
            #        .format(c,self.logger.step,lr1,lr2,sisr_loss_total.item(),sisr_loss_total.item(),np.mean(agent_iou),s1,s2,s3),end="\n")

            #LOG AND SAVE THE INFORMATION
            scalar_summaries = {
                'Loss/AgentLoss': sisr_loss_total,
                'Loss/SISRLoss': sisr_loss_total,
                "Loss/IOU": np.mean(agent_iou),
                "choice/c1": c1,
                "choice/c2": c2,
                "choice/c3": c3,
                "sisr/s1": s1,
                "sisr/s2": s2,
                "sisr/s3": s3
            }
            #hist_summaries = {'actions': probs.view(-1), "choices": choice.view(-1)}
            img_summaries = {
                'sr/HR': sample_hr.clamp(0, 1),
                'sr/SR': sample_sr.clamp(0, 1)
            }
            #self.logger.hist_summary(hist_summaries)
            self.logger.scalar_summary(scalar_summaries)
            self.logger.image_summary(img_summaries)
            if self.logger.step % 100 == 0:
                with torch.no_grad():
                    psnr, ssim, info = self.test.validateSet5(save=False,
                                                              quick=False)
                    self.agent.model.train()
                    [model.train() for model in self.SRmodels]
                if self.logger:
                    self.logger.scalar_summary({
                        'Testing_PSNR': psnr,
                        'Testing_SSIM': ssim
                    })
                    weightedmask = torch.from_numpy(info['mask']).permute(
                        2, 0, 1) / 255.0
                    mask = torch.from_numpy(info['maxchoice']).permute(
                        2, 0, 1) / 255.0
                    optimal_mask = torch.from_numpy(
                        info['optimalchoice']).permute(2, 0, 1) / 255.0
                    hrimg = torch.Tensor(info["HR"]).permute(2, 0, 1)
                    srimg = torch.from_numpy(info['max']).permute(2, 0,
                                                                  1) / 255.0
                    self.logger.image_summary({
                        'Testing/Test Assignment':
                        mask[:3],
                        'Testing/Weight':
                        weightedmask[:3],
                        'Testing/SR':
                        srimg,
                        'Testing/HR':
                        hrimg,
                        'Testing/optimalmask':
                        optimal_mask[:3]
                    })
                    #self.logger.image_summary({'Testing/Test Assignment':mask[:3], 'Testing/SR':srimg, 'Testing/HR': hrimg, 'Testing/upperboundmask': best_mask[:3]})
                self.savemodels()
            self.logger.incstep()

            if np.mean(agent_iou) > iou_threshold or c + 1 % 10000 == 0: break
コード例 #7
0
ファイル: train.py プロジェクト: yhu9/RCAN
    def optimize(self, data, iou_threshold=0.8):
        random.shuffle(data)
        for c, idx in enumerate(data):
            # increment the patches seen to all patches of that image
            agent_iou = deque(maxlen=100)

            # while the agent iou is not good enough
            while True:
                # get an image
                idx = random.sample(data, 1)[0]

                hr_path = self.TRAINING_HRPATH[idx]
                lr_path = self.TRAINING_LRPATH[idx]
                lr = imageio.imread(lr_path)
                hr = imageio.imread(hr_path)

                lr, hr = util.getTrainingPatches(lr, hr, args)
                patch_ids = list(range(len(lr)))

                # get the mini batch
                batch_ids = random.sample(
                    patch_ids, self.batch_size)  #TRAIN ON A SINGLE IMAGE
                labels = torch.Tensor(batch_ids).long()
                lr_batch = lr[labels, :, :, :]
                hr_batch = hr[labels, :, :, :]
                lr_batch = lr_batch.to(self.device)
                hr_batch = hr_batch.to(self.device)
                if self.model == 'ESRGAN':
                    lr_batch = lr_batch / 255.0
                    hr_batch = hr_batch / 255.0

                # GET SISR RESULTS FROM EACH MODEL
                sisrs = []
                probs = self.agent.model(lr_batch)
                for j, sisr in enumerate(self.SRmodels):
                    hr_pred = sisr(lr_batch)
                    sisrs.append(hr_pred)

                # update the sisr model and the selection model
                sr_result = torch.zeros(
                    self.batch_size, 3, self.PATCH_SIZE * self.UPSIZE,
                    self.PATCH_SIZE * self.UPSIZE).to(self.device)
                sr_result.requires_gard = False
                l1diff = []
                sisr_loss = 0
                for j, sr in enumerate(sisrs):
                    self.SRoptimizers[j].zero_grad()
                    l1 = torch.abs(sr - hr_batch).mean(dim=1)
                    sisr_loss += torch.mean(probs[:, j] * l1)
                    l1diff.append(l1)
                    pred = sr * probs[:, j].unsqueeze(1)
                    sr_result += pred
                self.agent.opt.zero_grad()
                sisr_loss_total = sisr_loss
                sisr_loss_total.backward()
                [opt.step() for opt in self.SRoptimizers]
                self.agent.opt.step()

                # visualizations
                if self.model != 'ESRGAN': sr_result = sr_result / 255.0
                maxval, maxidx = probs.max(dim=1)
                l1diff = torch.stack(l1diff, dim=1)
                diffmap = torch.nn.functional.softmax(
                    -255 * (l1diff - torch.mean(l1diff)), dim=1)
                minval, minidx = l1diff.min(dim=1)
                target = torch.nn.functional.one_hot(minidx,
                                                     len(sisrs)).permute(
                                                         0, 3, 1, 2)
                reward = (l1diff - l1diff.mean(1).unsqueeze(1)).detach() * -1
                reward = (reward - reward.mean()) / reward.std()

                lr1 = self.SRoptimizers[-1].param_groups[0]['lr']
                lr2 = self.agent.opt.param_groups[0]['lr']
                choice = probs.max(dim=1)[1]
                iou = (choice == minidx).float().sum() / (
                    choice.shape[0] * choice.shape[1] * choice.shape[2])
                c1 = (choice == 0).float().mean()
                c2 = (choice == 1).float().mean()
                c3 = (choice == 2).float().mean()
                s1 = torch.mean(l1diff[:, 0]).item()
                s2 = torch.mean(l1diff[:, 1]).item()
                s3 = torch.mean(l1diff[:, 2]).item()

                # SAVE THE IOU FOR THIS BATCH
                agent_iou.append(iou.item())

                print('\rEpoch/img: {}/{} | LR sr/ag: {:.8f}/{:.8f} | Agent Loss {:.4f} |  SISR Loss: {:.4f} | IOU: {:.4f} | s1: {:.4f}, s2: {:.4f}, s3: {:.4f}'\
                        .format(c,self.logger.step,lr1,lr2,sisr_loss_total.item(),sisr_loss_total.item(),np.mean(agent_iou), s1, s2, s3),end="\n")

                #LOG AND SAVE THE INFORMATION
                scalar_summaries = {
                    'Loss/IOU': np.mean(agent_iou),
                    'Loss/SISRLoss': sisr_loss_total,
                    "choice/c1": c1,
                    "choice/c2": c2,
                    "choice/c3": c3
                }
                hist_summaries = {
                    'actions': probs[0].view(-1),
                    "choices": choice[0].view(-1)
                }
                img_summaries = {
                    'sr/mask': probs[0][:3],
                    'sr/sr': sr_result[0].clamp(0, 1),
                    'sr/targetmask': target[0][:3],
                    'sr/diffmap': diffmap[0][:3]
                }
                self.logger.scalar_summary(scalar_summaries)
                self.logger.hist_summary(hist_summaries)
                self.logger.image_summary(img_summaries)
                if self.logger.step % 100 == 0:
                    with torch.no_grad():
                        psnr, ssim, info = self.test.validateSet5(save=False,
                                                                  quick=False)
                    self.agent.model.train()
                    [model.train() for model in self.SRmodels]
                    if self.logger:
                        self.logger.scalar_summary({
                            'Testing_PSNR': psnr,
                            'Testing_SSIM': ssim
                        })
                        mask = torch.from_numpy(
                            info['choices']).float().permute(2, 0, 1)
                        best_mask = info['upperboundmask'].squeeze()
                        worst_mask = info['lowerboundmask'].squeeze()
                        hrimg = info['HR'].squeeze()
                        srimg = torch.from_numpy(info['weighted']).permute(
                            2, 0, 1).clamp(0, 1)
                        advantage = info['advantage']
                        self.logger.image_summary({
                            'Testing/Test Assignment':
                            mask[:3],
                            'Testing/SR':
                            srimg,
                            'Testing/HR':
                            hrimg,
                            'Testing/upperboundmask':
                            best_mask,
                            'Testing/advantage':
                            advantage
                        })
                    self.savemodels()
                self.logger.incstep()

                # breaking condition
                if np.mean(agent_iou) > iou_threshold:
                    print(agent_iou)
                    break