コード例 #1
0
    def get_gen_loss(self, num_images):
        fake_noise = get_noise(num_images, self.z_dim)  # 노이즈 생성
        fake = self.gen(fake_noise)  # 가짜 이미지 생성
        disc_fake_pred = self.disc(fake)
        gen_loss = self.crit(
            disc_fake_pred,
            torch.ones_like(disc_fake_pred))  # 생성한 이미지의 label은 1

        return gen_loss
コード例 #2
0
    def get_disc_loss(self, real, num_images):
        fake_noise = get_noise(num_images, self.z_dim)  # 노이즈 생성
        fake = self.gen(fake_noise)  # 가짜 이미지 생성

        disc_fake_pred = self.disc(fake.detach())  # 왜 여기만 detach() 했지?
        disc_fake_loss = self.crit(
            disc_fake_pred,
            torch.zeros_like(disc_fake_pred))  # fake image의 label은 0

        disc_real_pred = self.disc(real)
        disc_real_loss = self.crit(
            disc_real_pred,
            torch.ones_like(disc_real_pred))  # real image의 label은 1

        disc_loss = (
            disc_fake_loss + disc_real_loss
        ) / 2  # discriminator loss 는 disc_fake_loss 와 disc_real_loss 의 평균

        return disc_loss
コード例 #3
0
    def _train_step(self, data):
        netG = self.train_model["netG"]
        optimizerG = self.train_model["optimizerG"]
        netD = self.train_model["netD"]
        optimizerD = self.train_model["optimizerD"]
        criterion = self.train_model["criterion"]
        device = self.config["device"]

        real_data = data[0].to(device)

        noise = model.get_noise(real_data, self.config)
        fake_data = netG(noise)
        label = model.get_label(real_data, self.config)

        errD, D_x, D_G_z1 = model.get_Discriminator_loss(
            netD, optimizerD, real_data, fake_data.detach(), label, criterion,
            self.config)
        errG, D_G_z2 = model.get_Generator_loss(netG, netD, optimizerG,
                                                fake_data, label, criterion,
                                                self.config)

        return errD, errG, D_x, D_G_z1, D_G_z2
コード例 #4
0
ファイル: train.py プロジェクト: normalct/vcae
def train_model(model, dataset, ds_name,
                epochs=10,
                batch_size=32,
                sample_size=32,
                eval_size=32,
                img_size=32,
                lr=1e-3,
                weight_decay=1e-4,
                loss_log_interval=20,
                image_log_interval=20,
                model_log_interval=20,
                checkpoint_dir='./checkpoints',
                resume=False,
                cuda=False,
                seed=0,
                device=None,
                cores=1):
    if resume:
        epoch_start = utils.load_checkpoint(model, checkpoint_dir)
    else:
        epoch_start = 0

    fixed_noise = torch.rand(sample_size, model.z_size).to(device)

    if model.model_name in ['vae', 'cvae', 'vae2', 'cvae2']:
        m = dist.Normal(torch.Tensor([0.0]).to(device), torch.Tensor([1.0]).to(device))
        fixed_noise = m.icdf(fixed_noise)

    output_folder = './results/' + ds_name
    resfile_prefix = ds_name + "_" + \
                     model.model_name + \
                     "_ld_" + \
                     str(model.z_size) + \
                     "_bs_" + str(batch_size)

    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    data_root = './datasets'

    if model.model_name in ['dec_vine', 'dec_vine2', 'dec_vine3']:

        # load pre-trained AE
        if model.model_name == 'dec_vine':
            pretrain_prefix = resfile_prefix.replace("dec", "ae")
        elif model.model_name == 'dec_vine2':
            pretrain_prefix = resfile_prefix.replace("dec_vine2", "ae_vine2")
        elif model.model_name == 'dec_vine3':
            pretrain_prefix = resfile_prefix.replace("dec_vine3", "ae_vine3")
        
        pretrain_files = [filename for filename in os.listdir(checkpoint_dir) if filename.startswith(pretrain_prefix)]
        pretrain_epochs = [int(filename.replace(pretrain_prefix + "_", "")) for filename in pretrain_files]
        pretrain_path = os.path.join(checkpoint_dir, pretrain_files[pretrain_epochs.index(max(pretrain_epochs))])
        model.pretrain(pretrain_path)

        # form initial cluster centres
        data_loader = utils.get_data_loader(dataset, batch_size, cuda=cuda)
        data_stream = tqdm(enumerate(data_loader, 1))
        features = []
        for batch_index, (x, _, _) in data_stream:

            tmp_x = Variable(x).to(device)
            if model.model_name == 'dec_vine':
                z = model.ae.encoder(tmp_x)
                z = model.ae.q(z)
            elif model.model_name == 'dec_vine2' or  model.model_name == 'dec_vine3':
                z = torch.nn.functional.relu(model.ae.fc1(model.ae.encoder(tmp_x).view(x.size(0), -1)))
                z = model.ae.fc21(z)

            features.append(z)

        kmeans = KMeans(n_clusters=model.cluster_number, n_init=20)
        y_pred = kmeans.fit_predict(torch.cat(features).detach().cpu().numpy())
        model.cluster_layer.data = torch.tensor(kmeans.cluster_centers_).to(device)

    pretrain=0
    if  pretrain==1 and model.model_name == 'ae_vine3':
        pretrain_prefix = resfile_prefix#ds_name + '_ae_vine3'
        pretrain_files = [filename for filename in os.listdir(checkpoint_dir) if filename.startswith(pretrain_prefix)]
        pretrain_epochs = [int(filename.replace(pretrain_prefix + "_", "")) for filename in pretrain_files]
        pretrain_path = os.path.join(checkpoint_dir, pretrain_files[pretrain_epochs.index(max(pretrain_epochs))])
        pretrained_ae = torch.load(pretrain_path, map_location=device)
        model.load_state_dict(pretrained_ae['state'])
        print('load pretrained ae3 from', pretrain_path)


    # reconstruction_criterion = torch.nn.BCELoss()
    reconstruction_criterion = torch.nn.BCELoss(size_average=False)
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

    if model.model_name == 'gan':
        lr_g = lr_d = 0.0002
        k = 1
        fix_noise = get_noise(sample_size)
        opt_g = torch.optim.Adam(model.net_g.parameters(), lr=lr_g, betas=(0.5, 0.999))  # optimizer for Generator
        opt_d = torch.optim.Adam(model.net_d.parameters(), lr=lr_d, betas=(0.5, 0.999))  # optimizer for Discriminator

    for epoch in range(epoch_start, epochs + 1):
        print("Epoch {}".format(epoch))
        if model.model_name == "dec_vine" or model.model_name == "dec_vine2":
            # update target distribution p
            model.eval()
            p = []
            indices = []
            data_loader = utils.get_data_loader(dataset, batch_size, cuda=cuda)
            data_stream = tqdm(enumerate(data_loader, 1))

            for batch_index, (x, _, idx) in data_stream:
                tmp_x = Variable(x).to(device)
                _, tmp_p = model(tmp_x)
                p.append(tmp_p.detach().cpu())
                tmp_idx = idx
                indices.append(tmp_idx)

            p = torch.cat(p)
            indices = torch.cat(indices)
            p = model.target_distribution(p[indices])
            p = Variable(p).to(device)

        model.train()
        data_loader = utils.get_data_loader(dataset, batch_size, cuda=cuda)
        data_stream = tqdm(enumerate(data_loader, 1))

        for batch_index, (x, _, idx) in data_stream:
            
            # learning rate decay
            if  model.model_name == 'gan' and (epoch) == 8:# and dataset == "CelebA":
                    opt_g.param_groups[0]['lr'] /= 10
                    opt_d.param_groups[0]['lr'] /= 10
                    #print("learning rate change!")

            if model.model_name == 'gan' and (epoch) == 15: # and dataset == "CelebA":
                    opt_g.param_groups[0]['lr'] /= 10
                    opt_d.param_groups[0]['lr'] /= 10
                    #print("learning rate change!")            

            iteration = (epoch - 1) * (len(dataset) // batch_size) + batch_index
            x = Variable(x).to(device)
            idx = Variable(idx).to(device)


            if model.model_name == 'gan':
                # train Discriminator
                real_data = Variable(x.cuda())
                #print(real_data.shape)
                prob_fake = model.net_d(model.net_g(get_noise(real_data.size(0)).to(device)))
                prob_real = model.net_d(real_data)

                loss_d = - torch.mean(torch.log(prob_real) + torch.log(1 - prob_fake))

                opt_d.zero_grad()
                loss_d.backward()
                opt_d.step()

                # train Generator
                if batch_index % k is 0:
                    prob_fake = model.net_d(model.net_g(get_noise().to(device)))

                    loss_g = - torch.mean(torch.log(prob_fake))

                    opt_g.zero_grad()
                    loss_g.backward()
                    opt_g.step()

            else:

                if model.model_name == 'ae_vine' or model.model_name == 'ae_vine2' or model.model_name == 'ae_vine3':
                    x_reconstructed = model(x)

                elif model.model_name == 'dec_vine' or model.model_name == 'dec_vine2':
                    x_reconstructed, q = model(x)
                    p_batch = p[idx]
                    penalization_loss = 10*F.kl_div(q.log(), p_batch)
                    del p_batch, q

                elif model.model_name == 'cvae' or model.model_name == "cvae2" or model.model_name=="cvae3":
                    (mean, logvar, atanhcor), x_reconstructed = model(x)
                    penalization_loss = model.kl_divergence_loss(mean, logvar, atanhcor)

                elif model.model_name == 'vae' or model.model_name == "vae2" or model.model_name=="vae3":
                    (mean, logvar), x_reconstructed = model(x)
                    penalization_loss = model.kl_divergence_loss(mean, logvar)

                reconstruction_loss = reconstruction_criterion(x_reconstructed, x) / x.size(0)

                if model.model_name == 'ae_vine' or model.model_name == 'ae_vine2' or model.model_name == 'ae_vine3':
                    loss = reconstruction_loss
                else:
                    loss = reconstruction_loss + penalization_loss

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            if iteration % loss_log_interval == 0:

                f = open(output_folder + "/" + resfile_prefix + "_losses" + ".txt", 'a')

                if model.model_name == 'gan':
                    f.write("\n{:<12} | {} | {} | {} | {} ".format(
                        model.model_name,
                        iteration,
                        loss_g,
                        loss_d,
                        seed
                    ))
                    '''                   
                    print("\n{:<12} | {} | {} | {} | {} ".format(
                        model.model_name,
                        iteration,
                        loss_g,
                        loss_d,
                        seed
                    ))
                    '''
                else:
                    if model.model_name == 'ae_vine' or model.model_name == 'ae_vine2' or model.model_name == 'ae_vine3':
                        f.write("\n{:<12} | {} | {} | {} ".format(
                            model.model_name,
                            iteration,
                            loss,
                            seed
                        ))

                    else:
                        f.write("\n{:<12} | {} | {} | {} | {} | {}".format(
                            model.model_name,
                            iteration,
                            reconstruction_loss.data.item(),
                            penalization_loss.data.item(),
                            loss.data.item(),
                            seed
                        ))

                f.close()

            # adding this just to have a way of calculating the scores at 0 epochs
            if batch_index > 0 and epoch == 0:
            	break

        if epoch % model_log_interval == 0:
            print()
            print('###################')
            print('# model checkpoint!')
            print('###################')
            print()
            utils.save_checkpoint(model, checkpoint_dir, epoch, resfile_prefix + "_" + str(epoch))

        if epoch % image_log_interval == 0:

            print()
            print('###################')
            print('# image checkpoint!')
            print('###################')
            print()

            model.eval()

            ae_vine_models = ['ae_vine', 'ae_vine2', 'dec_vine', 'dec_vine2', 'ae_vine3', 'dec_vine3']

            if model.model_name in ae_vine_models:

                data_loader_vine = utils.get_data_loader(dataset, 5000, cuda=cuda)
                data_stream_vine = tqdm(enumerate(data_loader_vine, 1))
                features = []

                for batch_index, (x, _, _) in data_stream_vine:

                    tmp_x = Variable(x).to(device)
                    if model.model_name == 'ae_vine':
                        encoded = model.encoder(tmp_x)
                        e = model.q(encoded)

                    elif model.model_name == 'dec_vine':
                        encoded = model.ae.encoder(tmp_x)
                        e = model.ae.q(encoded)

                    elif model.model_name == 'ae_vine2':
                        encoded = torch.nn.functional.relu(model.fc1(model.encoder(tmp_x).view(x.size(0), -1)))
                        e = model.fc21(encoded)

                    elif model.model_name == 'dec_vine2':
                        encoded = torch.nn.functional.relu(model.ae.fc1(model.ae.encoder(tmp_x).view(x.size(0), -1)))
                        e = model.ae.fc21(encoded)


                    elif model.model_name == 'ae_vine3':
                        encoded = F.relu(model.fc1(model.encoder(tmp_x).view(x.size(0), -1)))
                        e = model.fc21(encoded)

                    elif model.model_name == 'dec_vine3':
                        encoded = F.relu(model.ae.fc1(model.ae.encoder(tmp_x).view(x.size(0), -1)))
                        e = model.ae.fc21(encoded)
                    features.append(e.detach().cpu())
                    if batch_index > 0:
                        break

                features = torch.cat(features).numpy()
                #np.savetxt(resfile_prefix + '_features' + str(epoch) + '_.csv', features, delimiter=",")
                copula_controls = base.list(family_set="tll", trunc_lvl=5, cores=cores)
                vine_obj = rvinecop.vine(features, copula_controls=copula_controls)

                model.vine = vine_obj

                fake = model.sample(sample_size, vine_obj, fixed_noise)

                del x, e, encoded, vine_obj,data_loader_vine

            elif model.model_name == 'gan':
                fake = model.net_g(fix_noise.to(device)).data.cpu() #+ 0.5
                print(fake.shape)
            else:

                fake = model.sample(sample_size, fixed_noise)

            fake = fake.reshape(sample_size, model.channel_num,
                                model.image_size, model.image_size)
            name_str = resfile_prefix + '_fake_samples_epoch'
            vutils.save_image(fake.detach(),
                              '%s/%s_%03d.png' % (output_folder, name_str, epoch),
                              normalize=True)
            del fake
        
        if epoch % 10 == 0: 
            
            s = metric.compute_score_raw(ds_name, dataset, img_size, data_root,
                                         eval_size, batch_size,
                                         output_folder + '/real/',
                                         output_folder + '/fake/',
                                         model, model.z_size, 'resnet34', device)

            f = open(output_folder + "/" + resfile_prefix + "_scores" + ".txt", 'a')

            scr_arr = [str(a) for a in s]
            f.write("\n{:<12} | {} | {} | {}".format(
                model.model_name,
                epoch,
                ', '.join(scr_arr),
                seed
            ))

            f.close()
コード例 #5
0
ファイル: train.py プロジェクト: pzarker/conditional_tsgan
# Initialize loss functions
gen_loss_fn = GeneratorLoss()
crit_loss_fn = DiscriminatorLoss()

gen_loss_history = []
crit_loss_history = []

n_plot_samples = 6

start = time.time()
for epoch in range(n_epochs):
    for real in tqdm(dataloader):
        cur_batch_size, seq_len, _ = real.shape
        real = real.to(device)

        fake_noise = get_noise(cur_batch_size, seq_len, z_dim, device=device)
        fake = gen(fake_noise)

        # Update discriminator
        if cur_step % n_rounds_g_per_d == 0:
            crit_opt.zero_grad()
            crit_loss = crit_loss_fn(real, fake, crit)
            crit_loss.backward(retain_graph=True)
            crit_opt.step()
            crit_loss_meter.add(crit_loss.item())

        # Update generator
        gen_opt.zero_grad()
        gen_loss = gen_loss_fn(fake, crit)
        gen_loss.backward()
        ''' Vanishing gradient problem?
コード例 #6
0
ファイル: metric.py プロジェクト: tagas/vcae
def sample_fake(model, nz, sample_size, batch_size, save_folder, device="cpu"):
    print('sampling fake images ...')
    save_folder = save_folder + '/0/'
    try:
        os.makedirs(save_folder)
    except OSError:
        pass

    ae_vine_models = [
        'ae_vine', 'ae_vine2', 'dec_vine', 'dec_vine2', 'ae_vine3', 'dec_vine3'
    ]

    if model.model_name == 'gan':

        iter = 0
        for i in range(0, 1 + sample_size):
            noise = get_noise(1).to(device)
            fake = 0.5 * model.sample(noise) + 0.5
            fake = fake.reshape(1, model.channel_num, model.image_size,
                                model.image_size)

            for j in range(0, len(fake.data)):
                if iter < sample_size:
                    vutils.save_image(fake.data[j],
                                      save_folder + give_name(iter) + ".png",
                                      normalize=True)
                iter += 1
                if iter >= sample_size:
                    break

            del fake

    else:
        if model.model_name in ae_vine_models:
            fake = model.sample(sample_size, model.vine)
            fake = fake.reshape(sample_size, model.channel_num,
                                model.image_size, model.image_size)
            iter = 0
            for j in range(0, len(fake.data)):
                vutils.save_image(fake.data[j],
                                  save_folder + give_name(iter) + ".png",
                                  normalize=True)
                iter += 1

        else:
            noise = torch.FloatTensor(batch_size, nz, 1, 1).to(device)
            iter = 0
            for i in range(0, 1 + sample_size // batch_size):
                noise.data.normal_(0, 1)

                fake = model.sample(sample_size)
                fake = fake.reshape(sample_size, model.channel_num,
                                    model.image_size, model.image_size)

                for j in range(0, len(fake.data)):
                    if iter < sample_size:
                        vutils.save_image(fake.data[j],
                                          save_folder + give_name(iter) +
                                          ".png",
                                          normalize=True)
                    iter += 1
                    if iter >= sample_size:
                        break

                del fake
コード例 #7
0
import model, data, train, config
from PIL import Image
import matplotlib.pyplot as plt
import torch
from torchvision.utils import make_grid


# function to demonstrate the anime images from tensor
def save_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.savefig(f"./img_result/testing.jpg")
    plt.show()


if __name__ == "__main__":
    # specify the parameters
    z_dim = 64
    size = 128
    # generate random anime images
    test_noise = model.get_noise(z_dim, z_dim)
    gen = model.Generator(z_dim)
    path = f"./weights/{config.gen_save_name}"
    gen.load_state_dict(torch.load(path))
    unloader = data.transforms.ToPILImage()
    save_tensor_images(gen(test_noise))
コード例 #8
0
ファイル: train.py プロジェクト: Savio666/anime
 gen = gen.apply(weights_init)
 crit = crit.apply(weights_init)
 mean_generator_loss = 0
 cur_step = 0
 generator_losses = []
 critic_losses = []
 # training the model
 for e in range(config.n_epochs):
     for real, _ in tqdm(data.dataloader):
         size = len(real)
         real = real.to(config.device)
         mean_iteration_critic_loss = 0
         for _ in range(config.crit_repeats):
             crit_opt.zero_grad()
             noise = model.get_noise(size,
                                     config.z_dim,
                                     device=config.device)
             fake = gen(noise)
             fake_pred = crit(fake.detach())
             real_pred = crit(real)
             epsilon = torch.rand(len(real),
                                  1,
                                  1,
                                  1,
                                  device=config.device,
                                  requires_grad=True)
             gradient = get_gradient(crit, real, fake.detach(), epsilon)
             gp = gradient_penalty(gradient)
             crit_loss = get_crit_loss(fake_pred, real_pred, gp,
                                       config.c_lambda)
             mean_iteration_critic_loss += crit_loss.item(
コード例 #9
0
    def train(self):
        cur_step = 0
        mean_generator_loss = 0
        mean_discriminator_loss = 0
        test_generator = True  # Whether the generator should be tested
        gen_loss = False

        for epoch in range(self.n_epochs):
            for real, _ in tqdm(self.dataloader):
                cur_batch_size = len(real)  # 현재 배치 사이즈 저장

                # real을 flatten 하게 만든다.
                real = real.view(cur_batch_size, -1).to(self.device)

                ### Update discriminator ###
                self.disc_opt.zero_grad()
                disc_loss = self.get_disc_loss(real, cur_batch_size)
                disc_loss.backward(retain_graph=True)  # retain_graph 알아보기
                self.disc_opt.step()

                # For testing purposes, to keep track of the generator weights
                if test_generator:
                    old_generator_weights = self.gen.gen[0][0].weight.detach(
                    ).clone()
                    # generator의 가장 처음 weights 를 저장한다.

                ### Update generator ###
                self.gen_opt.zero_grad()
                gen_loss = self.get_gen_loss(cur_batch_size)
                gen_loss.backward()
                self.gen_opt.step()

                # For testing purposes, to check that your code changes the generator weights
                if test_generator:
                    assert torch.any(self.gen.gen[0][0].weight.detach().clone(
                    ) != old_generator_weights)
                    # torch.any(input) : input에는 Tensor가 들어간다. input 안의 수(or bool) 중에서 True가 하나 이상 있으면 tensor(True)를 반환한다.
                    # 나의 코드가 generator weights를 변화시켰는지 확인한다.

                # 딕셔너리는 items() 함수를 사용하면 딕셔너리에 있는 키와 값들의 쌍을 얻을 수 있다.
                # >>> car = {"name": "BMW", "price": "8000"}
                # >>> car.items()
                # >>> dict_items([('name', 'BMW'), ('price', '7000')])

                # Keep track of the average discriminator loss
                mean_discriminator_loss += disc_loss.item() / self.display_step

                # Keep track of the average generator loss
                mean_generator_loss += gen_loss.item() / self.display_step

                ### Visualization code ###
                if cur_step % self.display_step == 0 and cur_step > 0:
                    print(
                        f"Step {cur_step}: Generator loss: {mean_generator_loss}, discriminator loss: {mean_discriminator_loss}"
                    )

                    fake_noise = get_noise(cur_batch_size, self.z_dim)
                    fake = self.gen(fake_noise)
                    save_tensor_images("fake_" + str(cur_step), fake)
                    save_tensor_images("real_" + str(cur_step), real)
                    mean_generator_loss = 0
                    mean_discriminator_loss = 0

                cur_step += 1