예제 #1
0
def test(args):

    transform = transforms.Compose(
        [transforms.Resize((args.crop_height,args.crop_width)),
         transforms.ToTensor(),
         # transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
         ])

    dataset_dirs = utils.get_testdata_link(args.dataset_dir)

    # a_test_data = dsets.ImageFolder(dataset_dirs['testA'], transform=transform)
    # b_test_data = dsets.ImageFolder(dataset_dirs['testB'], transform=transform)
    a_test_data = ListDataSet('/media/l/新加卷1/city/data/river/train_256_9w.lst', transform=transform)
    b_test_data = ListDataSet('/media/l/新加卷/city/jinan_z3.lst', transform=transform)

    a_test_loader = torch.utils.data.DataLoader(a_test_data, batch_size=args.batch_size, shuffle=True, num_workers=4)
    b_test_loader = torch.utils.data.DataLoader(b_test_data, batch_size=args.batch_size, shuffle=True, num_workers=4)

    Gab = define_Gen(input_nc=3, output_nc=3, ngf=args.ngf, netG='resnet_9blocks', norm=args.norm, 
                                                    use_dropout= not args.no_dropout, gpu_ids=args.gpu_ids)
    Gba = define_Gen(input_nc=3, output_nc=3, ngf=args.ngf, netG='resnet_9blocks', norm=args.norm, 
                                                    use_dropout= not args.no_dropout, gpu_ids=args.gpu_ids)

    utils.print_networks([Gab, Gba], ['Gab', 'Gba'])

    try:
        ckpt = utils.load_checkpoint('/media/l/新加卷/city/project/cycleGAN-PyTorch/checkpoints/horse2zebra/latest.ckpt')
        Gab.load_state_dict(ckpt['Gab'])
        Gba.load_state_dict(ckpt['Gba'])
    except:
        print(' [*] No checkpoint!')
    res = []
    for i in range(3):
        """ run """
        a_real_test = Variable(iter(a_test_loader).next()[0], requires_grad=True)
        b_real_test = Variable(iter(b_test_loader).next()[0], requires_grad=True)
        a_real_test, b_real_test = utils.cuda([a_real_test, b_real_test])

        Gab.eval()
        Gba.eval()

        with torch.no_grad():
            a_fake_test = Gab(b_real_test)
            b_fake_test = Gba(a_real_test)
            a_recon_test = Gab(b_fake_test)
            b_recon_test = Gba(a_fake_test)
        res.append(a_real_test)
        res.append(b_fake_test)
        res.append(b_real_test)
        res.append(a_fake_test)

    pic = (torch.cat(res,
                     dim=0).data + 1) / 2.0

    if not os.path.isdir(args.results_dir):
        os.makedirs(args.results_dir)

    torchvision.utils.save_image(pic, args.results_dir + '/sample.png', nrow=2)
예제 #2
0
def test(args, epoch):
    transform = transforms.Compose(
        [transforms.Resize((args.crop_height, args.crop_width)),
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])

    dataset_dirs = utils.get_testdata_link(args.dataset_dir)

    a_test_data = dsets.ImageFolder(dataset_dirs['testA'], transform=transform)
    b_test_data = dsets.ImageFolder(dataset_dirs['testB'], transform=transform)


    a_test_loader = torch.utils.data.DataLoader(a_test_data, batch_size=args.batch_size, shuffle=True, num_workers=4)
    b_test_loader = torch.utils.data.DataLoader(b_test_data, batch_size=args.batch_size, shuffle=True, num_workers=4)

    Gab = define_Gen(input_nc=3, output_nc=3, ngf=args.ngf, netG=args.gen_net, norm=args.norm, 
                                                    use_dropout= args.use_dropout, gpu_ids=args.gpu_ids, self_attn=args.self_attn, spectral = args.spectral)
    Gba = define_Gen(input_nc=3, output_nc=3, ngf=args.ngf, netG=args.gen_net, norm=args.norm, 
                                                    use_dropout= args.use_dropout, gpu_ids=args.gpu_ids, self_attn=args.self_attn, spectral = args.spectral)

    utils.print_networks([Gab,Gba], ['Gab','Gba'])


    ckpt = utils.load_checkpoint('%s/latest.ckpt' % (args.checkpoint_path))
    Gab.load_state_dict(ckpt['Gab'])
    Gba.load_state_dict(ckpt['Gba'])


    """ run """
    a_real_test = Variable(iter(a_test_loader).next()[0], requires_grad=True)
    b_real_test = Variable(iter(b_test_loader).next()[0], requires_grad=True)
    a_real_test, b_real_test = utils.cuda([a_real_test, b_real_test])


    Gab.eval()
    Gba.eval()

    with torch.no_grad():
        a_fake_test = Gab(b_real_test)
        b_fake_test = Gba(a_real_test)
        a_recon_test = Gab(b_fake_test)
        b_recon_test = Gba(a_fake_test)
        # Calculate ssim loss
        gray = kornia.color.RgbToGrayscale()
        m = kornia.losses.SSIM(11, 'mean')
        ba_ssim = m(gray((a_real_test + 1) / 2.0), gray((b_fake_test + 1) / 2.0))
        ab_ssim = m(gray((b_real_test + 1) / 2.0), gray((a_fake_test + 1) / 2.0))

    pic = (torch.cat([a_real_test, b_fake_test, a_recon_test, b_real_test, a_fake_test, b_recon_test], dim=0).data + 1) / 2.0

    if not os.path.isdir(args.results_path):
        os.makedirs(args.results_path)

    torchvision.utils.save_image(pic, args.results_path+'/sample_' + str(epoch) + '_' + str(1 - 2*round(ba_ssim.item(), 4)) + '_' + str(1 - 2*round(ab_ssim.item(), 4)) + '.jpg', nrow=args.batch_size)
예제 #3
0
def test(args):
    transform = transforms.Compose(
        [transforms.Resize((args.crop_height, args.crop_width)),
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])

    dataset_dirs = utils.get_testdata_link(args.dataset_dir)

    a_test_data = dsets.ImageFolder(dataset_dirs['testA'], transform=transform)
    b_test_data = dsets.ImageFolder(dataset_dirs['testB'], transform=transform)

    a_test_loader = torch.utils.data.DataLoader(a_test_data, batch_size=args.batch_size, shuffle=True, num_workers=4)
    b_test_loader = torch.utils.data.DataLoader(b_test_data, batch_size=args.batch_size, shuffle=True, num_workers=4)

    Gab = define_Gen(input_nc=3, output_nc=3, ngf=args.ngf, netG='resnet_9blocks', norm=args.norm,
                     use_dropout=not args.no_dropout, gpu_ids=args.gpu_ids)
    Gba = define_Gen(input_nc=3, output_nc=3, ngf=args.ngf, netG='resnet_9blocks', norm=args.norm,
                     use_dropout=not args.no_dropout, gpu_ids=args.gpu_ids)

    utils.print_networks([Gab, Gba], ['Gab', 'Gba'])

    try:

        # ckpt = utils.load_checkpoint('%s/latest.ckpt' % (args.checkpoint_dir))

        ckpt = utils.load_checkpoint('%s' % args.checkpoint)
        Gab.load_state_dict(ckpt['Gab'])
        Gba.load_state_dict(ckpt['Gba'])
    except:
        print(' [*] No checkpoint!')

    """ run """
    a_real_test = Variable(iter(a_test_loader).next()[0], requires_grad=True)
    b_real_test = Variable(iter(b_test_loader).next()[0], requires_grad=True)
    a_real_test, b_real_test = utils.cuda([a_real_test, b_real_test])

    Gab.eval()
    Gba.eval()

    with torch.no_grad():
        a_fake_test = Gab(b_real_test)
        b_fake_test = Gba(a_real_test)
        a_recon_test = Gab(b_fake_test)
        b_recon_test = Gba(a_fake_test)

    pic = (torch.cat([a_real_test, b_fake_test, a_recon_test, b_real_test, a_fake_test, b_recon_test],
                     dim=0).data + 1) / 2.0

    if not os.path.isdir(args.results_dir):
        os.makedirs(args.results_dir)

    torchvision.utils.save_image(pic, args.results_dir + '/sample.jpg', nrow=3)
예제 #4
0
    def __init__(self, args):

        if args.dataset == 'voc2012':
            self.n_channels = 21
        elif args.dataset == 'cityscapes':
            self.n_channels = 20
        elif args.dataset == 'acdc':
            self.n_channels = 4

        # Define the network
        self.Gsi = define_Gen(
            input_nc=3,
            output_nc=self.n_channels,
            ngf=args.ngf,
            netG='resnet_9blocks_softmax',
            norm=args.norm,
            use_dropout=not args.no_dropout,
            gpu_ids=args.gpu_ids)  # for image to segmentation

        utils.print_networks([self.Gsi], ['Gsi'])

        self.CE = nn.CrossEntropyLoss()
        self.activation_softmax = nn.Softmax2d()
        self.gsi_optimizer = torch.optim.Adam(self.Gsi.parameters(),
                                              lr=args.lr,
                                              betas=(0.9, 0.999))

        ### writer for tensorboard
        self.writer_supervised = SummaryWriter(tensorboard_loc + '_supervised')
        self.running_metrics_val = utils.runningScore(self.n_channels,
                                                      args.dataset)

        self.args = args

        if not os.path.isdir(args.checkpoint_dir):
            os.makedirs(args.checkpoint_dir)

        try:
            ckpt = utils.load_checkpoint('%s/latest_supervised_model.ckpt' %
                                         (args.checkpoint_dir))
            self.start_epoch = ckpt['epoch']
            self.Gsi.load_state_dict(ckpt['Gsi'])
            self.gsi_optimizer.load_state_dict(ckpt['gsi_optimizer'])
            self.best_iou = ckpt['best_iou']
        except:
            print(' [*] No checkpoint!')
            self.start_epoch = 0
            self.best_iou = -100
예제 #5
0
    def __init__(self, args):
        super(cycleGAN, self).__init__()
        self.device = torch.device(
            "cuda:" +
            str(args.cuda_id) if torch.cuda.is_available() else "cpu")
        print(self.device)
        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")
        print(self.device)
        self.GAB = generator.Generator().to(self.device)
        self.GBA = generator.Generator().to(self.device)
        self.DA = discriminator.Discriminator().to(self.device)
        self.DB = discriminator.Discriminator().to(self.device)
        self.init_weights(self.GAB)
        self.init_weights(self.GBA)
        self.init_weights(self.DA)
        self.init_weights(self.DB)
        utils.print_networks([self.GAB, self.GBA, self.DA, self.DB],
                             ['GAB', 'GBA', 'DA', 'DB'])

        self.optim_G = torch.optim.Adam(itertools.chain(
            self.GAB.parameters(), self.GBA.parameters()),
                                        lr=args.g_lr,
                                        betas=(args.beta1, args.beta2))
        self.optim_D = torch.optim.Adam(itertools.chain(
            self.DA.parameters(), self.DB.parameters()),
                                        lr=args.d_lr,
                                        betas=(args.beta1, args.beta2))
        self.MSE = nn.MSELoss()
        self.L1 = nn.L1Loss()
        # self.MSE = nn.BCEWithLogitsLoss()

        self.train_loader = data_loader.get_loader(args.img_path, args.mode,
                                                   args.batch_size,
                                                   args.num_workers,
                                                   args.crop_size,
                                                   args.img_size)
예제 #6
0
 def __init__(self, args):
     super(cycleGAN, self).__init__()
     self.device = torch.device("cuda:" + str(args.cuda_id) +
                                "" if torch.cuda.is_available() else "cpu")
     self.GAB = generator.Generator().to(self.device)
     self.GBA = generator.Generator().to(self.device)
     self.DA = discriminator.Discriminator().to(self.device)
     self.DB = discriminator.Discriminator().to(self.device)
     utils.print_networks([self.GAB, self.GBA, self.DA, self.DB],
                          ['GAB', 'GBA', 'DA', 'DB'])
     try:
         ckpt = utils.load_checkpoint('%s/latest.ckpt' %
                                      (args.model_save_dir))
         self.GAB.load_state_dict(ckpt['GAB'])
         self.GBA.load_state_dict(ckpt['GBA'])
         self.DA.load_state_dict(ckpt['DA'])
         self.DB.load_state_dict(ckpt['DB'])
     except:
         print(' [*] No checkpoint!')
     self.test_loader = data_loader.get_loader(args.img_path, args.mode,
                                               args.batch_size,
                                               args.num_workers,
                                               args.crop_size,
                                               args.img_size)
    def __init__(self, args):

        if args.dataset == 'voc2012':
            self.n_channels = 21
        elif args.dataset == 'cityscapes':
            self.n_channels = 20
        elif args.dataset == 'acdc':
            self.n_channels = 4

        # Define the network
        #####################################################
        # for segmentaion to image
        self.Gis = define_Gen(input_nc=self.n_channels,
                              output_nc=3,
                              ngf=args.ngf,
                              netG='deeplab',
                              norm=args.norm,
                              use_dropout=not args.no_dropout,
                              gpu_ids=args.gpu_ids)
        # for image to segmentation
        self.Gsi = define_Gen(input_nc=3,
                              output_nc=self.n_channels,
                              ngf=args.ngf,
                              netG='deeplab',
                              norm=args.norm,
                              use_dropout=not args.no_dropout,
                              gpu_ids=args.gpu_ids)
        self.Di = define_Dis(input_nc=3,
                             ndf=args.ndf,
                             netD='pixel',
                             n_layers_D=3,
                             norm=args.norm,
                             gpu_ids=args.gpu_ids)
        self.Ds = define_Dis(
            input_nc=self.n_channels,
            ndf=args.ndf,
            netD='pixel',
            n_layers_D=3,
            norm=args.norm,
            gpu_ids=args.gpu_ids)  # for voc 2012, there are 21 classes

        self.old_Gis = define_Gen(input_nc=self.n_channels,
                                  output_nc=3,
                                  ngf=args.ngf,
                                  netG='resnet_9blocks',
                                  norm=args.norm,
                                  use_dropout=not args.no_dropout,
                                  gpu_ids=args.gpu_ids)
        self.old_Gsi = define_Gen(input_nc=3,
                                  output_nc=self.n_channels,
                                  ngf=args.ngf,
                                  netG='resnet_9blocks_softmax',
                                  norm=args.norm,
                                  use_dropout=not args.no_dropout,
                                  gpu_ids=args.gpu_ids)
        self.old_Di = define_Dis(input_nc=3,
                                 ndf=args.ndf,
                                 netD='pixel',
                                 n_layers_D=3,
                                 norm=args.norm,
                                 gpu_ids=args.gpu_ids)

        ### To put the pretrained weights in Gis and Gsi
        # if args.dataset != 'acdc':
        #     saved_state_dict = torch.load(pretrained_loc)
        #     new_params_Gsi = self.Gsi.state_dict().copy()
        #     # new_params_Gis = self.Gis.state_dict().copy()
        #     for name, param in new_params_Gsi.items():
        #         # print(name)
        #         if name in saved_state_dict and param.size() == saved_state_dict[name].size():
        #             new_params_Gsi[name].copy_(saved_state_dict[name])
        #             # print('copy {}'.format(name))
        #     self.Gsi.load_state_dict(new_params_Gsi)
        # for name, param in new_params_Gis.items():
        #     # print(name)
        #     if name in saved_state_dict and param.size() == saved_state_dict[name].size():
        #         new_params_Gis[name].copy_(saved_state_dict[name])
        #         # print('copy {}'.format(name))
        # # self.Gis.load_state_dict(new_params_Gis)

        ### This is just so as to get pretrained methods for the case of Gis
        if args.dataset == 'voc2012':
            try:
                ckpt_for_Arnab_loss = utils.load_checkpoint(
                    './ckpt_for_Arnab_loss.ckpt')
                self.old_Gis.load_state_dict(ckpt_for_Arnab_loss['Gis'])
                self.old_Gsi.load_state_dict(ckpt_for_Arnab_loss['Gsi'])
            except:
                print(
                    '**There is an error in loading the ckpt_for_Arnab_loss**')

        utils.print_networks([self.Gsi], ['Gsi'])

        utils.print_networks([self.Gis, self.Gsi, self.Di, self.Ds],
                             ['Gis', 'Gsi', 'Di', 'Ds'])

        self.args = args

        ### interpolation
        self.interp = nn.Upsample((args.crop_height, args.crop_width),
                                  mode='bilinear',
                                  align_corners=True)

        self.MSE = nn.MSELoss()
        self.L1 = nn.L1Loss()
        self.CE = nn.CrossEntropyLoss()
        self.activation_softmax = nn.Softmax2d()
        self.activation_tanh = nn.Tanh()
        self.activation_sigmoid = nn.Sigmoid()

        ### Tensorboard writer
        self.writer_semisuper = SummaryWriter(tensorboard_loc + '_semisuper')
        self.running_metrics_val = utils.runningScore(self.n_channels,
                                                      args.dataset)

        ### For adding gaussian noise
        self.gauss_noise = utils.GaussianNoise(sigma=0.2)

        # Optimizers
        #####################################################
        self.g_optimizer = torch.optim.Adam(itertools.chain(
            self.Gis.parameters(), self.Gsi.parameters()),
                                            lr=args.lr,
                                            betas=(0.5, 0.999))
        self.d_optimizer = torch.optim.Adam(itertools.chain(
            self.Di.parameters(), self.Ds.parameters()),
                                            lr=args.lr,
                                            betas=(0.5, 0.999))

        self.g_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
            self.g_optimizer,
            lr_lambda=utils.LambdaLR(args.epochs, 0, args.decay_epoch).step)
        self.d_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
            self.d_optimizer,
            lr_lambda=utils.LambdaLR(args.epochs, 0, args.decay_epoch).step)

        # Try loading checkpoint
        #####################################################
        if not os.path.isdir(args.checkpoint_dir):
            os.makedirs(args.checkpoint_dir)

        try:
            ckpt = utils.load_checkpoint('%s/latest_semisuper_cycleGAN.ckpt' %
                                         (args.checkpoint_dir))
            self.start_epoch = ckpt['epoch']
            self.Di.load_state_dict(ckpt['Di'])
            self.Ds.load_state_dict(ckpt['Ds'])
            self.Gis.load_state_dict(ckpt['Gis'])
            self.Gsi.load_state_dict(ckpt['Gsi'])
            self.d_optimizer.load_state_dict(ckpt['d_optimizer'])
            self.g_optimizer.load_state_dict(ckpt['g_optimizer'])
            self.best_iou = ckpt['best_iou']
        except:
            print(' [*] No checkpoint!')
            self.start_epoch = 0
            self.best_iou = -100
예제 #8
0
    def __init__(self, args):

        # Set up both gens and discs
        self.Gab = define_Gen(input_nc=3,
                              output_nc=3,
                              ngf=args.ngf,
                              netG=args.gen_net,
                              norm=args.norm,
                              use_dropout=args.use_dropout,
                              gpu_ids=args.gpu_ids,
                              self_attn=args.self_attn,
                              spectral=args.spectral)
        self.Gba = define_Gen(input_nc=3,
                              output_nc=3,
                              ngf=args.ngf,
                              netG=args.gen_net,
                              norm=args.norm,
                              use_dropout=args.use_dropout,
                              gpu_ids=args.gpu_ids,
                              self_attn=args.self_attn,
                              spectral=args.spectral)

        self.Da = define_Dis(input_nc=3,
                             ndf=args.ndf,
                             netD=args.dis_net,
                             n_layers_D=3,
                             norm=args.norm,
                             gpu_ids=args.gpu_ids,
                             spectral=args.spectral,
                             self_attn=args.self_attn)
        self.Db = define_Dis(input_nc=3,
                             ndf=args.ndf,
                             netD=args.dis_net,
                             n_layers_D=3,
                             norm=args.norm,
                             gpu_ids=args.gpu_ids,
                             spectral=args.spectral,
                             self_attn=args.self_attn)

        utils.print_networks([self.Gab, self.Gba, self.Da, self.Db],
                             ['Gab', 'Gba', 'Da', 'Db'])

        # Loss functions
        self.MSE = nn.MSELoss()
        self.L1 = nn.L1Loss()
        self.ssim = kornia.losses.SSIM(11, reduction='mean')

        # Optimizers
        self.g_optimizer = torch.optim.Adam(itertools.chain(
            self.Gab.parameters(), self.Gba.parameters()),
                                            lr=args.lr,
                                            betas=(0.5, 0.999))
        self.d_optimizer = torch.optim.Adam(itertools.chain(
            self.Da.parameters(), self.Db.parameters()),
                                            lr=args.lr,
                                            betas=(0.5, 0.999))

        self.g_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
            self.g_optimizer,
            lr_lambda=utils.LambdaLR(args.epochs, 0, args.decay_epoch).step)
        self.d_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
            self.d_optimizer,
            lr_lambda=utils.LambdaLR(args.epochs, 0, args.decay_epoch).step)

        # Checkpoints
        if not os.path.isdir(args.checkpoint_path):
            os.makedirs(args.checkpoint_path)

        try:
            ckpt = utils.load_checkpoint('%s/latest.ckpt' %
                                         (args.checkpoint_path))
            self.start_epoch = ckpt['epoch']
            self.Da.load_state_dict(ckpt['Da'])
            self.Db.load_state_dict(ckpt['Db'])
            self.Gab.load_state_dict(ckpt['Gab'])
            self.Gba.load_state_dict(ckpt['Gba'])
            self.d_optimizer.load_state_dict(ckpt['d_optimizer'])
            self.g_optimizer.load_state_dict(ckpt['g_optimizer'])
        except:
            print(' [*] No checkpoint!')
            self.start_epoch = 0
def gen_samples(args, epoch):
    transform = transforms.Compose([
        transforms.Resize((args.crop_height, args.crop_width)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    if args.specific_samples:
        dataset_dirs = utils.get_sampledata_link(args.dataset_dir)

        a_test_data = dsets.ImageFolder(dataset_dirs['sampleA'],
                                        transform=transform)
        b_test_data = dsets.ImageFolder(dataset_dirs['sampleB'],
                                        transform=transform)

    else:
        dataset_dirs = utils.get_testdata_link(args.dataset_dir)

        a_test_data = dsets.ImageFolder(dataset_dirs['testA'],
                                        transform=transform)
        b_test_data = dsets.ImageFolder(dataset_dirs['testB'],
                                        transform=transform)

    a_test_loader = torch.utils.data.DataLoader(a_test_data,
                                                batch_size=args.batch_size,
                                                shuffle=False,
                                                num_workers=4)
    b_test_loader = torch.utils.data.DataLoader(b_test_data,
                                                batch_size=args.batch_size,
                                                shuffle=False,
                                                num_workers=4)

    Gab = define_Gen(input_nc=3,
                     output_nc=3,
                     ngf=args.ngf,
                     netG=args.gen_net,
                     norm=args.norm,
                     use_dropout=args.use_dropout,
                     gpu_ids=args.gpu_ids,
                     self_attn=args.self_attn,
                     spectral=args.spectral)
    Gba = define_Gen(input_nc=3,
                     output_nc=3,
                     ngf=args.ngf,
                     netG=args.gen_net,
                     norm=args.norm,
                     use_dropout=args.use_dropout,
                     gpu_ids=args.gpu_ids,
                     self_attn=args.self_attn,
                     spectral=args.spectral)

    utils.print_networks([Gab, Gba], ['Gab', 'Gba'])

    ckpt = utils.load_checkpoint('%s/latest.ckpt' % (args.checkpoint_path))
    Gab.load_state_dict(ckpt['Gab'])
    Gba.load_state_dict(ckpt['Gba'])

    ab_ssims = []
    ba_ssims = []
    a_names = []
    b_names = []
    """ run """
    for i, (a_real_test,
            b_real_test) in enumerate(zip(a_test_loader, b_test_loader)):
        a_fnames = a_test_loader.dataset.samples[i * 16:i * 16 + 16]
        b_fnames = b_test_loader.dataset.samples[i * 16:i * 16 + 16]

        a_real_test = Variable(a_real_test[0], requires_grad=True)
        b_real_test = Variable(b_real_test[0], requires_grad=True)
        a_real_test, b_real_test = utils.cuda([a_real_test, b_real_test])

        gray = kornia.color.RgbToGrayscale()
        m = kornia.losses.SSIM(11, 'mean')

        Gab.eval()
        Gba.eval()

        with torch.no_grad():
            a_fake_test = Gab(b_real_test)
            b_fake_test = Gba(a_real_test)
            a_recon_test = Gab(b_fake_test)
            b_recon_test = Gba(a_fake_test)
            # Calculate ssim loss
            b = a_real_test.size(0)

            for j in range(min(args.batch_size, b)):
                a_real = a_real_test[j].unsqueeze(0)
                b_fake = b_fake_test[j].unsqueeze(0)
                a_recon = a_recon_test[j].unsqueeze(0)
                b_real = b_real_test[j].unsqueeze(0)
                a_fake = a_fake_test[j].unsqueeze(0)
                b_recon = b_recon_test[j].unsqueeze(0)

                ba_ssim = m(gray((a_real + 1) / 2.0), gray((b_fake + 1) / 2.0))
                ab_ssim = m(gray((b_real + 1) / 2.0), gray((a_fake + 1) / 2.0))

                ab_ssims.append(ab_ssim.item())
                ba_ssims.append(ba_ssim.item())

                pic = (torch.cat([
                    a_real_test, b_fake_test, a_recon_test, b_real_test,
                    a_fake_test, b_recon_test
                ],
                                 dim=0).data + 1) / 2.0

                path = args.results_path + '/b_fake/'
                image_path = path + a_fnames[j][0].split('/')[-1]
                if not os.path.isdir(path):
                    os.makedirs(path)
                torchvision.utils.save_image((b_fake.data + 1) / 2.0,
                                             image_path)

                a_names.append(a_fnames[j][0].split('/')[-1])

                path = args.results_path + '/a_recon/'
                image_path = path + a_fnames[j][0].split('/')[-1]
                if not os.path.isdir(path):
                    os.makedirs(path)
                torchvision.utils.save_image((a_recon.data + 1) / 2.0,
                                             image_path)

                path = args.results_path + '/a_fake/'
                image_path = path + b_fnames[j][0].split('/')[-1]
                if not os.path.isdir(path):
                    os.makedirs(path)
                torchvision.utils.save_image((a_fake.data + 1) / 2.0,
                                             image_path)

                b_names.append(b_fnames[j][0].split('/')[-1])

                path = args.results_path + '/b_recon/'
                image_path = path + b_fnames[j][0].split('/')[-1]
                if not os.path.isdir(path):
                    os.makedirs(path)
                torchvision.utils.save_image((b_recon.data + 1) / 2.0,
                                             image_path)

        df1 = pd.DataFrame(list(zip(a_names, ba_ssims)),
                           columns=['Name', 'SSIM_A_to_B'])
        df2 = pd.DataFrame(list(zip(b_names, ab_ssims)),
                           columns=['Name', 'SSIM_B_to_A'])

        df1.to_csv(args.results_path + '/b_fake/' + 'SSIM_A_to_B.csv')
        df2.to_csv(args.results_path + '/a_fake/' + 'SSIM_B_to_A.csv')
예제 #10
0
def test(args):

    transform = transforms.Compose([
        transforms.Resize((args.crop_height, args.crop_width)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    dataset_dirs = utils.get_testdata_link(args.dataset_dir)

    a_test_data = dsets.ImageFolder(dataset_dirs['testA'], transform=transform)
    b_test_data = dsets.ImageFolder(dataset_dirs['testB'], transform=transform)

    a_test_loader = torch.utils.data.DataLoader(a_test_data,
                                                batch_size=args.batch_size,
                                                num_workers=4)
    b_test_loader = torch.utils.data.DataLoader(b_test_data,
                                                batch_size=args.batch_size,
                                                num_workers=4)

    Gab = Generator(input_nc=3,
                    output_nc=3,
                    ngf=args.ngf,
                    netG='resnet_9blocks',
                    norm=args.norm,
                    use_dropout=not args.no_dropout,
                    gpu_ids=args.gpu_ids)
    Gba = Generator(input_nc=3,
                    output_nc=3,
                    ngf=args.ngf,
                    netG='resnet_9blocks',
                    norm=args.norm,
                    use_dropout=not args.no_dropout,
                    gpu_ids=args.gpu_ids)

    utils.print_networks([Gab, Gba], ['Gab', 'Gba'])

    try:
        ckpt = utils.load_checkpoint('%s/latest.ckpt' % (args.checkpoint_dir))
        Gab.load_state_dict(ckpt['Gab'])
        Gba.load_state_dict(ckpt['Gba'])
    except:
        print(' [*] No checkpoint!')
    """ run """
    a_real_test = Variable(iter(a_test_loader).next()[0], requires_grad=True)
    b_real_test = Variable(iter(b_test_loader).next()[0], requires_grad=True)
    a_real_test, b_real_test = utils.cuda([a_real_test, b_real_test])

    Gab.eval()
    Gba.eval()

    with torch.no_grad():
        a_fake_test = Gab(b_real_test)
        b_fake_test = Gba(a_real_test)
        a_recon_test = Gab(b_fake_test)
        b_recon_test = Gba(a_fake_test)

    pic = (torch.cat([
        a_real_test, b_fake_test, a_recon_test, b_real_test, a_fake_test,
        b_recon_test
    ],
                     dim=0).data + 1) / 2.0

    if not os.path.isdir(args.results_dir):
        os.makedirs(args.results_dir)

    torchvision.utils.save_image(pic, args.results_dir + '/sample.jpg', nrow=3)

    #create output dirs if they don't exist
    if not os.path.exists(args.results_dir + '/inputA'):
        os.makedirs(args.results_dir + '/inputA')
    if not os.path.exists(args.results_dir + '/outputA'):
        os.makedirs(args.results_dir + '/outputA')
    if not os.path.exists(args.results_dir + './outputB'):
        os.makedirs(args.results_dir + '/outputB')

    for i, (a_real_test,
            b_real_test) in enumerate(zip(a_test_loader, b_test_loader)):
        #set model input
        a_real_test = Variable(a_real_test[0], requires_grad=True)
        b_real_test = Variable(b_real_test[0], requires_grad=True)
        a_real_test, b_real_test = utils.cuda([a_real_test, b_real_test])

        Gab.eval()
        Gba.eval()

        with torch.no_grad():
            a_fake_test = Gab(b_real_test)
            b_fake_test = Gba(a_real_test)
            a_recon_test = Gab(b_fake_test)
            b_recon_test = Gba(a_fake_test)
            a_fake_test = (a_fake_test + 1) / 2.0
            b_real_test = (b_real_test + 1) / 2.0
            a_real_test = (a_real_test + 1) / 2.0
            b_fake_test = (b_fake_test + 1) / 2.0

        # Save image files
        torchvision.utils.save_image(
            b_real_test, args.results_dir + '/inputA/%04d.png' % (i + 1))
        torchvision.utils.save_image(
            a_fake_test, args.results_dir + '/outputA/%04d.png' % (i + 1))

    print("\n\nCreated Output Directories\n\n")
예제 #11
0
    def __init__(self, args):

        # Generators and Discriminators
        self.G_AtoB = define_Gen(input_nc=3,
                                 output_nc=3,
                                 ngf=args.ngf,
                                 norm=args.norm,
                                 use_dropout=not args.no_dropout,
                                 gpu_ids=args.gpu_ids)
        self.G_BtoA = define_Gen(input_nc=3,
                                 output_nc=3,
                                 ngf=args.ngf,
                                 norm=args.norm,
                                 use_dropout=not args.no_dropout,
                                 gpu_ids=args.gpu_ids)
        self.D_A = define_Dis(input_nc=3,
                              ndf=args.ndf,
                              norm=args.norm,
                              gpu_ids=args.gpu_ids)
        self.D_B = define_Dis(input_nc=3,
                              ndf=args.ndf,
                              norm=args.norm,
                              gpu_ids=args.gpu_ids)

        utils.print_networks([self.G_AtoB, self.G_BtoA, self.D_A, self.D_B],
                             ['G_AtoB', 'G_BtoA', 'D_A', 'D_B'])

        # MSE loss and L1 loss
        self.MSE = nn.MSELoss()
        self.L1 = nn.L1Loss()

        # Optimizers and lr_scheduler
        self.g_optimizer = torch.optim.Adam(itertools.chain(
            self.G_AtoB.parameters(), self.G_BtoA.parameters()),
                                            lr=args.lr,
                                            betas=(0.5, 0.999))
        self.d_optimizer = torch.optim.Adam(itertools.chain(
            self.D_A.parameters(), self.D_B.parameters()),
                                            lr=args.lr,
                                            betas=(0.5, 0.999))

        self.g_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
            self.g_optimizer,
            lr_lambda=utils.LambdaLR(args.epochs, 0, args.decay_epoch).step)
        self.d_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
            self.d_optimizer,
            lr_lambda=utils.LambdaLR(args.epochs, 0, args.decay_epoch).step)

        # Check if there is a checkpoint
        if not os.path.isdir(args.checkpoint_dir):
            os.makedirs(args.checkpoint_dir)

        try:
            ckpt = utils.load_checkpoint('%s/latest.ckpt' %
                                         (args.checkpoint_dir))
            self.start_epoch = ckpt['epoch']
            self.D_A.load_state_dict(ckpt['D_A'])
            self.D_B.load_state_dict(ckpt['D_B'])
            self.G_AtoB.load_state_dict(ckpt['G_AtoB'])
            self.G_BtoA.load_state_dict(ckpt['G_BtoA'])
            self.d_optimizer.load_state_dict(ckpt['d_optimizer'])
            self.g_optimizer.load_state_dict(ckpt['g_optimizer'])
        except:
            print(' [*] No checkpoint! Train from the beginning! ')
            self.start_epoch = 0
예제 #12
0
파일: main.py 프로젝트: jshe/wasserstein-2
def main():
    ## parse flags
    config = Options().parse()
    utils.print_opts(config)

    ## set up folders
    exp_dir = os.path.join(config.exp_dir, config.exp_name)
    model_dir = os.path.join(exp_dir, 'models')
    img_dir = os.path.join(exp_dir, 'images')
    if not os.path.exists(exp_dir):
        os.makedirs(exp_dir)
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    if not os.path.exists(img_dir):
        os.makedirs(img_dir)

    if config.use_tbx:
        # remove old tensorboardX logs
        logs = glob.glob(os.path.join(exp_dir, 'events.out.tfevents.*'))
        if len(logs) > 0:
            os.remove(logs[0])
        tbx_writer = SummaryWriter(exp_dir)
    else:
        tbx_writer = None

    ## initialize data loaders/generators & model
    r_loader, z_loader = get_loader(config)
    if config.solver == 'w1':
        model = W1(config, r_loader, z_loader)
    elif config.solver == 'w2':
        model = W2(config, r_loader, z_loader)
    elif config.solver == 'bary_ot':
        model = BaryOT(config, r_loader, z_loader)
    cudnn.benchmark = True
    networks = model.get_networks()
    utils.print_networks(networks)

    ## training
    ## stage 1 (dual stage) of bary_ot
    start_time = time.time()
    if config.solver == 'bary_ot':
        print("Starting: dual stage for %d iters." % config.dual_iters)
        for step in range(config.dual_iters):
            model.train_diter_only(config)
            if ((step + 1) % 100) == 0:
                stats = model.get_stats(config)
                end_time = time.time()
                stats['disp_time'] = (end_time - start_time) / 60.
                start_time = end_time
                utils.print_out(stats, step + 1, config.dual_iters, tbx_writer)
        print("dual stage iterations complete.")

    ## main training loop of w1 / w2 or stage 2 (map stage) of bary-ot
    map_iters = config.map_iters if config.solver == 'bary_ot' else config.train_iters
    if config.solver == 'bary_ot':
        print("Starting: map stage for %d iters." % map_iters)
    else:
        print("Starting training...")
    for step in range(map_iters):
        model.train_iter(config)
        if ((step + 1) % 100) == 0:
            stats = model.get_stats(config)
            end_time = time.time()
            stats['disp_time'] = (end_time - start_time) / 60.
            start_time = end_time
            utils.print_out(stats, step + 1, map_iters, tbx_writer)
        if ((step + 1) % 500) == 0:
            images = model.get_visuals(config)
            utils.visualize_iter(images, img_dir, step + 1, config)
    print("Training complete.")
    networks = model.get_networks()
    utils.save_networks(networks, model_dir)

    ## testing
    root = "./mvg_test"
    file = open(os.path.join(root, "data.pkl"), "rb")
    fixed_z = pickle.load(file)
    file.close()
    fixed_z = utils.to_var(fixed_z)
    fixed_gz = model.g(fixed_z).view(*fixed_z.size())
    utils.visualize_single(fixed_gz, os.path.join(img_dir, 'test.png'), config)
예제 #13
0
파일: main.py 프로젝트: lufeng22/OT-ICNN
def main():
    config = Options().parse()
    utils.print_opts(config)

    ## set up folders
    dir_string = './{0}_{1}/trial_{2}/'.format(config.solver, config.data, config.trial) if config.solver != 'w2' else \
                                    './{0}_gen{2}_{1}/trial_{3}/'.format(config.solver, config.data, config.gen, config.trial)

    exp_dir = dir_string  #os.path.join(config.exp_dir, config.exp_name)
    model_dir = os.path.join(exp_dir, 'models')
    img_dir = os.path.join(exp_dir, 'images')
    if not os.path.exists(exp_dir):
        os.makedirs(exp_dir)
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    if not os.path.exists(img_dir):
        os.makedirs(img_dir)

    if config.use_tbx:
        # remove old tensorboardX logs
        logs = glob.glob(os.path.join(exp_dir, 'events.out.tfevents.*'))
        if len(logs) > 0:
            os.remove(logs[0])
        tbx_writer = SummaryWriter(exp_dir)
    else:
        tbx_writer = None

    ## initialize data loaders & model
    r_loader, z_loader = get_loader(config)
    if config.solver == 'w1':
        model = W1(config, r_loader, z_loader)
    elif config.solver == 'w2':
        model = W2(config, r_loader, z_loader)
    elif config.solver == 'bary_ot':
        model = BaryOT(config, r_loader, z_loader)
    cudnn.benchmark = True
    networks = model.get_networks(config)
    utils.print_networks(networks)

    fixed_r, fixed_z = model.get_fixed_data()
    utils.visualize_single(utils.to_data(fixed_z),
                           utils.to_data(fixed_r),
                           None,
                           os.path.join(img_dir, 'data.png'),
                           data_range=(-12,
                                       12) if config.data == '8gaussians' else
                           (-6, 6))
    if not config.no_benchmark:
        print('computing discrete-OT benchmark...')
        start_time = time.time()
        cost = model.get_cost()
        discrete_tz = utils.solve_assignment(fixed_z, fixed_r, cost,
                                             fixed_r.size(0))
        print('Done in %.4f seconds.' % (time.time() - start_time))
        utils.visualize_single(utils.to_data(fixed_z), utils.to_data(fixed_r),
                               utils.to_data(discrete_tz),
                               os.path.join(img_dir, 'assignment.png'))

    ## training
    ## stage 1 (dual stage) of bary_ot
    start_time = time.time()
    if config.solver == 'bary_ot':
        print("Starting: dual stage for %d iters." % config.dual_iters)
        for step in range(config.dual_iters):
            model.train_diter_only(config)
            if ((step + 1) % 10) == 0:
                stats = model.get_stats(config)
                end_time = time.time()
                stats['disp_time'] = (end_time - start_time) / 60.
                start_time = end_time
                utils.print_out(stats, step + 1, config.dual_iters, tbx_writer)
        print("dual stage complete.")

    ## main training loop of w1 / w2 or stage 2 (map stage) of bary-ot
    map_iters = config.map_iters if config.solver == 'bary_ot' else config.train_iters
    if config.solver == 'bary_ot':
        print("Starting: map stage for %d iters." % map_iters)
    else:
        print("Starting training...")
    for step in range(map_iters):
        model.train_iter(config)
        if ((step + 1) % 10) == 0:
            stats = model.get_stats(config)
            end_time = time.time()
            stats['disp_time'] = (end_time - start_time) / 60.
            start_time = end_time
            if not config.no_benchmark:
                if config.gen:
                    stats['l2_dist/discrete_T_x--G_x'] = losses.calc_l2(
                        fixed_z, model.g(fixed_z), discrete_tz).data.item()
                else:
                    stats['l2_dist/discrete_T_x--T_x'] = losses.calc_l2(
                        fixed_z, model.get_tx(fixed_z, reverse=True),
                        discrete_tz).data.item()
            utils.print_out(stats, step + 1, map_iters, tbx_writer)
        if ((step + 1) % 10000) == 0 or step == 0:
            images = model.get_visuals(config)
            utils.visualize_iter(
                images,
                img_dir,
                step + 1,
                config,
                data_range=(-12, 12) if config.data == '8gaussians' else
                (-6, 6))
    print("Training complete.")
    networks = model.get_networks(config)
    utils.save_networks(networks, model_dir)
예제 #14
0
    def __init__(self, args):
        # Define the network
        #####################################################
        '''
        Define the network:
        Two generators: Gab, Gba
        Two discriminators: Da, Db
        '''
        self.Gab = define_Gen(input_nc=3,
                              output_nc=3,
                              ngf=args.ngf,
                              netG=args.gen_net,
                              norm=args.norm,
                              use_dropout=not args.no_dropout,
                              gpu_ids=args.gpu_ids)
        self.Gba = define_Gen(input_nc=3,
                              output_nc=3,
                              ngf=args.ngf,
                              netG=args.gen_net,
                              norm=args.norm,
                              use_dropout=not args.no_dropout,
                              gpu_ids=args.gpu_ids)
        self.Da = define_Dis(input_nc=3,
                             ndf=args.ndf,
                             netD=args.dis_net,
                             n_layers_D=3,
                             norm=args.norm,
                             gpu_ids=args.gpu_ids)
        self.Db = define_Dis(input_nc=3,
                             ndf=args.ndf,
                             netD=args.dis_net,
                             n_layers_D=3,
                             norm=args.norm,
                             gpu_ids=args.gpu_ids)

        utils.print_networks([self.Gab, self.Gba, self.Da, self.Db],
                             ['Gab', 'Gba', 'Da', 'Db'])

        # Define loss criteria
        self.identity_criteron = nn.L1Loss()
        self.adversarial_criteron = nn.MSELoss()
        self.cycle_consistency_criteron = nn.L1Loss()

        # Define optimizers
        self.g_optimizer = torch.optim.Adam(itertools.chain(
            self.Gab.parameters(), self.Gba.parameters()),
                                            lr=args.lr,
                                            betas=(0.5, 0.999))
        self.d_optimizer = torch.optim.Adam(itertools.chain(
            self.Da.parameters(), self.Db.parameters()),
                                            lr=args.lr,
                                            betas=(0.5, 0.999))

        # Define learning rate schedulers
        self.g_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
            self.g_optimizer,
            lr_lambda=utils.LambdaLR(args.epochs, 0, args.decay_epoch).step)
        self.d_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
            self.d_optimizer,
            lr_lambda=utils.LambdaLR(args.epochs, 0, args.decay_epoch).step)

        # Try loading checkpoint
        #####################################################
        if not os.path.isdir(args.checkpoint_dir):
            os.makedirs(args.checkpoint_dir)

        try:
            ckpt = utils.load_checkpoint('%s/latest.ckpt' %
                                         (args.checkpoint_dir))
            self.start_epoch = ckpt['epoch']
            self.Da.load_state_dict(ckpt['Da'])
            self.Db.load_state_dict(ckpt['Db'])
            self.Gab.load_state_dict(ckpt['Gab'])
            self.Gba.load_state_dict(ckpt['Gba'])
            self.d_optimizer.load_state_dict(ckpt['d_optimizer'])
            self.g_optimizer.load_state_dict(ckpt['g_optimizer'])
        except:
            print(' [*] No checkpoint!')
            self.start_epoch = 0

        # Tensorboard Setup
        # current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        current_time = '20201024-102158'
        train_log_dir = 'logs/sketch2pokemon/' + current_time
        self.writer = SummaryWriter(train_log_dir)

        # Stability variables setup
        self.last_test_output = []
        self.cur_test_output = []
    def __init__(self, args):

        if args.dataset == 'voc2012':
            self.n_channels = 21
        elif args.dataset == 'cityscapes':
            self.n_channels = 20
        elif args.dataset == 'acdc':
            self.n_channels = 4

        # Define the network
        self.Gsi = define_Gen(
            input_nc=3,
            output_nc=self.n_channels,
            ngf=args.ngf,
            netG='deeplab',
            norm=args.norm,
            use_dropout=not args.no_dropout,
            gpu_ids=args.gpu_ids)  # for image to segmentation

        ### Now we put in the pretrained weights in Gsi
        ### These will only be used in the case of VOC and cityscapes
        if args.dataset != 'acdc':
            saved_state_dict = torch.load(pretrained_loc)
            new_params = self.Gsi.state_dict().copy()
            for name, param in new_params.items():
                # print(name)
                if name in saved_state_dict and param.size(
                ) == saved_state_dict[name].size():
                    new_params[name].copy_(saved_state_dict[name])
                    # print('copy {}'.format(name))
            # self.Gsi.load_state_dict(new_params)

        utils.print_networks([self.Gsi], ['Gsi'])

        ###Defining an interpolation function so as to match the output of network to feature map size
        self.interp = nn.Upsample(size=(args.crop_height, args.crop_width),
                                  mode='bilinear',
                                  align_corners=True)
        self.interp_val = nn.Upsample(size=(512, 512),
                                      mode='bilinear',
                                      align_corners=True)

        self.CE = nn.CrossEntropyLoss()
        self.activation_softmax = nn.Softmax2d()
        self.gsi_optimizer = torch.optim.Adam(self.Gsi.parameters(),
                                              lr=args.lr,
                                              betas=(0.9, 0.999))

        ### writer for tensorboard
        self.writer_supervised = SummaryWriter(tensorboard_loc + '_supervised')
        self.running_metrics_val = utils.runningScore(self.n_channels,
                                                      args.dataset)

        self.args = args

        if not os.path.isdir(args.checkpoint_dir):
            os.makedirs(args.checkpoint_dir)

        try:
            ckpt = utils.load_checkpoint('%s/latest_supervised_model.ckpt' %
                                         (args.checkpoint_dir))
            self.start_epoch = ckpt['epoch']
            self.Gsi.load_state_dict(ckpt['Gsi'])
            self.gsi_optimizer.load_state_dict(ckpt['gsi_optimizer'])
            self.best_iou = ckpt['best_iou']
        except:
            print(' [*] No checkpoint!')
            self.start_epoch = 0
            self.best_iou = -100
예제 #16
0
def test(args):

    transform = transforms.Compose([
        transforms.Resize((args.crop_height, args.crop_width)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    dataset_dirs = utils.get_testdata_link(args.dataset_dir)

    a_test_data = datasets.ImageFolder(dataset_dirs['testA'],
                                       transform=transform)
    b_test_data = datasets.ImageFolder(dataset_dirs['testB'],
                                       transform=transform)

    a_test_loader = torch.utils.data.DataLoader(a_test_data,
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                num_workers=4)
    b_test_loader = torch.utils.data.DataLoader(b_test_data,
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                num_workers=4)

    G_AtoB = define_Gen(input_nc=3,
                        output_nc=3,
                        ngf=args.ngf,
                        norm=args.norm,
                        use_dropout=not args.no_dropout,
                        gpu_ids=args.gpu_ids)
    G_BtoA = define_Gen(input_nc=3,
                        output_nc=3,
                        ngf=args.ngf,
                        norm=args.norm,
                        use_dropout=not args.no_dropout,
                        gpu_ids=args.gpu_ids)

    utils.print_networks([G_AtoB, G_BtoA], ['G_AtoB', 'G_BtoA'])

    try:
        ckpt = utils.load_checkpoint('%s/latest.ckpt' % (args.checkpoint_dir))
        G_AtoB.load_state_dict(ckpt['G_AtoB'])
        G_BtoA.load_state_dict(ckpt['G_BtoA'])
    except:
        print(' [*] No checkpoint! ')

    for i, (a_real_test,
            b_real_test) in enumerate(zip(a_test_loader, b_test_loader)):
        a_real_test = Variable(a_real_test[0], requires_grad=True)
        b_real_test = Variable(b_real_test[0], requires_grad=True)
        a_real_test, b_real_test = utils.cuda([a_real_test, b_real_test])

        G_AtoB.eval()
        G_BtoA.eval()

        with torch.no_grad():
            a_fake_test = G_BtoA(b_real_test)
            b_fake_test = G_AtoB(a_real_test)
            a_recon_test = G_BtoA(b_fake_test)
            b_recon_test = G_AtoB(a_fake_test)

        pic = (torch.cat([
            a_real_test, b_fake_test, a_recon_test, b_real_test, a_fake_test,
            b_recon_test
        ],
                         dim=0).data + 1) / 2.0

        if not os.path.isdir(args.results_dir):
            os.makedirs(args.results_dir)

        torchvision.utils.save_image(pic,
                                     args.results_dir + '/sample_' + str(i) +
                                     '.jpg',
                                     nrow=3)
예제 #17
0
Gab = define_Gen(input_nc=3,
                 output_nc=3,
                 ngf=ngf,
                 netG='resnet_9blocks',
                 norm=norm,
                 use_dropout=not no_dropout,
                 gpu_ids=gpu_ids)
Gba = define_Gen(input_nc=3,
                 output_nc=3,
                 ngf=ngf,
                 netG='resnet_9blocks',
                 norm=norm,
                 use_dropout=not no_dropout,
                 gpu_ids=gpu_ids)

utils.print_networks([Gab, Gba], ['Gab', 'Gba'])

try:
    ckpt = utils.load_checkpoint('%s/%s.ckpt' %
                                 (checkpoint_dir, checkpoint_name))
    Gab.load_state_dict(ckpt['Gab'])
    Gba.load_state_dict(ckpt['Gba'])
except:
    print(' [*] No checkpoint!')


def save_sample_image(len):
    itera = iter(a_test_loader)
    iterb = iter(b_test_loader)
    res = []
    for i in range(len):
예제 #18
0
model_dir = exp_dir +'/models'
img_dir = exp_dir +'/images'

## initialize data loaders & model
r_loader, z_loader = get_loader(config)

if config.solver == 'w1':
    model = W1(config, r_loader, z_loader)
elif config.solver == 'w2':
    model = W2(config, r_loader, z_loader)
elif config.solver == 'bary_ot':
    model = BaryOT(config, r_loader, z_loader)

cudnn.benchmark = True
networks = model.get_networks(config)
utils.print_networks(networks)

model.phi.load_state_dict(torch.load(model_dir+'/phi.pkl'))
model.phi.eval()
if config.solver == 'w1':
    model.g.load_state_dict(torch.load(model_dir+'/gen.pkl'))
    model.g.eval()

if config.solver == 'w2':
    model.eps.load_state_dict(torch.load(model_dir+'/eps.pkl'))
    model.eps.eval()
    if gen:
        model.g.load_state_dict(torch.load(model_dir+'/gen.pkl'))
        model.g.eval()

if config.solver == 'bary_ot':
예제 #19
0
    def __init__(self, args):

        if args.dataset == 'voc2012':
            self.n_channels = 21
        elif args.dataset == 'cityscapes':
            self.n_channels = 20
        elif args.dataset == 'acdc':
            self.n_channels = 4

        # Define the network
        #####################################################
        # for segmentaion to image
        self.Gis = define_Gen(input_nc=self.n_channels,
                              output_nc=3,
                              ngf=args.ngf,
                              netG='resnet_9blocks',
                              norm=args.norm,
                              use_dropout=not args.no_dropout,
                              gpu_ids=args.gpu_ids)
        # for image to segmentation
        self.Gsi = define_Gen(input_nc=3,
                              output_nc=self.n_channels,
                              ngf=args.ngf,
                              netG='resnet_9blocks_softmax',
                              norm=args.norm,
                              use_dropout=not args.no_dropout,
                              gpu_ids=args.gpu_ids)
        self.Di = define_Dis(input_nc=3,
                             ndf=args.ndf,
                             netD='pixel',
                             n_layers_D=3,
                             norm=args.norm,
                             gpu_ids=args.gpu_ids)
        self.Ds = define_Dis(
            input_nc=1,
            ndf=args.ndf,
            netD='pixel',
            n_layers_D=3,
            norm=args.norm,
            gpu_ids=args.gpu_ids)  # for voc 2012, there are 21 classes

        utils.print_networks([self.Gis, self.Gsi, self.Di, self.Ds],
                             ['Gis', 'Gsi', 'Di', 'Ds'])

        self.args = args

        self.MSE = nn.MSELoss()
        self.L1 = nn.L1Loss()
        self.CE = nn.CrossEntropyLoss()
        self.activation_softmax = nn.Softmax2d()

        ### Tensorboard writer
        self.writer_semisuper = SummaryWriter(tensorboard_loc + '_semisuper')
        self.running_metrics_val = utils.runningScore(self.n_channels,
                                                      args.dataset)

        ### For adding gaussian noise
        self.gauss_noise = utils.GaussianNoise(sigma=0.2)

        # Optimizers
        #####################################################
        self.g_optimizer = torch.optim.Adam(itertools.chain(
            self.Gis.parameters(), self.Gsi.parameters()),
                                            lr=args.lr,
                                            betas=(0.5, 0.999))
        self.d_optimizer = torch.optim.Adam(itertools.chain(
            self.Di.parameters(), self.Ds.parameters()),
                                            lr=args.lr,
                                            betas=(0.5, 0.999))

        self.g_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
            self.g_optimizer,
            lr_lambda=utils.LambdaLR(args.epochs, 0, args.decay_epoch).step)
        self.d_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
            self.d_optimizer,
            lr_lambda=utils.LambdaLR(args.epochs, 0, args.decay_epoch).step)

        # Try loading checkpoint
        #####################################################
        if not os.path.isdir(args.checkpoint_dir):
            os.makedirs(args.checkpoint_dir)

        try:
            ckpt = utils.load_checkpoint('%s/latest_semisuper_cycleGAN.ckpt' %
                                         (args.checkpoint_dir))
            self.start_epoch = ckpt['epoch']
            self.Di.load_state_dict(ckpt['Di'])
            self.Ds.load_state_dict(ckpt['Ds'])
            self.Gis.load_state_dict(ckpt['Gis'])
            self.Gsi.load_state_dict(ckpt['Gsi'])
            self.d_optimizer.load_state_dict(ckpt['d_optimizer'])
            self.g_optimizer.load_state_dict(ckpt['g_optimizer'])
            self.best_iou = ckpt['best_iou']
        except:
            print(' [*] No checkpoint!')
            self.start_epoch = 0
            self.best_iou = -100
예제 #20
0
    def __init__(self, args):

        # Define the network
        #####################################################
        self.Gab = define_Gen(input_nc=3,
                              output_nc=3,
                              ngf=args.ngf,
                              netG=args.gen_net,
                              norm=args.norm,
                              use_dropout=not args.no_dropout,
                              gpu_ids=args.gpu_ids)
        self.Gba = define_Gen(input_nc=3,
                              output_nc=3,
                              ngf=args.ngf,
                              netG=args.gen_net,
                              norm=args.norm,
                              use_dropout=not args.no_dropout,
                              gpu_ids=args.gpu_ids)
        self.Da = define_Dis(input_nc=3,
                             ndf=args.ndf,
                             netD=args.dis_net,
                             n_layers_D=3,
                             norm=args.norm,
                             gpu_ids=args.gpu_ids)
        self.Db = define_Dis(input_nc=3,
                             ndf=args.ndf,
                             netD=args.dis_net,
                             n_layers_D=3,
                             norm=args.norm,
                             gpu_ids=args.gpu_ids)

        utils.print_networks([self.Gab, self.Gba, self.Da, self.Db],
                             ['Gab', 'Gba', 'Da', 'Db'])

        # Define Loss criterias

        self.MSE = nn.MSELoss()
        self.L1 = nn.L1Loss()

        # Optimizers
        #####################################################
        self.g_optimizer = torch.optim.Adam(itertools.chain(
            self.Gab.parameters(), self.Gba.parameters()),
                                            lr=args.lr,
                                            betas=(0.5, 0.999))
        self.d_optimizer = torch.optim.Adam(itertools.chain(
            self.Da.parameters(), self.Db.parameters()),
                                            lr=args.lr,
                                            betas=(0.5, 0.999))

        self.g_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
            self.g_optimizer,
            lr_lambda=utils.LambdaLR(args.epochs, 0, args.decay_epoch).step)
        self.d_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
            self.d_optimizer,
            lr_lambda=utils.LambdaLR(args.epochs, 0, args.decay_epoch).step)

        # Try loading checkpoint
        #####################################################
        if not os.path.isdir(args.checkpoint_dir):
            os.makedirs(args.checkpoint_dir)

        try:
            ckpt = utils.load_checkpoint('%s/latest.ckpt' %
                                         (args.checkpoint_dir))
            self.start_epoch = ckpt['epoch']
            self.Da.load_state_dict(ckpt['Da'])
            self.Db.load_state_dict(ckpt['Db'])
            self.Gab.load_state_dict(ckpt['Gab'])
            self.Gba.load_state_dict(ckpt['Gba'])
            self.d_optimizer.load_state_dict(ckpt['d_optimizer'])
            self.g_optimizer.load_state_dict(ckpt['g_optimizer'])
        except:
            print(' [*] No checkpoint!')
            self.start_epoch = 0
예제 #21
0
def test(args):

    transform = transforms.Compose([
        transforms.Resize((args.crop_height, args.crop_width)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    dataset_dirs = utils.get_testdata_link(args.dataset_dir)

    a_test_data = dsets.ImageFolder(dataset_dirs['testA'], transform=transform)
    b_test_data = dsets.ImageFolder(dataset_dirs['testB'], transform=transform)

    a_test_loader = torch.utils.data.DataLoader(a_test_data,
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                num_workers=4)
    b_test_loader = torch.utils.data.DataLoader(b_test_data,
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                num_workers=4)

    Gab = define_Gen(input_nc=3,
                     output_nc=3,
                     ngf=args.ngf,
                     netG=args.gen_net,
                     norm=args.norm,
                     use_dropout=not args.no_dropout,
                     gpu_ids=args.gpu_ids)
    Gba = define_Gen(input_nc=3,
                     output_nc=3,
                     ngf=args.ngf,
                     netG=args.gen_net,
                     norm=args.norm,
                     use_dropout=not args.no_dropout,
                     gpu_ids=args.gpu_ids)

    utils.print_networks([Gab, Gba], ['Gab', 'Gba'])

    try:
        ckpt = utils.load_checkpoint('%s/latest.ckpt' % (args.checkpoint_dir))
        if len(args.gpu_ids) > 0:
            print(args.gpu_ids)
            Gab.load_state_dict(ckpt['Gab'])
            Gba.load_state_dict(ckpt['Gba'])
        else:
            new_state_dict = {}
            for k, v in ckpt['Gab'].items():
                name = k.replace('module.', '')
                new_state_dict[name] = v
            Gab.load_state_dict(new_state_dict)

            new_state_dict = {}
            for k, v in ckpt['Gba'].items():
                name = k.replace('module.', '')
                new_state_dict[name] = v
            Gba.load_state_dict(new_state_dict)
    except:
        print(' [*] No checkpoint!')
    """ run """
    a_it = iter(a_test_loader)
    b_it = iter(b_test_loader)
    for i in range(len(a_test_loader)):
        try:
            a_real_test = Variable(next(a_it)[0], requires_grad=True)
        except:
            a_it = iter(a_test_loader)

        try:
            b_real_test = Variable(next(b_it)[0], requires_grad=True)
        except:
            b_it = iter(b_test_loader)

        print(a_real_test)
        return

        a_real_test, b_real_test = utils.cuda([a_real_test, b_real_test])
        print(a_real_test.shape)

        Gab.eval()
        Gba.eval()

        with torch.no_grad():
            a_fake_test = Gab(b_real_test)
            b_fake_test = Gba(a_real_test)
            a_recon_test = Gab(b_fake_test)
            b_recon_test = Gba(a_fake_test)

        pic = (torch.cat([
            a_real_test, b_fake_test, a_recon_test, b_real_test, a_fake_test,
            b_recon_test
        ],
                         dim=0).data + 1) / 2.0

        if not os.path.isdir(args.results_dir):
            os.makedirs(args.results_dir)

        path = str.format("{}/sample-{}.jpg", args.results_dir, i)
        torchvision.utils.save_image(pic, path, nrow=3)