示例#1
0
    def __init__(self, data_loader):

        # Data loader
        self.data_loader = data_loader
        self.labels_dict = {
            0: 'airplane',
            1: 'automobile',
            2: 'bird',
            3: 'cat',
            4: 'deer',
            5: 'dog',
            6: 'frog',
            7: 'horse',
            8: 'ship',
            9: 'truck'
        }
        # exact model and loss
        self.model = model
        self.adv_loss = adv_loss

        # Model hyper-parameters
        self.imsize = imsize
        self.g_num = g_num
        self.z_dim = z_dim
        self.g_conv_dim = g_conv_dim
        self.d_conv_dim = d_conv_dim
        self.parallel = parallel

        self.d_iters = d_iters
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.pretrained_model = pretrained_model

        self.dataset = dataset

        self.model_save_path = model_save_path
        self.test_path = test_path

        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.fid_model = FID("./log_path",
                             device)  # as a var log path not a string changed

        self.build_model()

        # Start with trained model
        if self.pretrained_model:
            self.load_pretrained_model()
示例#2
0
 def __init__(self, exp_name, generator, discriminator, args, device, flog,
              logger):
     self.generator = generator
     self.discriminator = discriminator
     self.optimizer_g = torch.optim.Adam(generator.parameters(),
                                         lr=args.lr,
                                         betas=(0.5, 0.999))
     self.optimizer_d = torch.optim.Adam(discriminator.parameters(),
                                         lr=args.lr,
                                         betas=(0.5, 0.999))
     self.exp_name = exp_name
     self.args = args
     self.device = device
     self.flog = flog
     self.logger = logger
     self.fid = FID(self.args.fid_mode, self.args.category, device, 'train')
def calculate_fid(images, batch_size, sw, epoch):
    # load model
    from fid import FID
    fid_model = FID()

    # calculate statistics
    fake_mu, fake_sigma = fid_model.calculate_statistics(images, batch_size)

    # calculate FID
    fid_odd = fid_model.calculate_fid(fake_mu, fake_sigma, TARGET_FID[0], TARGET_FID[1])
    fid_real = fid_model.calculate_fid(fake_mu, fake_sigma, DATASET_FID[0], DATASET_FID[1])

    # save FID
    sw.add_scalar(f'DCGAN/FID ODD', fid_odd, epoch)
    sw.add_scalar(f'DCGAN/FID Real', fid_real, epoch)

    # clean memory
    del fid_model, images, fake_mu, fake_sigma

    return fid_real
示例#4
0
def calculate_fid(fake_images, sw, epoch, n_class):
    # load model
    from fid import FID
    fid_model = FID()

    # calculate statistics
    fake_mu, fake_sigma = fid_model.calculate_statistics(fake_images, 32)

    # calculate FID
    fid_odd = fid_model.calculate_fid(fake_mu, fake_sigma, TARGET_FID[0],
                                      TARGET_FID[1])
    fid_real = fid_model.calculate_fid(fake_mu, fake_sigma, DATASET_FID[0],
                                       DATASET_FID[1])

    # save FID
    sw.add_scalar(f'GAN {n_class}/FID ODD', fid_odd, epoch)
    sw.add_scalar(f'GAN {n_class}/FID Real', fid_real, epoch)

    # clean memory
    del fid_model, fake_images, fake_mu, fake_sigma

    return fid_real
示例#5
0
class SAGAN_test(object):
    def __init__(self, data_loader):

        # Data loader
        self.data_loader = data_loader
        self.labels_dict = {
            0: 'airplane',
            1: 'automobile',
            2: 'bird',
            3: 'cat',
            4: 'deer',
            5: 'dog',
            6: 'frog',
            7: 'horse',
            8: 'ship',
            9: 'truck'
        }
        # exact model and loss
        self.model = model
        self.adv_loss = adv_loss

        # Model hyper-parameters
        self.imsize = imsize
        self.g_num = g_num
        self.z_dim = z_dim
        self.g_conv_dim = g_conv_dim
        self.d_conv_dim = d_conv_dim
        self.parallel = parallel

        self.d_iters = d_iters
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.pretrained_model = pretrained_model

        self.dataset = dataset

        self.model_save_path = model_save_path
        self.test_path = test_path

        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.fid_model = FID("./log_path",
                             device)  # as a var log path not a string changed

        self.build_model()

        # Start with trained model
        if self.pretrained_model:
            self.load_pretrained_model()

    def test(self):

        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        # Data iterator
        data_iter = iter(self.data_loader)

        # Fixed input for debugging
        fixed_z = tensor2var(torch.randn(self.batch_size, 90))  #self.z_dim

        self.G.eval()
        fid_scores = []
        n_batches = 2
        for i in range(n_batches):

            real_images, labels = next(iter(self.data_loader))

            if i == n_batches - 1:
                if self.batch_size <= 10:
                    for l in labels:
                        print(self.labels_dict[l.item()])
                else:
                    print(
                        "Avoiding to print labels since batch size greater than 10"
                    )
            # Compute loss with real images
            real_images = tensor2var(real_images)
            labels = tensor2var(encode(labels))

            z = tensor2var(torch.randn(real_images.size(0), 90))
            fake_images, gf1, gf2 = self.G(z, labels)
            fid_score = self.fid_model.compute_fid(real_images, fake_images)
            fid_scores.append(fid_score)

        fid_score = self.fid_model.compute_fid(real_images, fake_images)
        save_image(denorm(fake_images.data), 'SAGAN_test.png')
        print("Image saved as SAGAN_test.png")
        avg_fid_score = sum(fid_scores) / len(fid_scores)
        print("Average FID_score for SA GAN, for ", n_batches, " is:",
              avg_fid_score)

    def build_model(self):
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.G = Generator(self.batch_size, self.imsize, self.z_dim,
                           self.g_conv_dim).to(device)

    def build_tensorboard(self):
        return
        #from logger import Logger
        #self.logger = Logger(self.log_path)

    def load_pretrained_model(self):
        self.G.load_state_dict(
            torch.load(
                os.path.join(self.model_save_path,
                             '{}_G.pth'.format(self.pretrained_model))))

        print('loaded trained models (step: {})..!'.format(
            self.pretrained_model))

    def save_sample(self, data_iter):
        real_images, _ = next(data_iter)
        save_image(denorm(real_images),
                   os.path.join(self.sample_path, 'real.png'))
示例#6
0
train_epoch = 5000

start_time = time.time()
epoch_start = 0
epoch_end = epoch_start + train_epoch

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
G = Generator(ngf)
D = Discriminator(ndf)
G_optimizer = optim.Adam(G.parameters(), lr=lrG, betas=(beta1, beta2))
D_optimizer = optim.Adam(D.parameters(), lr=lrD, betas=(beta1, beta2))
#loss

BCE_loss = nn.BCELoss().to(device)
L1_loss = nn.L1Loss().to(device)
fid_model = FID("./", device)
#fid_model.compute_fid(real_image,G_result) #todo add fid score

#summary writer
writer = SummaryWriter(log_dir)

#fixed noise for visualiing images
save_noise = torch.randn(1, nz, 1, 1, device=device)

#Loading the model if previously exists
if (os.path.isfile(model_dir + 'generator_param.pkl')
        and os.path.isfile(model_dir + 'discriminator_param.pkl')):

    G_checkpoint = torch.load(model_dir + 'generator_param.pkl',
                              map_location=device)
    D_checkpoint = torch.load(model_dir + 'discriminator_param.pkl',
示例#7
0
class Trainer(object):
    def __init__(self, exp_name, generator, discriminator, args, device, flog,
                 logger):
        self.generator = generator
        self.discriminator = discriminator
        self.optimizer_g = torch.optim.Adam(generator.parameters(),
                                            lr=args.lr,
                                            betas=(0.5, 0.999))
        self.optimizer_d = torch.optim.Adam(discriminator.parameters(),
                                            lr=args.lr,
                                            betas=(0.5, 0.999))
        self.exp_name = exp_name
        self.args = args
        self.device = device
        self.flog = flog
        self.logger = logger
        self.fid = FID(self.args.fid_mode, self.args.category, device, 'train')

    def load_model(self, path):
        checkpoint = torch.load(path)
        self.generator.load_state_dict(checkpoint["generator_state_dict"])
        self.discriminator.load_state_dict(
            checkpoint["discriminator_state_dict"])
        self.optimizer_g.load_state_dict(
            checkpoint["generator_optimizer_state_dict"])
        self.optimizer_d.load_state_dict(
            checkpoint["discriminator_optimizer_state_dict"])

    def compute_gradient_penalty(self, pg_node, real_samples, fake_samples):
        batch_size = real_samples.size(0)

        alpha = torch.rand(batch_size, 1, 1, 1).to(self.device)
        interp_samples = (alpha * real_samples +
                          ((1 - alpha) * fake_samples)).requires_grad_(True)

        interp_score, _, _ = self.discriminator.forward(
            pg_node, interp_samples)
        fake = torch.ones(interp_score.size()).to(self.device)

        gradients = torch.autograd.grad(
            outputs=interp_score,
            inputs=interp_samples,
            grad_outputs=fake,
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
        )[0]

        gradients = gradients.contiguous().view(batch_size, -1)
        gradient_penalty = ((gradients.pow(2).sum(dim=1) + 1e-4).sqrt() - 1)**2
        return gradient_penalty

    def train_iteration(self, dataset, data, iteration):
        num_pg = len(data[0])
        num_shape_per_pg = data[1][0].shape[0]

        # get all pg-templates
        pg_templates = []
        for i in range(num_pg):
            pg_templates.append(dataset.get_pg_template(data[0][i]))

        # train discriminator
        self.discriminator.train()
        self.generator.eval()

        self.optimizer_d.zero_grad()

        real_score = []
        fake_score = []
        gradient_penalty = []
        real_sn_score = []
        real_pn_score = []
        fake_sn_score = []
        fake_pn_score = []
        for i in range(num_pg):
            with torch.no_grad():
                zs = torch.randn(num_shape_per_pg,
                                 self.args.z_dim).to(self.device)
                fake_part_pcs = self.generator(pg_templates[i], zs).detach()
            real_part_pcs = Variable(torch.Tensor(data[1][i]).to(self.device))

            cur_real_score, cur_real_sn_score, cur_real_pn_score = self.discriminator(
                pg_templates[i], real_part_pcs)
            real_score.append(cur_real_score)
            real_sn_score.append(cur_real_sn_score)
            real_pn_score.append(cur_real_pn_score)
            cur_fake_score, cur_fake_sn_score, cur_fake_pn_score = self.discriminator(
                pg_templates[i], fake_part_pcs)
            fake_score.append(cur_fake_score)
            fake_sn_score.append(cur_fake_sn_score)
            fake_pn_score.append(cur_fake_pn_score)

            gradient_penalty.append(
                self.compute_gradient_penalty(pg_templates[i],
                                              real_part_pcs.data,
                                              fake_part_pcs.data))

        real_score = torch.cat(real_score)
        real_sn_score = torch.cat(real_sn_score)
        real_pn_score = torch.cat(real_pn_score)
        self.logger.add_scalar('real_score',
                               torch.mean(real_score).item(), iteration)

        fake_score = torch.cat(fake_score)
        fake_sn_score = torch.cat(fake_sn_score)
        fake_pn_score = torch.cat(fake_pn_score)
        self.logger.add_scalar('fake_score',
                               torch.mean(fake_score).item(), iteration)

        gradient_penalty = torch.cat(gradient_penalty)
        gradient_penalty = torch.mean(gradient_penalty)
        self.logger.add_scalar("gradient_penalty", gradient_penalty.item(),
                               iteration)

        wasserstein_estimate = torch.mean(real_score) - torch.mean(fake_score)
        self.logger.add_scalar('wasserstein_estimate',
                               wasserstein_estimate.item(), iteration)
        wasserstein_estimate_sn = torch.mean(real_sn_score) - torch.mean(
            fake_sn_score)
        self.logger.add_scalar('wasserstein_estimate_sn',
                               wasserstein_estimate_sn.item(), iteration)
        wasserstein_estimate_pn = torch.mean(real_pn_score) - torch.mean(
            fake_pn_score)
        self.logger.add_scalar('wasserstein_estimate_pn',
                               wasserstein_estimate_pn.item(), iteration)

        d_loss = self.args.loss_weight_gp * gradient_penalty - wasserstein_estimate
        self.logger.add_scalar('train_d_loss', d_loss.item(), iteration)

        d_loss.backward()
        self.optimizer_d.step()

        out_str = '  **Training DIS %s** [w_dist: %.4f] [real_scores: %.4f] [fake_scores: %.4f] [gp: %.4f]' \
                % (self.exp_name, wasserstein_estimate.item(), torch.mean(real_score).item(), torch.mean(fake_score).item(), gradient_penalty.item())
        print(out_str)
        self.flog.write(out_str + '\n')

        if iteration % self.args.n_critic == 0:
            # train generator
            self.discriminator.eval()
            self.generator.train()

            self.optimizer_g.zero_grad()

            fake_score = []
            for i in range(num_pg):
                zs = torch.randn(num_shape_per_pg,
                                 self.args.z_dim).to(self.device)
                fake_part_pcs = self.generator(pg_templates[i], zs)

                cur_fake_score, _, _ = self.discriminator(
                    pg_templates[i], fake_part_pcs)
                fake_score.append(cur_fake_score)

            fake_score = torch.cat(fake_score)

            g_loss = -torch.mean(fake_score)
            self.logger.add_scalar('train_g_loss', g_loss.item(), iteration)

            g_loss.backward()
            self.optimizer_g.step()

            out_str = '  **Training GEN %s** [fake_scores: %.4f]' \
                    % (self.exp_name, torch.mean(fake_score).item())
            print(out_str)
            self.flog.write(out_str + '\n')

    def eval_metric(self, dataset, epoch):
        self.generator.eval()

        # generate fake pcs
        with torch.no_grad():
            fake_pcs = []
            for i in range(self.args.num_fake_per_metric):
                idx = np.random.choice(len(dataset))
                pg_idx, _ = dataset[idx]
                pg_template = dataset.get_pg_template(pg_idx)
                z = torch.randn(1, self.args.z_dim).to(self.device)
                gen_part_pc = self.generator(pg_template, z)
                gen_pc = gen_part_pc.reshape(1, -1, 3)
                gen_pc_idx = furthest_point_sample(
                    gen_pc, self.args.num_point_per_shape)[0]
                gen_pc = gen_pc[0, gen_pc_idx.long()]
                gen_pc = gen_pc.cpu().detach().numpy()
                fake_pcs.append(np.expand_dims(gen_pc, 0))
            fake_pcs = np.concatenate(fake_pcs, 0)

        # compute FPD score
        fpd = self.fid.get_fid(fake_pcs)
        self.logger.add_scalar('eval_fpd', fpd, epoch)
        out_str = '##Eval Metric %s## [fpd: %.4f]' % (self.exp_name, fpd)
        print(out_str)
        self.flog.write(out_str + '\n')

    def train(self,
              train_dataset,
              train_dataloader,
              start_iteration=0,
              start_epoch=0):
        iteration = start_iteration
        for epoch in range(start_epoch, self.args.max_epochs):
            # train one epoch
            out_str = '\n %s [Epoch %03d/%03d]' % (time.asctime(
                time.localtime(time.time())), epoch, self.args.max_epochs)
            print(out_str)
            self.flog.write(out_str + '\n')

            for i, data in enumerate(train_dataloader):
                self.train_iteration(train_dataset, data, iteration)
                iteration = iteration + 1

            if (epoch + 1) % self.args.epochs_per_metric == 0:
                self.eval_metric(train_dataset, epoch)

            if (epoch + 1) % self.args.epochs_per_eval == 0:
                self.discriminator.eval()
                self.generator.eval()

                with torch.no_grad():
                    # save checkpoint
                    out_fn = os.path.join('log', self.args.exp_name,
                                          'model_%06d.ckpt' % epoch)
                    out_str = 'Saving checkpoint to %s' % out_fn
                    print(out_str)
                    self.flog.write(out_str + '\n')
                    torch.save(
                        {
                            'discriminator_state_dict':
                            self.discriminator.state_dict(),
                            'discriminator_optimizer_state_dict':
                            self.optimizer_d.state_dict(),
                            'generator_state_dict':
                            self.generator.state_dict(),
                            'generator_optimizer_state_dict':
                            self.optimizer_g.state_dict(),
                        }, out_fn)

                    # visualize current results
                    if self.args.num_visu is not None:
                        cur_visu_dir = os.path.join('log', self.args.exp_name,
                                                    'visu-%08d' % epoch)
                        os.mkdir(cur_visu_dir)
                        cur_gen_dir = os.path.join(cur_visu_dir, 'gen')
                        os.mkdir(cur_gen_dir)
                        cur_gen2_dir = os.path.join(cur_visu_dir, 'gen2')
                        os.mkdir(cur_gen2_dir)
                        cur_real_dir = os.path.join(cur_visu_dir, 'real')
                        os.mkdir(cur_real_dir)
                        cur_info_dir = os.path.join(cur_visu_dir, 'info')
                        os.mkdir(cur_info_dir)
                        print('Visualizing ...')
                        self.flog.write('Visualizing ...\n')
                        for pg_idx in self.args.visu_pg_list:
                            pg_node = train_dataset.get_pg_template(pg_idx)
                            zs = torch.randn(self.args.num_visu,
                                             self.args.z_dim).to(self.device)
                            part_pcs = self.generator(pg_node, zs)
                            shape_pcs = part_pcs.view(self.args.num_visu, -1,
                                                      3)
                            shape_pc_id1 = torch.arange(
                                self.args.num_visu).unsqueeze(1).repeat(
                                    1,
                                    self.args.num_point_per_shape).long().view(
                                        -1).to(self.device)
                            shape_pc_id2 = furthest_point_sample(
                                shape_pcs,
                                self.args.num_point_per_shape).long().view(-1)
                            shape_pcs = shape_pcs[
                                shape_pc_id1, shape_pc_id2].view(
                                    self.args.num_visu,
                                    self.args.num_point_per_shape, 3)
                            real_names, real_part_pcs = train_dataset.get_pg_real_pcs(
                                pg_idx, self.args.num_visu)
                            part_pcs = part_pcs.cpu().detach().numpy()
                            real_part_pcs = real_part_pcs.cpu().detach().numpy(
                            )
                            for pcid in range(self.args.num_visu):
                                fn = 'pg-%04d-shape-%04d' % (pg_idx, pcid)
                                render_part_pcs(
                                    [part_pcs[pcid]],
                                    title_list=['shape-%04d' % pcid],
                                    out_fn=os.path.join(
                                        cur_gen_dir, fn + '.png'))
                                export_part_pcs(os.path.join(cur_gen_dir, fn),
                                                part_pcs[pcid])
                                render_part_pcs(
                                    [real_part_pcs[pcid]],
                                    title_list=['shape-%04d' % pcid],
                                    out_fn=os.path.join(
                                        cur_real_dir, fn + '.png'))
                                export_part_pcs(os.path.join(cur_real_dir, fn),
                                                real_part_pcs[pcid])
                                cur_shape_pc = shape_pcs[pcid].cpu().detach(
                                ).numpy()
                                render_pc(
                                    os.path.join(cur_gen2_dir, fn + '.png'),
                                    cur_shape_pc)
                                export_pc(
                                    os.path.join(cur_gen2_dir, fn + '.obj'),
                                    cur_shape_pc)
                                with open(
                                        os.path.join(cur_info_dir,
                                                     fn + '.txt'),
                                        'w') as fout:
                                    fout.write('%s\n' % real_names[pcid])
                        sublist = 'gen,gen2,real,info'
                        cmd = 'cd %s && python %s . %d htmls %s %s > /dev/null' % (cur_visu_dir, \
                                os.path.join(BASE_DIR, 'gen_html_hierachy_local.py'), self.args.num_visu, sublist, sublist)
                        call(cmd, shell=True)

            self.flog.flush()
示例#8
0
# ## batch size could be bigger if you have enough GPU memory.
batch_size = args.batch_size

# ## folder
root_folder = args.case_folder
gt_folder = glob(join(root_folder, 'GT*'))[0]
recon_folder = glob(join(root_folder, 'Recon*'))[0]
refine_folder = glob(join(root_folder, 'Refine*'))[0]

gt_paths = glob(join(gt_folder, '*.tif'))
recon_paths = glob(join(recon_folder, '*.tif'))
refine_paths = glob(join(refine_folder, '*.tif'))

rc = RandomCrop(256, 0)

fid = FID(gpu_id=gpu_id, batch_size=batch_size)

gt_imgs = [io.imread(_) for _ in gt_paths]
gt_imgs = [_ / 255 for _ in gt_imgs]
gt_imgs = np.array(gt_imgs)
gt_imgs = gt_imgs[:, :, :, np.newaxis]

recon_imgs = [io.imread(_) for _ in recon_paths]
recon_imgs = [_ / 255 for _ in recon_imgs]
recon_imgs = np.array(recon_imgs)
recon_imgs = recon_imgs[:, :, :, np.newaxis]

refine_imgs = [io.imread(_) for _ in refine_paths]
refine_imgs = [_ / 255 for _ in refine_imgs]
refine_imgs = np.array(refine_imgs)
refine_imgs = refine_imgs[:, :, :, np.newaxis]