def test(self, ):

        # switch to evaluate mode
        self.model.eval()

        ssimes = AverageMeter()
        psnres = AverageMeter()

        with torch.no_grad():
            for i, batches in enumerate(self.val_loader):

                inputs = batches['image'].to(self.device)
                target = batches['target'].to(self.device)
                mask = batches['mask'].to(self.device)

                outputs = self.model(inputs)

                # select the outputs by the giving arch
                if type(outputs) == type(inputs):
                    output = outputs
                elif type(outputs[0]) == type([]):
                    output = outputs[0][0]
                else:
                    output = outputs[0]

                # recover the image to 255
                output = im_to_numpy(
                    torch.clamp(output[0] * 255, min=0.0,
                                max=255.0)).astype(np.uint8)
                target = im_to_numpy(
                    torch.clamp(target[0] * 255, min=0.0,
                                max=255.0)).astype(np.uint8)

                skimage.io.imsave(
                    '%s/%s' % (self.args.checkpoint, batches['name'][0]),
                    output)

                psnr = compare_psnr(target, output)
                ssim = compare_ssim(target, output, multichannel=True)

                psnres.update(psnr, inputs.size(0))
                ssimes.update(ssim, inputs.size(0))

        print("%s:PSNR:%s,SSIM:%s" %
              (self.args.checkpoint, psnres.avg, ssimes.avg))
        print("DONE.\n")
Exemplo n.º 2
0
    def train(self, epoch):

        self.current_epoch = epoch

        if self.args.freeze and epoch > 10:
            self.model.freeze_weighting_of_rasc()
            self.optimizer_G = torch.optim.Adam(
                filter(lambda p: p.requires_grad, self.model.parameters()),
                lr=self.args.lr,
                betas=(0.5, 0.999),
                weight_decay=self.args.weight_decay)

        batch_time = AverageMeter()
        data_time = AverageMeter()
        LoggerLossG = AverageMeter()
        LoggerLossGGAN = AverageMeter()
        LoggerLossGL1 = AverageMeter()

        LoggerLossD = AverageMeter()
        LoggerLossDreal = AverageMeter()
        LoggerLossDfake = AverageMeter()

        lossMask8s = AverageMeter()
        lossMask4s = AverageMeter()
        lossMask2s = AverageMeter()

        # switch to train mode
        self.model.train()
        self.discriminator.train()

        end = time.time()

        bar = Bar('Processing {} '.format(self.args.arch),
                  max=len(self.train_loader))

        for i, (inputs, target) in enumerate(self.train_loader):

            input_image, mask, m2s, m4s, m8s = inputs

            current_index = len(self.train_loader) * epoch + i
            valid = torch.ones((input_image.size(0), self.patch, self.patch),
                               requires_grad=False).cuda()
            fake = torch.zeros((input_image.size(0), self.patch, self.patch),
                               requires_grad=False).cuda()

            reverse_mask = 1 - mask

            if self.args.gpu:
                input_image = input_image.cuda()
                mask = mask.cuda()
                m2s = m2s.cuda()
                m4s = m4s.cuda()
                m8s = m8s.cuda()
                reverse_mask = reverse_mask.cuda()
                target = target.cuda()
                valid.cuda()
                fake.cuda()

            # ---------------
            # Train model
            # --------------

            self.optimizer_G.zero_grad()
            fake_input, mask8s, mask4s, mask2s = self.model(
                torch.cat((input_image, mask), 1))
            pred_fake = self.discriminator(fake_input, input_image)
            loss_GAN = self.criterion_GAN(pred_fake, valid)
            loss_pixel = self.criterion_L1(fake_input, target)  # fake in
            # here two choice: mseLoss or NLLLoss
            masked_loss8s = self.attentionLoss8s(mask8s, m8s)
            masked_loss4s = self.attentionLoss4s(mask4s, m4s)
            masked_loss2s = self.attentionLoss2s(mask2s, m2s)
            loss_G = loss_GAN + 100 * loss_pixel + 90 * masked_loss8s + 90 * masked_loss4s + 90 * masked_loss2s

            loss_G.backward()
            self.optimizer_G.step()

            self.optimizer_D.zero_grad()
            pred_real = self.discriminator(target, input_image)
            loss_real = self.criterion_GAN(pred_real, valid)
            pred_fake = self.discriminator(fake_input.detach(), input_image)
            loss_fake = self.criterion_GAN(pred_fake, fake)
            loss_D = 0.5 * (loss_real + loss_fake)
            loss_D.backward()
            self.optimizer_D.step()

            # ---------------------
            #        Logger
            # ---------------------

            LoggerLossGGAN.update(loss_GAN.item(), input_image.size(0))
            LoggerLossGL1.update(loss_pixel.item(), input_image.size(0))
            LoggerLossG.update(loss_G.item(), input_image.size(0))
            LoggerLossDfake.update(loss_real.item(), input_image.size(0))
            LoggerLossDreal.update(loss_fake.item(), input_image.size(0))
            LoggerLossD.update(loss_D.item(), input_image.size(0))
            lossMask8s.update(masked_loss8s.item(), input_image.size(0))
            lossMask4s.update(masked_loss4s.item(), input_image.size(0))
            lossMask2s.update(masked_loss2s.item(), input_image.size(0))

            # ---------------------
            #        Visualize
            # ---------------------

            if i == 1:
                self.writer.add_images('train/Goutput', deNorm(fake_input),
                                       current_index)
                self.writer.add_images('train/target', deNorm(target),
                                       current_index)
                self.writer.add_images('train/input', deNorm(input_image),
                                       current_index)
                self.writer.add_images('train/mask', mask.repeat((1, 3, 1, 1)),
                                       current_index)
                self.writer.add_images('train/attention2s',
                                       mask2s.repeat(1, 3, 1, 1),
                                       current_index)
                self.writer.add_images('train/attention4s',
                                       mask4s.repeat(1, 3, 1, 1),
                                       current_index)
                self.writer.add_images('train/attention8s',
                                       mask8s.repeat(1, 3, 1, 1),
                                       current_index)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            # plot progress
            bar.suffix = '({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss D: {loss_d:.4f} | Loss G: {loss_g:.4f} | Loss L1: {loss_l1:.6f} '.format(
                batch=i + 1,
                size=len(self.train_loader),
                data=data_time.val,
                bt=batch_time.val,
                total=bar.elapsed_td,
                eta=bar.eta_td,
                loss_d=LoggerLossD.avg,
                loss_g=LoggerLossGGAN.avg,
                loss_l1=LoggerLossGL1.avg)
            bar.next()

        bar.finish()
        self.writer.add_scalar('train/loss/GAN', LoggerLossGGAN.avg, epoch)
        self.writer.add_scalar('train/loss/D', LoggerLossD.avg, epoch)
        self.writer.add_scalar('train/loss/L1', LoggerLossGL1.avg, epoch)
        self.writer.add_scalar('train/loss/G', LoggerLossG.avg, epoch)
        self.writer.add_scalar('train/loss/Dreal', LoggerLossDreal.avg, epoch)
        self.writer.add_scalar('train/loss/Dfake', LoggerLossDfake.avg, epoch)

        self.writer.add_scalar('train/loss_Mask8s', lossMask8s.avg, epoch)
        self.writer.add_scalar('train/loss_Mask4s', lossMask4s.avg, epoch)
        self.writer.add_scalar('train/loss_Mask2s', lossMask2s.avg, epoch)
Exemplo n.º 3
0
    def validate(self, epoch):

        self.current_epoch = epoch
        batch_time = AverageMeter()
        data_time = AverageMeter()
        psnres = AverageMeter()
        ssimes = AverageMeter()
        lossMask8s = AverageMeter()
        lossMask4s = AverageMeter()
        lossMask2s = AverageMeter()

        # switch to evaluate mode
        self.model.eval()

        end = time.time()
        bar = Bar('Processing {} '.format(self.args.arch),
                  max=len(self.val_loader))

        with torch.no_grad():
            for i, (inputs, target) in enumerate(self.val_loader):

                input_image, mask, m2s, m4s, m8s = inputs

                current_index = len(self.train_loader) * epoch + i
                valid = torch.ones(
                    (input_image.size(0), self.patch, self.patch),
                    requires_grad=False).cuda()
                fake = torch.zeros(
                    (input_image.size(0), self.patch, self.patch),
                    requires_grad=False).cuda()

                reverse_mask = 1 - mask

                if self.args.gpu:
                    input_image = input_image.cuda()
                    mask = mask.cuda()
                    m2s = m2s.cuda()
                    m4s = m4s.cuda()
                    m8s = m8s.cuda()
                    reverse_mask = reverse_mask.cuda()
                    target = target.cuda()
                    valid.cuda()
                    fake.cuda()

                # 32,64,128
                output, mask8s, mask4s, mask2s = self.model(
                    torch.cat((input_image, mask), 1))

                output = deNorm(output)
                target = deNorm(target)

                masked_loss8s = self.attentionLoss8s(mask8s, m8s)
                masked_loss4s = self.attentionLoss4s(mask4s, m4s)
                masked_loss2s = self.attentionLoss2s(mask2s, m2s)

                ## psnr and  ssim calculator.
                mse = self.criterion_GAN(output, target)
                psnr = 10 * log10(1 / mse.item())
                ssim = pytorch_ssim.ssim(output, target)

                psnres.update(psnr, input_image.size(0))
                ssimes.update(ssim, input_image.size(0))
                lossMask8s.update(masked_loss8s.item(), input_image.size(0))
                lossMask4s.update(masked_loss4s.item(), input_image.size(0))
                lossMask2s.update(masked_loss2s.item(), input_image.size(0))

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()
                # plot progress
                bar.suffix = '({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | SSIM: {ssim:.4f} | PSNR: {psnr:.4f}'.format(
                    batch=i + 1,
                    size=len(self.val_loader),
                    data=data_time.val,
                    bt=batch_time.val,
                    total=bar.elapsed_td,
                    eta=bar.eta_td,
                    ssim=ssimes.avg,
                    psnr=psnres.avg)
                bar.next()
        bar.finish()

        self.writer.add_scalar('val/SSIM', ssimes.avg, epoch)
        self.writer.add_scalar('val/PSNR', psnres.avg, epoch)
        self.writer.add_scalar('train/loss_Mask8s', lossMask8s.avg, epoch)
        self.writer.add_scalar('train/loss_Mask4s', lossMask4s.avg, epoch)
        self.writer.add_scalar('train/loss_Mask2s', lossMask2s.avg, epoch)

        self.metric = psnres.avg
Exemplo n.º 4
0
    def train(self, epoch):
        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        gradientes = AverageMeter()

        # switch to train mode
        self.model.train()

        end = time.time()

        bar = Bar('Processing {} '.format(self.args.arch),
                  max=len(self.train_loader))

        for i, (inputs, target) in enumerate(self.train_loader):
            # measure data loading time

            if self.args.gpu:
                inputs = inputs.cuda()
                mask = inputs[:, 3:4, :, :].cuda()
                target = target.cuda()
            else:
                target = target
                mask = inputs[:, 3:4, :, :]

            output = self.model(inputs)

            if i == 1:
                current_index = len(self.train_loader) * epoch + i
                self.writer.add_images('train/output', output, current_index)
                self.writer.add_images('train/target', target, current_index)
                self.writer.add_images('train/input', inputs[:, 0:3, :, :],
                                       current_index)
                self.writer.add_images('train/mask', mask.repeat(1, 3, 1, 1),
                                       current_index)

            L2_loss = self.loss(output, target)

            if self.args.gradient_loss:
                tgx, tgy = image_gradient(inputs[:, 0:3, :, :])
                ogx, ogy = image_gradient(output)
                gradient_loss = self.gradient_loss_y(
                    ogy, tgy) + self.gradient_loss_x(ogx, tgx)
            else:
                gradient_loss = 0

            total_loss = 1e10 * L2_loss + 1e9 * gradient_loss

            # compute gradient and do SGD step
            self.optimizer.zero_grad()
            total_loss.backward()
            self.optimizer.step()

            # measure accuracy and record loss
            losses.update(1e10 * L2_loss.item(), inputs.size(0))

            if self.args.gradient_loss:
                gradientes.update(1e9 * gradient_loss.item(), inputs.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            # plot progress
            bar.suffix = '({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss L2: {loss_label:.4f}'.format(
                batch=i + 1,
                size=len(self.train_loader),
                data=data_time.val,
                bt=batch_time.val,
                total=bar.elapsed_td,
                eta=bar.eta_td,
                loss_label=losses.avg)
            bar.next()
        bar.finish()
        self.writer.add_scalar('train/loss_L2', losses.avg, epoch)
        self.writer.add_scalar('train/loss_gradient', gradientes.avg, epoch)
Exemplo n.º 5
0
    def validate(self, epoch):
        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        ssimes = AverageMeter()
        psnres = AverageMeter()

        # switch to evaluate mode
        self.model.eval()

        end = time.time()
        bar = Bar('Processing', max=len(self.val_loader))
        with torch.no_grad():
            for i, (inputs, target) in enumerate(self.val_loader):

                # measure data loading time
                if self.args.gpu:
                    inputs = inputs.cuda()  # image and bbox
                    target = target.cuda()

                output = self.model(inputs)
                mse = self.loss(output, target)

                L2_loss = 1e10 * mse
                psnr = 10 * log10(1 / mse.item())
                ssim = pytorch_ssim.ssim(output, target)

                losses.update(L2_loss.item(), inputs.size(0))
                psnres.update(psnr, inputs.size(0))
                ssimes.update(ssim, inputs.size(0))

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                # plot progress
                bar.suffix = '({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss_L2: {loss_label:.4f} | SSIM: {ssim:.4f} | PSNR: {psnr:.4f}'.format(
                    batch=i + 1,
                    size=len(self.val_loader),
                    data=data_time.val,
                    bt=batch_time.val,
                    total=bar.elapsed_td,
                    eta=bar.eta_td,
                    loss_label=losses.avg,
                    psnr=psnres.avg,
                    ssim=ssimes.avg,
                )
                bar.next()
        bar.finish()

        self.writer.add_scalar('val/loss_L2', losses.avg, epoch)
        self.writer.add_scalar('val/PSNR', psnres.avg, epoch)
        self.writer.add_scalar('val/SSIM', ssimes.avg, epoch)
        self.metric = psnres.avg
    def train(self, epoch):
        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        lossvgg = AverageMeter()

        # switch to train mode
        self.model.train()
        end = time.time()

        bar = Bar('Processing', max=len(self.train_loader) * self.hl)
        for _ in range(self.hl):
            for i, batches in enumerate(self.train_loader):
                # measure data loading time
                inputs = batches['image'].to(self.device)
                target = batches['target'].to(self.device)
                mask = batches['mask'].to(self.device)
                current_index = len(self.train_loader) * epoch + i

                feeded = torch.cat([inputs, mask], dim=1)
                feeded = feeded.to(self.device)

                output = self.model(feeded)

                if self.args.res:
                    output = output + inputs

                L2_loss = self.loss(output, target)

                if self.args.style_loss > 0:
                    vgg_loss = self.vggloss(output, target, mask)
                else:
                    vgg_loss = 0

                total_loss = L2_loss + self.args.style_loss * vgg_loss

                # compute gradient and do SGD step
                self.optimizer.zero_grad()
                total_loss.backward()
                self.optimizer.step()

                # measure accuracy and record loss
                losses.update(L2_loss.item(), inputs.size(0))

                if self.args.style_loss > 0:
                    lossvgg.update(vgg_loss.item(), inputs.size(0))

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                # plot progress
                suffix = '({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss L2: {loss_label:.4f} | Loss VGG: {loss_vgg:.4f}'.format(
                    batch=i + 1,
                    size=len(self.train_loader),
                    data=data_time.val,
                    bt=batch_time.val,
                    total=bar.elapsed_td,
                    eta=bar.eta_td,
                    loss_label=losses.avg,
                    loss_vgg=lossvgg.avg)

                if current_index % 1000 == 0:
                    print(suffix)

                if self.args.freq > 0 and current_index % self.args.freq == 0:
                    self.validate(current_index)
                    self.flush()
                    self.save_checkpoint()

        self.record('train/loss_L2', losses.avg, current_index)
    def validate(self, epoch):
        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        ssimes = AverageMeter()
        psnres = AverageMeter()
        # switch to evaluate mode
        self.model.eval()

        end = time.time()
        with torch.no_grad():
            for i, batches in enumerate(self.val_loader):

                inputs = batches['image'].to(self.device)
                target = batches['target'].to(self.device)
                mask = batches['mask'].to(self.device)

                feeded = torch.cat([inputs, mask], dim=1)
                feeded = feeded.to(self.device)

                output = self.model(feeded)

                if self.args.res:
                    output = output + inputs

                L2_loss = self.loss(output, target)

                psnr = 10 * log10(1 / L2_loss.item())
                ssim = pytorch_ssim.ssim(output, target)

                losses.update(L2_loss.item(), inputs.size(0))
                psnres.update(psnr, inputs.size(0))
                ssimes.update(ssim.item(), inputs.size(0))

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

        print("Epoches:%s,Losses:%.3f,PSNR:%.3f,SSIM:%.3f" %
              (epoch + 1, losses.avg, psnres.avg, ssimes.avg))
        self.record('val/loss_L2', losses.avg, epoch)
        self.record('val/PSNR', psnres.avg, epoch)
        self.record('val/SSIM', ssimes.avg, epoch)

        self.metric = psnres.avg
Exemplo n.º 8
0
    def train(self, epoch):

        self.current_epoch = epoch

        batch_time = AverageMeter()
        data_time = AverageMeter()
        LoggerLossG = AverageMeter()
        LoggerLossGGAN = AverageMeter()
        LoggerLossGL1 = AverageMeter()

        LoggerLossD = AverageMeter()
        LoggerLossDreal = AverageMeter()
        LoggerLossDfake = AverageMeter()

        # switch to train mode
        self.model.train()
        self.discriminator.train()

        end = time.time()

        bar = Bar('Processing {} '.format(self.args.arch),
                  max=len(self.train_loader))

        for i, (inputs, target) in enumerate(self.train_loader):

            current_index = len(self.train_loader) * epoch + i

            valid = torch.ones((inputs.size(0), self.patch, self.patch),
                               requires_grad=False).cuda()
            fake = torch.zeros((inputs.size(0), self.patch, self.patch),
                               requires_grad=False).cuda()
            input_image = inputs[:, 0:3, :, :]
            mask = inputs[:, 3:4, :, :]
            reverse_mask = 1 - mask

            if self.args.gpu:
                inputs = inputs.cuda()
                input_image = input_image.cuda()
                mask = mask.cuda()
                reverse_mask = reverse_mask.cuda()
                target = target.cuda()
                valid.cuda()
                fake.cuda()

            # ---------------
            # Train model
            # --------------

            self.optimizer_G.zero_grad()
            fake_input = self.model(inputs)
            pred_fake = self.discriminator(fake_input, input_image)
            loss_GAN = self.criterion_GAN(pred_fake, valid)
            loss_pixel = self.criterion_L1(fake_input, target)  # fake in
            loss_G = loss_GAN + 100 * loss_pixel
            loss_G.backward()
            self.optimizer_G.step()

            self.optimizer_D.zero_grad()
            pred_real = self.discriminator(target, input_image)
            loss_real = self.criterion_GAN(pred_real, valid)
            pred_fake = self.discriminator(fake_input.detach(), input_image)
            loss_fake = self.criterion_GAN(pred_fake, fake)
            loss_D = 0.5 * (loss_real + loss_fake)
            loss_D.backward()
            self.optimizer_D.step()

            # ---------------------
            #        Logger
            # ---------------------

            LoggerLossGGAN.update(loss_GAN.item(), inputs.size(0))
            LoggerLossGL1.update(loss_pixel.item(), inputs.size(0))
            LoggerLossG.update(loss_G.item(), inputs.size(0))
            LoggerLossDfake.update(loss_real.item(), inputs.size(0))
            LoggerLossDreal.update(loss_fake.item(), inputs.size(0))
            LoggerLossD.update(loss_D.item(), inputs.size(0))

            # ---------------------
            #        Visualize
            # ---------------------

            if current_index % (len(self.train_loader) // 10) == 0:
                self.writer.add_images('train/Goutput', deNorm(fake_input),
                                       current_index)
                self.writer.add_images('train/target', deNorm(target),
                                       current_index)
                self.writer.add_images('train/input',
                                       deNorm(inputs[:, 0:3, :, :]),
                                       current_index)
                self.writer.add_images(
                    'train/mask', inputs[:, 3:4, :, :].repeat((1, 3, 1, 1)),
                    current_index)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            # plot progress
            bar.suffix = '({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss D: {loss_d:.4f} | Loss G: {loss_g:.4f} | Loss L1: {loss_l1:.6f} '.format(
                batch=i + 1,
                size=len(self.train_loader),
                data=data_time.val,
                bt=batch_time.val,
                total=bar.elapsed_td,
                eta=bar.eta_td,
                loss_d=LoggerLossD.avg,
                loss_g=LoggerLossGGAN.avg,
                loss_l1=LoggerLossGL1.avg)
            bar.next()

        bar.finish()
        self.writer.add_scalar('train/loss/GAN', LoggerLossGGAN.avg, epoch)
        self.writer.add_scalar('train/loss/D', LoggerLossD.avg, epoch)
        self.writer.add_scalar('train/loss/L1', LoggerLossGL1.avg, epoch)
        self.writer.add_scalar('train/loss/G', LoggerLossG.avg, epoch)
        self.writer.add_scalar('train/loss/Dreal', LoggerLossDreal.avg, epoch)
        self.writer.add_scalar('train/loss/Dfake', LoggerLossDfake.avg, epoch)
Exemplo n.º 9
0
    def validate(self, epoch):

        self.current_epoch = epoch
        batch_time = AverageMeter()
        data_time = AverageMeter()
        psnres = AverageMeter()
        ssimes = AverageMeter()

        # switch to evaluate mode
        self.model.eval()

        end = time.time()
        bar = Bar('Processing {} '.format(self.args.arch),
                  max=len(self.val_loader))

        with torch.no_grad():
            for i, (inputs, target) in enumerate(self.val_loader):

                # measure data loading time
                if self.args.gpu:
                    inputs = inputs.cuda()
                    target = target.cuda()

                output = self.model(inputs)

                output = deNorm(output)
                target = deNorm(target)

                ## psnr and  ssim calculator.
                mse = self.criterion_GAN(output, target)
                psnr = 10 * log10(1 / mse.item())
                ssim = pytorch_ssim.ssim(output, target)

                if i == 10:
                    self.writer.add_images('val/Goutput', output, epoch)
                    self.writer.add_images('val/target', target, epoch)
                    self.writer.add_images('val/input',
                                           deNorm(inputs[:, 0:3, :, :]), epoch)
                    self.writer.add_images(
                        'val/mask', inputs[:, 3:4, :, :].repeat((1, 3, 1, 1)),
                        epoch)

                psnres.update(psnr, inputs.size(0))
                ssimes.update(ssim, inputs.size(0))

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()
                # plot progress
                bar.suffix = '({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | SSIM: {ssim:.4f} | PSNR: {psnr:.4f}'.format(
                    batch=i + 1,
                    size=len(self.val_loader),
                    data=data_time.val,
                    bt=batch_time.val,
                    total=bar.elapsed_td,
                    eta=bar.eta_td,
                    ssim=ssimes.avg,
                    psnr=psnres.avg)
                bar.next()
        bar.finish()

        self.writer.add_scalar('val/SSIM', ssimes.avg, epoch)
        self.writer.add_scalar('val/PSNR', psnres.avg, epoch)

        self.metric = psnres.avg
Exemplo n.º 10
0
def validate(val_loader, model, criterions, args):
    criterion_classification = criterions[0]
    criterion_segmentation = criterions[1]

    losses_label = AverageMeter()
    acces_label = AverageMeter()
    losses_mask = AverageMeter()
    acces_mask = AverageMeter()

    # switch to evaluate mode
    model.eval()

    for i, (inputs, target, label, full_image) in enumerate(val_loader):
        # measure data loading time
        if args.gpu:
            inputs = inputs.cuda()
            full_image = full_image.cuda()
            target = target.cuda()
            label = label.cuda()

        with torch.no_grad():
            input_var = torch.autograd.Variable(inputs)
            full_image_var = torch.autograd.Variable(full_image)
            target_var = torch.autograd.Variable(target.long())
            label_var = torch.autograd.Variable(label.long())

        # compute output
        output_label,output_mask = model(input_var,full_image_var)
        
        loss_label = criterion_classification(output_label, label_var)
        loss_mask = criterion_segmentation(output_mask, target_var)

        acc_label = accuracy(output_label.data.cpu(), label.cpu())
        acc_mask = accuracy(output_mask.data.cpu(), target.cpu())

        # measure accuracy and record loss
        losses_label.update(loss_label.item(), inputs.size(0))
        losses_mask.update(loss_mask.item(), inputs.size(0))
        acces_label.update(acc_label, inputs.size(0))
        acces_mask.update(acc_mask, inputs.size(0))

    return losses_label.avg, acces_label.avg, losses_mask.avg, acces_mask.avg
Exemplo n.º 11
0
def train(train_loader, model, criterions, optimizer,args):

    losses_label = AverageMeter()
    acces_label = AverageMeter()
    losses_mask = AverageMeter()
    acces_mask = AverageMeter()

    criterion_classification = criterions[0]
    criterion_segmentation = criterions[1]

    # switch to train mode
    model.train()


    for i, (inputs, target, label, full_image) in enumerate(train_loader):
        # measure data loading time

        if args.gpu:
            inputs = inputs.cuda()
            full_image = full_image.cuda()
            target = target.cuda()
            label = label.cuda()
        
        input_var = torch.autograd.Variable(inputs)
        full_image_var = torch.autograd.Variable(full_image)
        target_var = torch.autograd.Variable(target.long())
        label_var = torch.autograd.Variable(label.long())
	
        # compute output
        output_label,output_mask = model(input_var,full_image_var)
        loss_label = criterion_classification(output_label, label_var)
        loss_mask = criterion_segmentation(output_mask, target_var)

        loss = loss_label + 10*loss_mask

        acc_label = accuracy(output_label.data, label)
        acc_mask = accuracy(output_mask.data, target)
        
        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure accuracy and record loss
        losses_label.update(loss_label.item(), inputs.size(0))
        losses_mask.update(loss_mask.item(), inputs.size(0))
        acces_label.update(acc_label, inputs.size(0))
        acces_mask.update(acc_mask, inputs.size(0))

    return losses_label.avg, acces_label.avg, losses_mask.avg, acces_mask.avg