Пример #1
0
 def init_device(self, net, gpu_id=None, whether_DP=False):
     gpu_id = gpu_id or self.default_gpu
     device = torch.device(
         "cuda:{}".format(gpu_id) if torch.cuda.is_available() else 'cpu')
     net = net.to(device)
     # if torch.cuda.is_available():
     if whether_DP:
         net = DataParallelWithCallback(net,
                                        device_ids=range(
                                            torch.cuda.device_count()))
     return net
Пример #2
0
def inceptions_score_fid_all(base_dir, generator_func, z_sampling_func,
                             y_sampling_func, use_data_parallel,
                             n_minibatch_sampling,
                             refrence_fid_statistics_path):
    model_paths = sorted(glob.glob(base_dir + "/models/gen_ema*.pytorch"))

    epochs = []
    inception_scores = []
    fids = []

    print(
        f"Calculating All Inception Scores / FIDs...  (# {len(model_paths)})")
    for i, path in enumerate(model_paths):
        model = generator_func()
        model.load_state_dict(torch.load(path))
        if use_data_parallel:
            model = DataParallelWithCallback(model)

        # generate images
        with torch.no_grad():
            imgs = []
            for _ in range(n_minibatch_sampling):
                z = z_sampling_func()
                y = y_sampling_func()
                x = model(z, y)
                imgs.append(x)
            imgs = torch.cat(imgs, dim=0).cpu()

        # eval_is
        iscore, _ = calculate_inception_score_given_tensor(imgs)
        # fid
        fid_score = calculate_fid_given_tensor(imgs,
                                               refrence_fid_statistics_path)
        # epoch
        epoch = int(
            os.path.basename(path).replace("gen_ema_epoch_",
                                           "").replace(".pytorch", ""))
        epochs.append(epoch)
        inception_scores.append(iscore)
        fids.append(fid_score)
        print(
            f"epoch = {epoch}, inception_score = {iscore}, fid = {fid_score}    [{i+1}/{len(model_paths)}]"
        )

    df = pd.DataFrame({
        "epoch": epochs,
        "inception_score": inception_scores,
        "fid": fids
    })
    df.to_csv(base_dir + "/inception_score.csv", index=False)
Пример #3
0
def main(args):
    # set which gpu(s) to use, should set PCI_BUS_ID first
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    num_gpus = (len(args.gpu) + 1) // 2

    # create model directories
    checkpath(args.modelG_path)
    checkpath(args.modelD_path)

    # tensorboard writer
    checkpath(args.log_path)
    writer = SummaryWriter(args.log_path)

    # load data
    data_loader, num_train = get_loader(args,
                                        batch_size=args.batch_size,
                                        shuffle=True,
                                        num_workers=args.num_workers,
                                        training=True)
    data_loader_val, num_test = get_loader(args,
                                           batch_size=args.val_bs,
                                           shuffle=False,
                                           num_workers=args.num_workers,
                                           training=False)
    print('Finished data loading')
    print("The length of the train set is: {}".format(num_train))
    print("The length of the test set is: {}".format(num_test))

    colorguide = True
    if args.nocolor:
        colorguide = False

    # loss multipliers
    lambdas = [
        args.lambda_imgl1, args.lambda_wfl1, args.lambda_ssim,
        args.lambda_color
    ]
    lambda_perceptual = args.lambda_perceptual

    # Generator
    netG = Generator(lambdas=lambdas,
                     colorguide=colorguide,
                     input_nc=1,
                     output_nc=1)

    if num_gpus > 1:
        # multi-gpu training with synchonized batchnormalization
        # make sure enough number of gpus are available
        assert (torch.cuda.device_count() >= num_gpus)
        # since we have set CUDA_VISIBLE_DEVICES to avoid some invalid device id issues
        netG = DataParallelWithCallback(
            netG, device_ids=[i for i in range(num_gpus)])
        netG_single = netG.module
    else:
        # single gpu training
        netG_single = netG

    # Discriminator
    netD = NLayerDiscriminator(input_nc=4, n_layers=4)
    if num_gpus > 1:
        netD = DataParallelWithCallback(
            netD, device_ids=[i for i in range(num_gpus)])
        netD_single = netD.module
    else:
        netD_single = netD

    # print(netG_single)
    # print(netD_single)

    if args.pretrained and args.netG_path != '' and args.netD_path != '':
        netG_single.load_state_dict(torch.load(args.netG_path))
        netD_single.load_state_dict(torch.load(args.netD_path))

    # Right now we only support gpu training
    if torch.cuda.is_available():
        netG = netG.cuda()
        netD = netD.cuda()

    # define the perceptual loss, place outside the forward func in G for better multi-gpu training
    Ploss = PNet()
    if num_gpus > 1:
        Ploss = DataParallelWithCallback(
            Ploss, device_ids=[i for i in range(num_gpus)])

    if torch.cuda.is_available():
        Ploss = Ploss.cuda()

    # setup optimizer
    lr = args.learning_rate
    optimizerD = optim.Adam(netD_single.parameters(),
                            lr=lr,
                            betas=(args.beta1, 0.999))
    schedulerD = ReduceLROnPlateau(optimizerD,
                                   factor=0.7,
                                   patience=10,
                                   mode='min',
                                   min_lr=1e-06)
    optimizerG = optim.Adam(netG_single.parameters(),
                            lr=lr,
                            betas=(args.beta1, 0.999))
    schedulerG = ReduceLROnPlateau(optimizerG,
                                   factor=0.7,
                                   patience=10,
                                   mode='min',
                                   min_lr=1e-06)

    for epoch in range(args.num_epochs):
        # switch to train mode
        netG.train()
        netD.train()

        for i, (img_real, wf_real, color_real) in enumerate(data_loader, 0):
            img_real = img_real.cuda()
            wf_real = wf_real.cuda()
            color_real = color_real.cuda()

            # Update D network, we freeze parameters in G to save memory
            for p in netG_single.parameters():
                p.requires_grad = False
            for p in netD_single.parameters():
                p.requires_grad = True

            # if using TTUR, D can be trained multiple steps per G step
            for _ in range(args.D_steps):
                optimizerD.zero_grad()

                # train with real
                real_AB = torch.cat((img_real, wf_real), 1)
                errD_real = 0.5 * netD(trainG=False,
                                       trainReal=True,
                                       real_AB=real_AB,
                                       fake_AB=None).sum()
                errD_real.backward()

                # train with fake
                img_fake, wf_fake, _, _, _, _, _ = netG(trainG=False,
                                                        img_real=None,
                                                        wf_real=wf_real,
                                                        color_real=color_real)
                fake_AB = torch.cat((img_fake, wf_fake), 1)
                errD_fake = 0.5 * netD(trainG=False,
                                       trainReal=False,
                                       real_AB=None,
                                       fake_AB=fake_AB).sum()
                errD_fake.backward()

                errD = errD_real + errD_fake
                optimizerD.step()
                del img_fake, wf_fake, fake_AB, real_AB, errD_real, errD_fake

            iterations_before_epoch = epoch * len(data_loader)
            writer.add_scalar('D Loss', errD.item(),
                              iterations_before_epoch + i)
            del errD

            # Update G network, we freeze parameters in D to save memory
            for p in netG.parameters():
                p.requires_grad = True
            for p in netD.parameters():
                p.requires_grad = False

            optimizerG.zero_grad()

            img_fake, wf_fake, lossG, wf_ssim, img_l1, color_l1, wf_l1 = netG(
                trainG=True,
                img_real=img_real,
                wf_real=wf_real,
                color_real=color_real)
            ploss = Ploss(img_fake, img_real.detach()).sum()
            fake_AB = torch.cat((img_fake, wf_fake), 1)
            lossD = netD(trainG=True,
                         trainReal=False,
                         real_AB=None,
                         fake_AB=fake_AB).sum()
            errG = (lossG.sum() + lambda_perceptual * ploss + lossD)
            errG.backward()
            optimizerG.step()

            del color_real, fake_AB, lossG, errG

            if args.nocolor:
                print(
                    'Epoch: [{}/{}] Iter: [{}/{}] PercLoss : {:.4f} ImageL1 : {:.6f} WfL1 : {:.6f} WfSSIM : {:.6f}'
                    .format(epoch, args.num_epochs, i, len(data_loader),
                            ploss.item(),
                            img_l1.sum().item(),
                            wf_l1.sum().item(),
                            num_gpus + wf_ssim.sum().item()))
            else:
                print(
                    'Epoch: [{}/{}] Iter: [{}/{}] PercLoss : {:.4f} ImageL1 : {:.6f} WfL1 : {:.6f} WfSSIM : {:.6f} ColorL1 : {:.6f}'
                    .format(epoch, args.num_epochs, i, len(data_loader),
                            ploss.item(),
                            img_l1.sum().item(),
                            wf_l1.sum().item(),
                            num_gpus + wf_ssim.sum().item(),
                            color_l1.sum().item()))
                writer.add_scalar('Color Loss',
                                  color_l1.sum().item(),
                                  iterations_before_epoch + i)

            # tensorboard log
            writer.add_scalar('G Loss', lossD.item(),
                              iterations_before_epoch + i)
            writer.add_scalar('Image L1 Loss',
                              img_l1.sum().item(), iterations_before_epoch + i)
            writer.add_scalar('Wireframe MSSSIM Loss',
                              num_gpus + wf_ssim.sum().item(),
                              iterations_before_epoch + i)
            writer.add_scalar('Wireframe L1',
                              wf_l1.sum().item(), iterations_before_epoch + i)
            writer.add_scalar('Image Perceptual Loss', ploss.item(),
                              iterations_before_epoch + i)

            del wf_ssim, ploss, img_l1, color_l1, wf_l1, lossD

            with torch.no_grad():
                # show generated tarining images in tensorboard
                if i % args.val_freq == 0:
                    real_img = vutils.make_grid(
                        img_real.detach()[:args.val_size],
                        normalize=True,
                        scale_each=True)
                    writer.add_image('Real Image', real_img,
                                     (iterations_before_epoch + i) //
                                     args.val_freq)
                    real_wf = vutils.make_grid(
                        wf_real.detach()[:args.val_size],
                        normalize=True,
                        scale_each=True)
                    writer.add_image('Real Wireframe', real_wf,
                                     (iterations_before_epoch + i) //
                                     args.val_freq)
                    fake_img = vutils.make_grid(
                        img_fake.detach()[:args.val_size],
                        normalize=True,
                        scale_each=True)
                    writer.add_image('Fake Image', fake_img,
                                     (iterations_before_epoch + i) //
                                     args.val_freq)
                    fake_wf = vutils.make_grid(
                        wf_fake.detach()[:args.val_size],
                        normalize=True,
                        scale_each=True)
                    writer.add_image('Fake Wireframe', fake_wf,
                                     (iterations_before_epoch + i) //
                                     args.val_freq)
                    del real_img, real_wf, fake_img, fake_wf

            del img_real, wf_real, img_fake, wf_fake

        # do checkpointing
        if epoch % args.save_freq == 0 and epoch > 0:
            torch.save(netG_single.state_dict(),
                       '{}/netG_epoch_{}.pth'.format(args.modelG_path, epoch))
            torch.save(netD_single.state_dict(),
                       '{}/netD_epoch_{}.pth'.format(args.modelD_path, epoch))

        # validation
        with torch.no_grad():
            netG_single.eval()
            # since we use a realtively large validation batchsize, we don't go through the who test set
            (img_real, wf_real, color_real) = next(iter(data_loader_val))
            img_real = img_real.cuda()
            wf_real = wf_real.cuda()
            color_real = color_real.cuda()

            img_fake, wf_fake, _, _, _, _, _ = netG_single(
                trainG=False,
                img_real=None,
                wf_real=wf_real,
                color_real=color_real)

            # update lr based on the validation perceptual loss
            val_score = Ploss(img_fake.detach(), img_real.detach()).sum()
            schedulerG.step(val_score)
            schedulerD.step(val_score)
            print('Current lr: {:.6f}'.format(
                optimizerG.param_groups[0]['lr']))

            real_img = vutils.make_grid(img_real.detach()[:args.val_size],
                                        normalize=True,
                                        scale_each=True)
            writer.add_image('Test: Real Image', real_img, epoch)
            real_wf = vutils.make_grid(wf_real.detach()[:args.val_size],
                                       normalize=True,
                                       scale_each=True)
            writer.add_image('Test: Real Wireframe', real_wf, epoch)
            fake_img = vutils.make_grid(img_fake.detach()[:args.val_size],
                                        normalize=True,
                                        scale_each=True)
            writer.add_image('Test: Fake Image', fake_img, epoch)
            fake_wf = vutils.make_grid(wf_fake.detach()[:args.val_size],
                                       normalize=True,
                                       scale_each=True)
            writer.add_image('Test: Fake Wireframe', fake_wf, epoch)

            netG_single.train()

            del img_real, real_img, wf_real, real_wf, img_fake, fake_img, wf_fake, fake_wf

    # close tb writer
    writer.close()
Пример #4
0
    def __setup(self):
        self.basedir = join('checkpoints', self.opt.arch)
        if not os.path.exists(self.basedir):
            os.makedirs(self.basedir)

        self.best_psnr = 0
        self.best_loss = 1e6
        self.epoch = 0  # start from epoch 0 or last checkpoint epoch
        self.iteration = 0

        cuda = not self.opt.no_cuda
        self.device = 'cuda' if cuda else 'cpu'
        print('Cuda Acess: %d' % cuda)
        if cuda and not torch.cuda.is_available():
            raise Exception("No GPU found, please run without --cuda")

        torch.manual_seed(self.opt.seed)
        if cuda:
            torch.cuda.manual_seed(self.opt.seed)
        """Model"""
        print("=> creating model '{}'".format(self.opt.arch))
        self.net = models.__dict__[self.opt.arch]()
        # initialize parameters

        init_params(
            self.net,
            init_type=self.opt.init)  # disable for default initialization

        if len(self.opt.gpu_ids) > 1:
            from models.sync_batchnorm import DataParallelWithCallback
            self.net = DataParallelWithCallback(self.net,
                                                device_ids=self.opt.gpu_ids)

        if self.opt.loss == 'l2':
            self.criterion = nn.MSELoss()
        if self.opt.loss == 'l1':
            self.criterion = nn.L1Loss()
        if self.opt.loss == 'smooth_l1':
            self.criterion = nn.SmoothL1Loss()
        if self.opt.loss == 'ssim':
            self.criterion = SSIMLoss(data_range=1, channel=31)
        if self.opt.loss == 'l2_ssim':
            self.criterion = MultipleLoss(
                [nn.MSELoss(),
                 SSIMLoss(data_range=1, channel=31)],
                weight=[1, 2.5e-3])

        print(self.criterion)

        if cuda:
            self.net.to(self.device)
            self.criterion = self.criterion.to(self.device)
        """Logger Setup"""
        log = not self.opt.no_log
        if log:
            self.writer = get_summary_writer(
                os.path.join(self.basedir, 'logs'), self.opt.prefix)
        """Optimization Setup"""
        self.optimizer = optim.Adam(self.net.parameters(),
                                    lr=self.opt.lr,
                                    weight_decay=self.opt.wd,
                                    amsgrad=False)
        """Resume previous model"""
        if self.opt.resume:
            # Load checkpoint.
            self.load(self.opt.resumePath, not self.opt.no_ropt)
        else:
            print('==> Building model..')
            print(self.net)
Пример #5
0
class Engine(object):
    def __init__(self, opt):
        self.prefix = opt.prefix
        self.opt = opt
        self.net = None
        self.optimizer = None
        self.criterion = None
        self.basedir = None
        self.iteration = None
        self.epoch = None
        self.best_psnr = None
        self.best_loss = None
        self.writer = None

        self.__setup()

    def __setup(self):
        self.basedir = join('checkpoints', self.opt.arch)
        if not os.path.exists(self.basedir):
            os.makedirs(self.basedir)

        self.best_psnr = 0
        self.best_loss = 1e6
        self.epoch = 0  # start from epoch 0 or last checkpoint epoch
        self.iteration = 0

        cuda = not self.opt.no_cuda
        self.device = 'cuda' if cuda else 'cpu'
        print('Cuda Acess: %d' % cuda)
        if cuda and not torch.cuda.is_available():
            raise Exception("No GPU found, please run without --cuda")

        torch.manual_seed(self.opt.seed)
        if cuda:
            torch.cuda.manual_seed(self.opt.seed)
        """Model"""
        print("=> creating model '{}'".format(self.opt.arch))
        self.net = models.__dict__[self.opt.arch]()
        # initialize parameters

        init_params(
            self.net,
            init_type=self.opt.init)  # disable for default initialization

        if len(self.opt.gpu_ids) > 1:
            from models.sync_batchnorm import DataParallelWithCallback
            self.net = DataParallelWithCallback(self.net,
                                                device_ids=self.opt.gpu_ids)

        if self.opt.loss == 'l2':
            self.criterion = nn.MSELoss()
        if self.opt.loss == 'l1':
            self.criterion = nn.L1Loss()
        if self.opt.loss == 'smooth_l1':
            self.criterion = nn.SmoothL1Loss()
        if self.opt.loss == 'ssim':
            self.criterion = SSIMLoss(data_range=1, channel=31)
        if self.opt.loss == 'l2_ssim':
            self.criterion = MultipleLoss(
                [nn.MSELoss(),
                 SSIMLoss(data_range=1, channel=31)],
                weight=[1, 2.5e-3])

        print(self.criterion)

        if cuda:
            self.net.to(self.device)
            self.criterion = self.criterion.to(self.device)
        """Logger Setup"""
        log = not self.opt.no_log
        if log:
            self.writer = get_summary_writer(
                os.path.join(self.basedir, 'logs'), self.opt.prefix)
        """Optimization Setup"""
        self.optimizer = optim.Adam(self.net.parameters(),
                                    lr=self.opt.lr,
                                    weight_decay=self.opt.wd,
                                    amsgrad=False)
        """Resume previous model"""
        if self.opt.resume:
            # Load checkpoint.
            self.load(self.opt.resumePath, not self.opt.no_ropt)
        else:
            print('==> Building model..')
            print(self.net)

    def forward(self, inputs):
        if self.opt.chop:
            output = self.forward_chop(inputs)
        else:
            output = self.net(inputs)

        return output

    def forward_chop(self, x, base=16):
        n, c, b, h, w = x.size()
        h_half, w_half = h // 2, w // 2

        shave_h = np.ceil(h_half / base) * base - h_half
        shave_w = np.ceil(w_half / base) * base - w_half

        shave_h = shave_h if shave_h >= 10 else shave_h + base
        shave_w = shave_w if shave_w >= 10 else shave_w + base

        h_size, w_size = int(h_half + shave_h), int(w_half + shave_w)

        inputs = [
            x[..., 0:h_size, 0:w_size], x[..., 0:h_size, (w - w_size):w],
            x[..., (h - h_size):h, 0:w_size], x[..., (h - h_size):h,
                                                (w - w_size):w]
        ]

        outputs = [self.net(input_i) for input_i in inputs]

        output = torch.zeros_like(x)
        output_w = torch.zeros_like(x)

        output[..., 0:h_half, 0:w_half] += outputs[0][..., 0:h_half, 0:w_half]
        output_w[..., 0:h_half, 0:w_half] += 1
        output[..., 0:h_half,
               w_half:w] += outputs[1][..., 0:h_half,
                                       (w_size - w + w_half):w_size]
        output_w[..., 0:h_half, w_half:w] += 1
        output[..., h_half:h,
               0:w_half] += outputs[2][..., (h_size - h + h_half):h_size,
                                       0:w_half]
        output_w[..., h_half:h, 0:w_half] += 1
        output[..., h_half:h,
               w_half:w] += outputs[3][..., (h_size - h + h_half):h_size,
                                       (w_size - w + w_half):w_size]
        output_w[..., h_half:h, w_half:w] += 1

        output /= output_w

        return output

    def __step(self, train, inputs, targets):
        if train:
            self.optimizer.zero_grad()
        loss_data = 0
        total_norm = None
        if self.get_net().bandwise:
            O = []
            for time, (i, t) in enumerate(
                    zip(inputs.split(1, 1), targets.split(1, 1))):
                o = self.net(i)
                O.append(o)
                loss = self.criterion(o, t)
                if train:
                    loss.backward()
                loss_data += loss.item()
            outputs = torch.cat(O, dim=1)
        else:
            outputs = self.net(inputs)
            # outputs = torch.clamp(self.net(inputs), 0, 1)
            # loss = self.criterion(outputs, targets)

            # if outputs.ndimension() == 5:
            #     loss = self.criterion(outputs[:,0,...], torch.clamp(targets[:,0,...], 0, 1))
            # else:
            #     loss = self.criterion(outputs, torch.clamp(targets, 0, 1))

            loss = self.criterion(outputs, targets)

            if train:
                loss.backward()
            loss_data += loss.item()
        if train:
            total_norm = nn.utils.clip_grad_norm_(self.net.parameters(),
                                                  self.opt.clip)
            self.optimizer.step()

        return outputs, loss_data, total_norm

    def load(self, resumePath=None, load_opt=True):
        model_best_path = join(self.basedir, self.prefix, 'model_latest.pth')
        if os.path.exists(model_best_path):
            best_model = torch.load(model_best_path)

        print('==> Resuming from checkpoint %s..' % resumePath)
        assert os.path.isdir(
            'checkpoints'), 'Error: no checkpoint directory found!'
        checkpoint = torch.load(resumePath or model_best_path)
        #### comment when using memnet
        self.epoch = checkpoint['epoch']
        self.iteration = checkpoint['iteration']
        if load_opt:
            self.optimizer.load_state_dict(checkpoint['optimizer'])
        ####
        self.get_net().load_state_dict(checkpoint['net'])

    def train(self, train_loader):
        print('\nEpoch: %d' % self.epoch)
        self.net.train()
        train_loss = 0

        for batch_idx, (inputs, targets) in enumerate(train_loader):

            if not self.opt.no_cuda:
                inputs, targets = inputs.to(self.device), targets.to(
                    self.device)
            outputs, loss_data, total_norm = self.__step(True, inputs, targets)
            train_loss += loss_data
            avg_loss = train_loss / (batch_idx + 1)

            if not self.opt.no_log:
                self.writer.add_scalar(join(self.prefix, 'train_loss'),
                                       loss_data, self.iteration)
                self.writer.add_scalar(join(self.prefix, 'train_avg_loss'),
                                       avg_loss, self.iteration)

            self.iteration += 1

            progress_bar(
                batch_idx, len(train_loader),
                'AvgLoss: %.4e | Loss: %.4e | Norm: %.4e' %
                (avg_loss, loss_data, total_norm))

        self.epoch += 1
        if not self.opt.no_log:
            self.writer.add_scalar(join(self.prefix, 'train_loss_epoch'),
                                   avg_loss, self.epoch)

    def validate(self, valid_loader, name):
        self.net.eval()
        validate_loss = 0
        total_psnr = 0
        print('[i] Eval dataset {}...'.format(name))
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(valid_loader):
                if not self.opt.no_cuda:
                    inputs, targets = inputs.to(self.device), targets.to(
                        self.device)

                outputs, loss_data, _ = self.__step(False, inputs, targets)
                psnr = np.mean(cal_bwpsnr(outputs, targets))

                validate_loss += loss_data
                avg_loss = validate_loss / (batch_idx + 1)

                total_psnr += psnr
                avg_psnr = total_psnr / (batch_idx + 1)

                progress_bar(batch_idx, len(valid_loader),
                             'Loss: %.4e | PSNR: %.4f' % (avg_loss, avg_psnr))

        if not self.opt.no_log:
            self.writer.add_scalar(join(self.prefix, name, 'val_loss_epoch'),
                                   avg_loss, self.epoch)
            self.writer.add_scalar(join(self.prefix, name, 'val_psnr_epoch'),
                                   avg_psnr, self.epoch)

        return avg_psnr, avg_loss

    def save_checkpoint(self, model_out_path=None, **kwargs):
        if not model_out_path:
            model_out_path = join(
                self.basedir, self.prefix,
                "model_epoch_%d_%d.pth" % (self.epoch, self.iteration))

        state = {
            'net': self.get_net().state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'epoch': self.epoch,
            'iteration': self.iteration,
        }

        state.update(kwargs)

        if not os.path.isdir(join(self.basedir, self.prefix)):
            os.makedirs(join(self.basedir, self.prefix))

        torch.save(state, model_out_path)
        print("Checkpoint saved to {}".format(model_out_path))

    # saving result into disk
    def test_develop(self, test_loader, savedir=None, verbose=True):
        from scipy.io import savemat
        from os.path import basename, exists

        def torch2numpy(hsi):
            if self.net.use_2dconv:
                R_hsi = hsi.data[0].cpu().numpy().transpose((1, 2, 0))
            else:
                R_hsi = hsi.data[0].cpu().numpy()[0, ...].transpose((1, 2, 0))
            return R_hsi

        self.net.eval()
        test_loss = 0
        total_psnr = 0
        dataset = test_loader.dataset.dataset

        res_arr = np.zeros((len(test_loader), 3))
        input_arr = np.zeros((len(test_loader), 3))

        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(test_loader):
                if not self.opt.no_cuda:
                    inputs, targets = inputs.cuda(), targets.cuda()
                outputs, loss_data, _ = self.__step(False, inputs, targets)

                test_loss += loss_data
                avg_loss = test_loss / (batch_idx + 1)

                res_arr[batch_idx, :] = MSIQA(outputs, targets)
                input_arr[batch_idx, :] = MSIQA(inputs, targets)
                """Visualization"""
                # Visualize3D(inputs.data[0].cpu().numpy())
                # Visualize3D(outputs.data[0].cpu().numpy())

                psnr = res_arr[batch_idx, 0]
                ssim = res_arr[batch_idx, 1]
                if verbose:
                    print(batch_idx, psnr, ssim)

                if savedir:
                    filedir = join(
                        savedir,
                        basename(dataset.filenames[batch_idx]).split('.')[0])
                    outpath = join(filedir, '{}.mat'.format(self.opt.arch))

                    if not exists(filedir):
                        os.mkdir(filedir)

                    if not exists(outpath):
                        savemat(outpath, {'R_hsi': torch2numpy(outputs)})

        return res_arr, input_arr

    def test_real(self, test_loader, savedir=None):
        """Warning: this code is not compatible with bandwise flag"""
        from scipy.io import savemat
        from os.path import basename
        self.net.eval()
        dataset = test_loader.dataset.dataset

        with torch.no_grad():
            for batch_idx, inputs in enumerate(test_loader):
                if not self.opt.no_cuda:
                    inputs = inputs.cuda()

                outputs = self.forward(inputs)
                """Visualization"""
                input_np = inputs[0].cpu().numpy()
                output_np = outputs[0].cpu().numpy()

                display = np.concatenate([input_np, output_np], axis=-1)

                Visualize3D(display)
                # Visualize3D(outputs[0].cpu().numpy())
                # Visualize3D((outputs-inputs).data[0].cpu().numpy())

                if savedir:
                    R_hsi = outputs.data[0].cpu().numpy()[0, ...].transpose(
                        (1, 2, 0))
                    savepath = join(
                        savedir,
                        basename(dataset.filenames[batch_idx]).split('.')[0],
                        self.opt.arch + '.mat')
                    savemat(savepath, {'R_hsi': R_hsi})

        return outputs

    def get_net(self):
        if len(self.opt.gpu_ids) > 1:
            return self.net.module
        else:
            return self.net
Пример #6
0
def train(cases):
    utils.load_settings(args, "settings/res128.json", cases)

    output_dir = f"res128_case{cases}"

    device = "cuda"
    torch.backends.cudnn.benchmark = True

    print("--- Conditions ---")
    print("- Case : ", cases)
    print(args)

    batch_size = 64
    dataloader = load_dataset(batch_size)

    model_G = Generator(64,
                        128,
                        args.n_classes,
                        n_projected_dims=args.n_projected_dims)
    model_G_ema = Generator(64,
                            128,
                            args.n_classes,
                            n_projected_dims=args.n_projected_dims)
    model_D = Discriminator(64, 128, args.n_classes)
    model_G, model_D = model_G.to(device), model_D.to(device)
    model_G_ema = model_G_ema.to(device)

    model_G, model_D = DataParallelWithCallback(
        model_G), DataParallelWithCallback(model_D)
    model_G_ema = DataParallelWithCallback(model_G_ema)

    param_G = torch.optim.Adam(model_G.parameters(), lr=5e-5, betas=(0, 0.999))
    param_D = torch.optim.Adam(model_D.parameters(), lr=2e-4, betas=(0, 0.999))

    result = {"d_loss": [], "g_loss": []}
    n = len(dataloader)
    onehot_encoding = torch.eye(args.n_classes).to(device)

    fake_img, fake_onehots = None, None
    ema = utils.EMA(model_G, model_G_ema)
    gan_loss = utils.HingeLoss(batch_size, device)
    update_G_counter = 1  # G:D=1:2

    def generate_fake_imgs(batch_len, labels):
        fake_onehots = labels.detach()
        x = truncnorm.rvs(-1.5, 1.5,
                          size=(batch_len,
                                128))  # truncation trick = [-1.5, 1.5]
        rand_X = torch.FloatTensor(x).to(device)
        fake_img = model_G(rand_X, fake_onehots)
        return fake_img, fake_onehots

    for epoch in range(args.n_epoch):
        log_loss_D, log_loss_G = [], []

        for i, (real_img, labels) in tqdm(enumerate(dataloader), total=n):
            batch_len = len(real_img)
            if batch_len != batch_size: continue

            real_img = real_img.to(device)
            real_onehots = onehot_encoding[labels.to(device)]

            # train D
            param_G.zero_grad()
            param_D.zero_grad()
            # train real
            d_out_real = model_D(real_img, real_onehots)
            loss_real = gan_loss(d_out_real, "dis_real")
            # train fake
            if fake_img is None:
                fake_img, fake_onehots = generate_fake_imgs(
                    batch_len, real_onehots)
                fake_img = fake_img.detach()
            d_out_fake = model_D(fake_img, fake_onehots)
            loss_fake = gan_loss(d_out_fake, "dis_fake")
            # fake + real loss
            loss = loss_real + loss_fake
            log_loss_D.append(loss.item())

            # backprop
            loss.backward()
            param_D.step()

            # train G
            if update_G_counter == 0:
                param_G.zero_grad()
                param_D.zero_grad()

                fake_img_g, fake_onehots = generate_fake_imgs(
                    batch_len, real_onehots)
                fake_img = fake_img_g.detach()
                g_out = model_D(fake_img_g, fake_onehots)

                loss = gan_loss(g_out, "gen")
                log_loss_G.append(
                    loss.item())  # loss without orthogonal regularization

                # backprop
                loss.backward()
                # orthogonal regularization
                utils.orthogonal_regularization(model_G)
                # update G
                param_G.step()
                # update EMA
                ema.update()
                # reset counter
                update_G_counter = 1
            else:
                update_G_counter -= 1

        # log
        result["d_loss"].append(statistics.mean(log_loss_D))
        result["g_loss"].append(statistics.mean(log_loss_G))
        print(
            f"epoch = {epoch}, g_loss = {result['g_loss'][-1]}, d_loss = {result['d_loss'][-1]}"
        )

        # save screen shot
        if epoch % 1 == 0:
            if not os.path.exists(output_dir):
                os.mkdir(output_dir)
            torchvision.utils.save_image(fake_img[:25],
                                         f"{output_dir}/epoch_{epoch:03}.png",
                                         nrow=5,
                                         padding=3,
                                         normalize=True,
                                         range=(-1.0, 1.0))

        # save weights
        if epoch % 5 == 0:
            if not os.path.exists(output_dir + "/models"):
                os.mkdir(output_dir + "/models")
            utils.save_model(
                model_G, f"{output_dir}/models/gen_epoch_{epoch:03}.pytorch",
                True)
            utils.save_model(
                model_G_ema,
                f"{output_dir}/models/gen_ema_epoch_{epoch:03}.pytorch", True)
            utils.save_model(
                model_D, f"{output_dir}/models/dis_epoch_{epoch:03}.pytorch",
                True)

    # ログ
    with open(output_dir + "/logs.pkl", "wb") as fp:
        pickle.dump(result, fp)