示例#1
0
def train(epoch):
    critic.train()
    generator.train()
    for batch_idx, (input_data, label) in enumerate(train_loader):
        if batch_idx > k:
            break
        critic_optimizer.zero_grad()
        input_data = ((input_data - 0.5) * 2).to(args.device)
        noise = torch.rand(len(input_data), latent_size)
        noise = noise.to(args.device)
        fake = generator(noise)
        g_loss = l * model.gradient_penalty(input_data, fake, critic)
        f_loss = critic(fake.detach()).mean()
        t_loss = -critic(input_data).mean()
        critic_loss = f_loss + t_loss + g_loss
        critic_loss.backward()
        critic_optimizer.step()
    writer.add_scalars('critic_loss_detail', {
        'g loss': g_loss,
        'f loss': f_loss,
        't loss': t_loss
    }, epoch)
    writer.add_scalar('critic loss', critic_loss, epoch)
    generator_optimizer.zero_grad()
    noise = torch.rand(args.batch_size, latent_size)
    noise = noise.to(args.device)
    fake = generator(noise)
    generator_loss = -critic(fake).mean()
    generator_loss.backward()
    generator_optimizer.step()
    writer.add_scalar('gen loss', generator_loss, epoch)
    return critic_loss, generator_loss
示例#2
0
def train():

    if len(args.gpu_idx) > 1:
        multi_gpu = True
        gpu_list = [int(i) for i in args.gpu_idx.split(',')]
    else:
        multi_gpu = False

    writer = SummaryWriter(log_dir=log_path)

    if args.dataset == 'cifar10':
        dataset = datasets.CIFAR10(args.data_dir,
                                   download=True,
                                   transform=transforms.Compose([
                                       transforms.Resize(args.imsize),
                                       transforms.ToTensor(),
                                       transforms.Normalize([0.5, 0.5, 0.5],
                                                            [0.5, 0.5, 0.5])
                                   ]))
    elif args.dataset == 'mnist':
        dataset = datasets.MNIST(args.data_dir,
                                 train=True,
                                 download=True,
                                 transform=transforms.Compose([
                                     transforms.Scale(args.imsize),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5),
                                                          (0.5, 0.5, 0.5))
                                 ]))
        args.in_dim = 1
        args.out_dim = 1
    else:
        dataset = datasets.ImageFolder(args.data_dir,
                                       transform=transforms.Compose([
                                           transforms.Resize(args.imsize),
                                           transforms.ToTensor(),
                                           transforms.Normalize(
                                               (0.5, 0.5, 0.5),
                                               (0.5, 0.5, 0.5))
                                       ]))

    dataloader = torch.utils.data.DataLoader( dataset , batch_size = args.batch , \
                                             shuffle = True , num_workers = args.worker)

    #device = torch.device()

    generator = model.Generator(args)
    discriminator = model.Discriminator(args)

    generator.apply(weights_init)
    discriminator.apply(weights_init)

    gan_criterion = nn.BCELoss()
    aux_criterion = nn.CrossEntropyLoss()

    #input_noise = torch.from_numpy( np.random.normal(0,1,[args.batch , args.dim_embed]) )
    #input_label = torch.from_numpy( np.random.randint(0,args.num_class, [args.batch,1 ]) )

    if args.l_smooth:
        # training strategy stated in improved GAN
        real_label = 0.9
        fake_label = 0.1
    else:
        real_label = 1.0
        fake_label = 0.0

    step = 0

    if args.gpu:

        # acutally do nothing?  because bce and cce don't have paramters
        gan_criterion = gan_criterion.cuda()
        aux_criterion = aux_criterion.cuda()

        generator = generator.cuda()
        discriminator = discriminator.cuda()

        if multi_gpu:
            print('multi gpu')
            generator = nn.DataParallel(generator, device_ids=gpu_list)
            discriminator = nn.DataParallel(discriminator, device_ids=gpu_list)

    opt_d = optim.Adam(discriminator.parameters(),
                       lr=args.lr,
                       betas=(args.beta1, 0.999))
    opt_g = optim.Adam(generator.parameters(),
                       lr=args.lr,
                       betas=(args.beta1, 0.999))

    if os.path.isfile(os.path.join(ckpt_path, args.run_name + '.ckpt')):
        print('found ckpt file' +
              os.path.join(ckpt_path, args.run_name + '.ckpt'))
        ckpt = torch.load(os.path.join(ckpt_path, args.run_name + '.ckpt'))
        generator.load_state_dict(ckpt['generator'])
        discriminator.load_state_dict(ckpt['discriminator'])
        opt_d.load_state_dict(ckpt['opt_d'])
        opt_g.load_state_dict(ckpt['opt_g'])
        step = ckpt['step']

    for i in range(args.epoch):

        for j, data in enumerate(dataloader):

            images, labels = data[0], data[1]

            batch = images.shape[0]

            input_noise = torch.from_numpy(
                np.random.normal(0, 1,
                                 [batch, args.dim_embed]).astype(np.float32))
            input_label = torch.from_numpy(
                np.random.randint(0, args.num_class, [batch]))

            input_noise = np.random.normal(
                0, 1, [batch, args.dim_embed]).astype(np.float32)
            class_onehot = np.zeros((batch, args.num_class))
            class_onehot[np.arange(batch), input_label] = 1
            input_noise[np.arange(batch), :args.num_class] = class_onehot[
                np.arange(batch)]
            input_noise = torch.from_numpy(input_noise)

            real_target = torch.full((batch, 1), real_label)
            fake_target = torch.full((batch, 1), fake_label)
            aux_target = torch.autograd.Variable(labels)

            if args.gpu:
                input_noise = input_noise.cuda()
                input_label = input_label.cuda()
                images = images.cuda()
                labels = labels.cuda()
                real_target = real_target.cuda()
                fake_target = fake_target.cuda()
                aux_target = aux_target.cuda()

            # train generator
            # 好像不call也沒關係
            opt_g.zero_grad()

            fake = generator(input_noise, input_label)
            gan_out_g, aux_out_g = discriminator(fake)

            if args.wgan:
                gan_loss_g = -torch.mean(gan_out_g)
            else:
                gan_loss_g = gan_criterion(gan_out_g, real_target)

            aux_loss_g = aux_criterion(aux_out_g, input_label)

            g_loss = (gan_loss_g + args.aux_weight * aux_loss_g) * 0.5
            g_loss.backward()

            #opt_g.step()
            opt_g.step()

            # train discriminator with real samples
            opt_d.zero_grad()

            gan_out_r, aux_out_r = discriminator(images)

            if args.wgan:
                gan_loss_r = -torch.mean(gan_out_r)
            else:
                gan_loss_r = gan_criterion(gan_out_r, real_target)

            aux_loss_r = aux_criterion(aux_out_r, aux_target)
            d_real_loss = (gan_loss_r + args.aux_weight * aux_loss_r) / 2.0

            # train discriminator with fake samples
            #fake = generator(input_noise , input_label).detach()
            gan_out_f, aux_out_f = discriminator(fake.detach())

            if args.wgan:
                gan_loss_f = torch.mean(gan_out_f)
            else:
                gan_loss_f = gan_criterion(gan_out_f, fake_target)

            aux_loss_f = aux_criterion(aux_out_f, input_label)
            d_fake_loss = (gan_loss_f + args.aux_weight * aux_loss_f) / 2.0

            if args.wgan and args.gp:
                gp = model.gradient_penalty(discriminator, images, fake,
                                            args.num_class, args.gpu)
                d_loss = 0.5 * (d_real_loss +
                                d_fake_loss) + args.gp_weight * gp

            else:
                d_loss = (d_real_loss + d_fake_loss) / 2.0

            d_loss.backward()

            opt_d.step()

            step = step + 1
            if step % 100 == 0:
                writer.add_scalar('losses/g_loss', g_loss, step)
                writer.add_scalar('losses/d_loss', d_loss, step)
                grid = vutils.make_grid(fake.detach(), normalize=True)
                writer.add_image('generated', grid, step)

            if args.wgan and not args.gp:
                for p in discriminator.parameters():
                    p.data.clamp_(-args.clip, args.clip)

            if step % args.save_freq == 0:
                torch.save(
                    {
                        'step': step,
                        'generator': generator.state_dict(),
                        'discriminator': discriminator.state_dict(),
                        'opt_d': opt_d.state_dict(),
                        'opt_g': opt_g.state_dict()
                    }, os.path.join(ckpt_path, args.run_name + '.ckpt'))

            if step % args.sample_freq == 0:
                sample_generator(generator, input_noise, input_label, step)

            pred = np.concatenate(
                [aux_out_r.data.cpu().numpy(),
                 aux_out_f.data.cpu().numpy()],
                axis=0)
            gt = np.concatenate(
                [labels.data.cpu().numpy(),
                 input_label.data.cpu().numpy()],
                axis=0)
            d_acc = np.mean(np.argmax(pred, axis=1) == gt)


            print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %d%%] [G loss: %f]" % \
                   (i, args.epoch , j, len(dataloader),
                    d_loss.item(), 100.0 * d_acc,
                    g_loss.item()))
示例#3
0
        step = ep * len(train_loader) + i + 1
        D.train()
        G.train()

        # train D
        x = x.to(device)
        z = torch.randn(batch_size, z_dim).to(device)
        c = torch.tensor(np.eye(c_dim)[c_dense.cpu().numpy()],
                         dtype=z.dtype).to(device)

        x_f = G(z, c).detach()
        x_gan_logit = D(x, c)
        x_f_gan_logit = D(x_f, c)

        d_x_gan_loss, d_x_f_gan_loss = d_loss_fn(x_gan_logit, x_f_gan_logit)
        gp = model.gradient_penalty(D, x, x_f, mode=gp_mode)
        d_loss = d_x_gan_loss + d_x_f_gan_loss + gp * gp_coef

        D.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        writer.add_scalar('D/d_gan_loss',
                          (d_x_gan_loss + d_x_f_gan_loss).data.cpu().numpy(),
                          global_step=step)
        writer.add_scalar('D/gp', gp.data.cpu().numpy(), global_step=step)

        # train G
        if step % n_d == 0:
            z = torch.randn(batch_size, z_dim).to(device)
示例#4
0
        xr = torch.from_numpy(xr).cuda()  # (32, 1, 24, 24)
        yr = torch.from_numpy(yr).cuda()  # (32, 5)

        # xr.requires_grad_()
        # yr = torch.tensor(yr, dtype=torch.float)
        predr = D(xr, yr)  # (32, 1)
        lossr = -predr.mean()

        zx = torch.randn((32, 100), dtype=torch.float).cuda()
        zx = torch.cat([zx, yr], dim=1)  # (32, 105)
        # zx = torch.tensor(zx, dtype=torch.float)
        xf = G(zx).detach()  # xf (32, 1, 24, 24)
        predf = D(xf, yr)
        lossf = predf.mean()

        gp = gradient_penalty(D, xr, xf.detach(), yr)
        loss_D = lossr + lossf + 10 * gp

        optim_D.zero_grad()
        loss_D.backward()
        optim_D.step()

    z = torch.randn((32, 100), dtype=torch.float).cuda()
    x, y = next(data_iter)
    # y = torch.tensor(y, dtype=torch.float)
    y = torch.from_numpy(y).cuda()
    zx = torch.cat([z, y], dim=1)
    # zx = torch.tensor(zx, dtype=torch.float)
    xf = G(zx)

    predf = D(xf, y)
示例#5
0
# d loss
d_r_loss, d_f_loss = d_loss_fn(r_logit, f_logit)
d_f_tree_losses = tree_loss_fn(f_c_logit, c, mask)
if att != '':
    d_r_tree_losses = tree_loss_fn(r_c_logit, c, mask)
    start = 1 if half_acgan else 0
    d_tree_loss = sum([
        d_f_tree_losses[i] * lambdas[i] * layer_mask[i]
        for i in range(start, len(lambdas))
    ])
    d_tree_loss += d_r_tree_losses[0] * lambdas[0] * layer_mask[0]
else:
    d_tree_loss = sum(
        [d_f_tree_losses[i] * lambdas[i] for i in range(len(lambdas))])
gp = model.gradient_penalty(D, real, fake, gp_mode)
d_loss = d_r_loss + d_f_loss + d_tree_loss + gp * 10.0

# g loss
g_f_loss = g_loss_fn(f_logit)
g_f_tree_losses = tree_loss_fn(f_c_logit, c, mask)
g_tree_loss = sum([
    g_f_tree_losses[i] * lambdas[i] * layer_mask[i]
    for i in range(len(lambdas))
])
g_loss = g_f_loss + g_tree_loss

# optims
d_step = optim(learning_rate=lr_d).minimize(
    d_loss, var_list=tl.trainable_variables(includes='D'))
g_step = optim(learning_rate=lr_g).minimize(
示例#6
0
discriminate_image_img_s, discriminate_image_attr_s = D(s_decoder_ouput)

print("D(image_tensor) - img = ", discriminate_image_img)
print("D(image_tensor) - attr = ", discriminate_image_attr)

print("D(s_decoder_output) - img = ", discriminate_image_img_s)
print("D(s_decoder_output) - attr = ", discriminate_image_attr_s)

# --------------------------------    CALCULATE LOSSES  ----------------------------------------------------------

# Calculate discriminator loss
wd = tf.reduce_mean(discriminate_image_img) - tf.reduce_mean(
    discriminate_image_img_s)
discriminator_loss_gan = -wd
# Determine the loss function given the original image and Generator output
gradient_loss = model.gradient_penalty(D, image_tensor, s_decoder_ouput)

# Discriminator loss for image attr
image_attr_loss = tf.losses.sigmoid_cross_entropy(label_tensor,
                                                  discriminate_image_attr)

# Final loss for discriminator
discriminator_loss = discriminator_loss_gan + gradient_loss * 10.0 + image_attr_loss

# Calculate generator loss
discriminator_label_loss = -tf.reduce_mean(discriminate_image_attr_s)

# The losses the one the do any other way that
label_decoder_loss = tf.losses.sigmoid_cross_entropy(
    shuffled_label_tensor, discriminate_image_attr_s)
image_decoder_loss = tf.losses.absolute_difference(image_tensor,