コード例 #1
0
        for _ in range(g_update):
            # Sample from a random distribution (N(0,1))

            z = torch.randn(batch_size, latent_dim, 1, 1).to(device)
            fake_img = G(z).to(device)

            gen_loss = criterion(D(fake_img), real_label)

            G_optim.zero_grad()
            gen_loss.backward()
            G_optim.step()

            ########## Updating logs ##########
            discrim_log.append(discrim_loss.item())
            gen_log.append(gen_loss.item())
            utils.show_process(steps, step_i, gen_log, discrim_log)
        ########## Checkpointing ##########

        if step_i == 1:
            save_image(utils.denorm(real_img[:64, :, :, :]),
                       os.path.join(sample_dir, 'real.png'))
        if step_i % 500 == 0:
            save_image(
                utils.denorm(fake_img[:64, :, :, :]),
                os.path.join(sample_dir, 'fake_step_{}.png'.format(step_i)))
        if step_i % 2000 == 0:
            utils.save_model(G, G_optim, step_i, tuple(gen_log),
                             os.path.join(ckpt_dir, 'G.ckpt'.format(step_i)))
            utils.save_model(D, D_optim, step_i, tuple(discrim_log),
                             os.path.join(ckpt_dir, 'D.ckpt'.format(step_i)))
            utils.plot_loss(gen_log, discrim_log,
コード例 #2
0
def main():
    batch_size = args.batch_size
    iterations = args.iterations
    device = args.device

    hair_classes, eye_classes, face_classes, glasses_classes = 6, 4, 3, 2
    num_classes = hair_classes + eye_classes + face_classes + glasses_classes
    # latent_dim = 100
    # smooth = 0.9
    smooth = args.beta  # default=0.5
    latent_dim = 200

    config = 'ACGAN-batch_size-[{}]-steps-[{}]'.format(batch_size, iterations)
    print('Configuration: {}'.format(config))

    root_dir = '../selected_cartoonset100k/'

    random_sample_dir = '{}/{}/random_generation'.format(
        args.sample_dir, config)
    fixed_attribute_dir = '{}/{}/fixed_attributes'.format(
        args.sample_dir, config)
    checkpoint_dir = '{}/{}'.format(args.checkpoint_dir, config)

    if not os.path.exists(random_sample_dir):
        os.makedirs(random_sample_dir)
    if not os.path.exists(fixed_attribute_dir):
        os.makedirs(fixed_attribute_dir)
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    ########## Start Training ##########
    transform = Transform.Compose([
        Transform.ToTensor(),
        Transform.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    dataset = Anime(root_dir=root_dir, transform=transform)
    shuffler = Shuffler(dataset=dataset, batch_size=args.batch_size)

    G = Generator(latent_dim=latent_dim, class_dim=num_classes).to(device)
    D = Discriminator(hair_classes=hair_classes,
                      eye_classes=eye_classes,
                      face_classes=face_classes,
                      glasses_classes=glasses_classes).to(device)

    #####
    G_optim = optim.Adam(G.parameters(), betas=[args.beta, 0.999], lr=args.lr)
    D_optim = optim.Adam(D.parameters(), betas=[args.beta, 0.999], lr=args.lr)

    ####
    d_log, g_log, classifier_log = [], [], []
    criterion = torch.nn.BCELoss()

    min_g_loss, min_step_i = 999.9, 0  # record min G loss

    for step_i in range(1, iterations + 1):
        # 宣告 real_label、fake_label、soft_label 之 label 變數
        real_label = torch.ones(batch_size).to(device)
        fake_label = torch.zeros(batch_size).to(device)
        soft_label = torch.Tensor(batch_size).uniform_(smooth, 1).to(device)

        # 訓練 discriminator
        real_img, hair_tags, eye_tags, face_tags, glasses_tags = shuffler.get_batch(
        )
        # print('real_img', real_img.size())      # [128, 3, 128, 128]
        real_img, hair_tags, eye_tags, face_tags, glasses_tags = real_img.to(device), \
                                                                 hair_tags.to(device), \
                                                                 eye_tags.to(device), \
                                                                 face_tags.to(device), \
                                                                 glasses_tags.to(device)
        # real_tag = torch.cat((hair_tags, eye_tags), dim = 1)

        z = torch.randn(batch_size, latent_dim).to(device)
        fake_tag = get_random_label(batch_size=batch_size,
                                    hair_classes=hair_classes,
                                    eye_classes=eye_classes,
                                    face_classes=face_classes,
                                    glasses_classes=glasses_classes).to(device)
        # print('z', z.size())    # [128, 100]
        # print('fake_tag', fake_tag.size())  # [128, 15]
        fake_img = G(z, fake_tag).to(device)
        # print(fake_img.size())  # [128, 3, 64, 64]
        # print('real_img', real_img.size())
        real_score, real_hair_predict, real_eye_predict, real_face_predict, real_glasses_predict = D(
            real_img)  # [128, 3, 128, 128]
        fake_score, _, _, _, _ = D(fake_img)

        real_discrim_loss = criterion(real_score, soft_label)
        fake_discrim_loss = criterion(fake_score, fake_label)

        real_hair_aux_loss = criterion(real_hair_predict, hair_tags.float())
        real_eye_aux_loss = criterion(real_eye_predict, eye_tags.float())
        real_face_aux_loss = criterion(real_face_predict, face_tags.float())
        real_glasses_aux_loss = criterion(real_glasses_predict,
                                          glasses_tags.float())
        real_classifier_loss = real_hair_aux_loss + real_eye_aux_loss + real_face_aux_loss + real_glasses_aux_loss

        discrim_loss = real_discrim_loss + fake_discrim_loss
        # print('args.classification_weight', args.classification_weight)
        classifier_loss = real_classifier_loss * args.classification_weight

        classifier_log.append(classifier_loss.item())

        D_loss = discrim_loss + classifier_loss
        D_optim.zero_grad()
        D_loss.backward()
        D_optim.step()

        # Train generator
        z = torch.randn(batch_size, latent_dim).to(device)
        fake_tag = get_random_label(batch_size=batch_size,
                                    hair_classes=hair_classes,
                                    eye_classes=eye_classes,
                                    face_classes=face_classes,
                                    glasses_classes=glasses_classes).to(device)
        # print('fake_tag', fake_tag.size())
        hair_tag = fake_tag[:, :hair_classes]
        eye_tag = fake_tag[:, hair_classes:(hair_classes + eye_classes)]
        face_tag = fake_tag[:, (hair_classes +
                                eye_classes):(hair_classes + eye_classes +
                                              face_classes)]
        glasses_tag = fake_tag[:,
                               (hair_classes + eye_classes +
                                face_classes):(hair_classes + eye_classes +
                                               face_classes + glasses_classes)]
        fake_img = G(z, fake_tag).to(device)

        fake_score, hair_predict, eye_predict, face_predict, glasses_predict = D(
            fake_img)

        discrim_loss = criterion(fake_score, real_label.float())
        # print('hair_predict', hair_predict.size())
        # print('hair_tag', hair_tag.size())
        hair_aux_loss = criterion(hair_predict, hair_tag.float())
        eye_aux_loss = criterion(eye_predict, eye_tag.float())
        face_aux_loss = criterion(face_predict, face_tag.float())
        glasses_aux_loss = criterion(glasses_predict, glasses_tag.float())
        classifier_loss = hair_aux_loss + eye_aux_loss + face_aux_loss + glasses_aux_loss

        G_loss = classifier_loss * args.classification_weight + discrim_loss
        G_optim.zero_grad()
        G_loss.backward()
        G_optim.step()

        ########## Updating logs ##########
        d_log.append(D_loss.item())
        g_log.append(G_loss.item())

        # save the smallest loss generator model
        if min_g_loss > g_log[-1]:
            min_g_loss = g_log[-1]
            min_step_i = step_i
            print('Update min G loss to {} @step{}!'.format(
                min_g_loss, min_step_i))
            save_image(
                denorm(fake_img[:64, :, :, :]),
                os.path.join(random_sample_dir,
                             'fake_step_{}.png'.format(step_i)))
            save_model(model=G,
                       optimizer=G_optim,
                       step=step_i,
                       log=tuple(g_log),
                       file_path=os.path.join(checkpoint_dir,
                                              'G_{}.ckpt'.format(step_i)))

        if step_i == iterations:
            print('[Final] Your mim G loss is {} @step{}'.format(
                min_g_loss, min_step_i))

        show_process(total_steps=iterations,
                     step_i=step_i,
                     g_log=g_log,
                     d_log=d_log,
                     classifier_log=classifier_log)

        ########## Checkpointing ##########
        if step_i == 1:
            save_image(denorm(real_img[:64, :, :, :]),
                       os.path.join(random_sample_dir, 'real.png'))
        if step_i % args.sample == 0:
            save_image(
                denorm(fake_img[:64, :, :, :]),
                os.path.join(random_sample_dir,
                             'fake_step_{}.png'.format(step_i)))

        if step_i % args.check == 0:
            save_model(model=G,
                       optimizer=G_optim,
                       step=step_i,
                       log=tuple(g_log),
                       file_path=os.path.join(checkpoint_dir,
                                              'G_{}.ckpt'.format(step_i)))
            save_model(model=D,
                       optimizer=D_optim,
                       step=step_i,
                       log=tuple(d_log),
                       file_path=os.path.join(checkpoint_dir,
                                              'D_{}.ckpt'.format(step_i)))

            plot_loss(g_log=g_log,
                      d_log=d_log,
                      file_path=os.path.join(checkpoint_dir, 'loss.png'))
            plot_classifier_loss(log=classifier_log,
                                 file_path=os.path.join(
                                     checkpoint_dir, 'classifier loss.png'))

            generation_by_attributes(model=G,
                                     device=args.device,
                                     step=step_i,
                                     latent_dim=latent_dim,
                                     hair_classes=hair_classes,
                                     eye_classes=eye_classes,
                                     face_classes=face_classes,
                                     glasses_classes=glasses_classes,
                                     sample_dir=fixed_attribute_dir)
コード例 #3
0
def main():
    batch_size = args.batch_size
    iterations = args.iterations
    device = args.device

    hair_classes, eye_classes = 12, 10
    num_classes = hair_classes + eye_classes
    latent_dim = 100
    smooth = 0.9

    config = 'ACGAN-batch_size-[{}]-steps-[{}]'.format(batch_size, iterations)
    print('Configuration: {}'.format(config))

    root_dir = '../{}/images'.format(args.train_dir)
    tags_file = '../{}/tags.pickle'.format(args.train_dir)
    hair_prior = np.load('../{}/hair_prob.npy'.format(args.train_dir))
    eye_prior = np.load('../{}/eye_prob.npy'.format(args.train_dir))

    random_sample_dir = '{}/{}/random_generation'.format(
        args.sample_dir, config)
    fixed_attribute_dir = '{}/{}/fixed_attributes'.format(
        args.sample_dir, config)
    checkpoint_dir = '{}/{}'.format(args.checkpoint_dir, config)

    if not os.path.exists(random_sample_dir):
        os.makedirs(random_sample_dir)
    if not os.path.exists(fixed_attribute_dir):
        os.makedirs(fixed_attribute_dir)
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    ########## Start Training ##########

    transform = Transform.Compose([
        Transform.ToTensor(),
        Transform.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    dataset = Anime(root_dir=root_dir,
                    tags_file=tags_file,
                    transform=transform)
    shuffler = Shuffler(dataset=dataset, batch_size=args.batch_size)

    G = Generator(latent_dim=latent_dim, class_dim=num_classes).to(device)
    D = Discriminator(hair_classes=hair_classes,
                      eye_classes=eye_classes).to(device)

    G_optim = optim.Adam(G.parameters(), betas=[args.beta, 0.999], lr=args.lr)
    D_optim = optim.Adam(D.parameters(), betas=[args.beta, 0.999], lr=args.lr)

    d_log, g_log, classifier_log = [], [], []
    criterion = torch.nn.BCELoss()

    for step_i in range(1, iterations + 1):

        real_label = torch.ones(batch_size).to(device)
        fake_label = torch.zeros(batch_size).to(device)
        soft_label = torch.Tensor(batch_size).uniform_(smooth, 1).to(device)

        # Train discriminator
        real_img, hair_tags, eye_tags = shuffler.get_batch()
        real_img, hair_tags, eye_tags = real_img.to(device), hair_tags.to(
            device), eye_tags.to(device)
        # real_tag = torch.cat((hair_tags, eye_tags), dim = 1)

        z = torch.randn(batch_size, latent_dim).to(device)
        fake_tag = get_random_label(batch_size=batch_size,
                                    hair_classes=hair_classes,
                                    hair_prior=hair_prior,
                                    eye_classes=eye_classes,
                                    eye_prior=eye_prior).to(device)
        fake_img = G(z, fake_tag).to(device)

        real_score, real_hair_predict, real_eye_predict = D(real_img)
        fake_score, _, _ = D(fake_img)

        real_discrim_loss = criterion(real_score, soft_label)
        fake_discrim_loss = criterion(fake_score, fake_label)

        real_hair_aux_loss = criterion(real_hair_predict, hair_tags)
        real_eye_aux_loss = criterion(real_eye_predict, eye_tags)
        real_classifier_loss = real_hair_aux_loss + real_eye_aux_loss

        discrim_loss = real_discrim_loss + fake_discrim_loss
        classifier_loss = real_classifier_loss * args.classification_weight

        classifier_log.append(classifier_loss.item())

        D_loss = discrim_loss + classifier_loss
        D_optim.zero_grad()
        D_loss.backward()
        D_optim.step()

        # Train generator
        z = torch.randn(batch_size, latent_dim).to(device)
        fake_tag = get_random_label(batch_size=batch_size,
                                    hair_classes=hair_classes,
                                    hair_prior=hair_prior,
                                    eye_classes=eye_classes,
                                    eye_prior=eye_prior).to(device)
        hair_tag = fake_tag[:, hair_classes]
        eye_tag = fake_tag[:, hair_classes + 1:eye_classes]
        fake_img = G(z, fake_tag).to(device)

        fake_score, hair_predict, eye_predict = D(fake_img)

        discrim_loss = criterion(fake_score, real_label)
        hair_aux_loss = criterion(hair_predict, hair_tag)
        eye_aux_loss = criterion(eye_predict, eye_tag)
        classifier_loss = hair_aux_loss + eye_aux_loss

        G_loss = classifier_loss * args.classification_weight + discrim_loss
        G_optim.zero_grad()
        G_loss.backward()
        G_optim.step()

        ########## Updating logs ##########
        d_log.append(D_loss.item())
        g_log.append(G_loss.item())
        show_process(total_steps=iterations,
                     step_i=step_i,
                     g_log=g_log,
                     d_log=d_log,
                     classifier_log=classifier_log)

        ########## Checkpointing ##########

        if step_i == 1:
            save_image(denorm(real_img[:64, :, :, :]),
                       os.path.join(random_sample_dir, 'real.png'))
        if step_i % args.sample == 0:
            save_image(
                denorm(fake_img[:64, :, :, :]),
                os.path.join(random_sample_dir,
                             'fake_step_{}.png'.format(step_i)))

        if step_i % args.check == 0:
            save_model(model=G,
                       optimizer=G_optim,
                       step=step_i,
                       log=tuple(g_log),
                       file_path=os.path.join(checkpoint_dir,
                                              'G_{}.ckpt'.format(step_i)))
            save_model(model=D,
                       optimizer=D_optim,
                       step=step_i,
                       log=tuple(d_log),
                       file_path=os.path.join(checkpoint_dir,
                                              'D_{}.ckpt'.format(step_i)))

            plot_loss(g_log=g_log,
                      d_log=d_log,
                      file_path=os.path.join(checkpoint_dir, 'loss.png'))
            plot_classifier_loss(log=classifier_log,
                                 file_path=os.path.join(
                                     checkpoint_dir, 'classifier loss.png'))

            generation_by_attributes(model=G,
                                     device=args.device,
                                     step=step_i,
                                     latent_dim=latent_dim,
                                     hair_classes=hair_classes,
                                     eye_classes=eye_classes,
                                     sample_dir=fixed_attribute_dir)
コード例 #4
0
ファイル: train.py プロジェクト: ycchen1989/Fun-with-MNIST
            d_loss_avg /= d_updates
            d_log.append(d_loss_avg.item())

            # Train G

            z = torch.randn(batch_size, latent_dim).to(device)
            fake_img = G(z)

            fake_score = D(fake_img)

            g_loss = -fake_score

            g_optim.zero_grad()
            g_loss.backward()
            g_optim.step()
            g_log.append(g_loss.item())

            utils.show_process(epoch_i, step_i + 1, step_per_epoch, g_log,
                               d_log)

        if epoch_i == 1:
            torchvision.utils.save_image(real_img,
                                         os.path.join(sample_dir, 'real.png'),
                                         nrow=10)
        if epoch_i % 5 == 0:
            fake_img = G(fix_z)
            utils.save_image(fake_img, 10, epoch_i, step_i + 1, sample_dir)

        utils.save_model(G, g_optim, g_log, checkpoint_dir, 'G.ckpt')
        utils.save_model(D, d_optim, d_log, checkpoint_dir, 'D.ckpt')
コード例 #5
0
def main():
    batch_size = args.batch_size
    iterations =  args.iterations
    device = args.device
    
#    hair_classes, eye_classes = 12, 10
#    num_classes = hair_classes + eye_classes
    hair_class, eye_class, face_class, glass_class = 6, 4, 3, 2
    num_classes = hair_class + eye_class + face_class + glass_class
    latent_dim = 100
    smooth = 0.9
    
    config = 'WGANGP_batch{}_steps{}'.format(batch_size, iterations)
    print('Configuration: {}'.format(config))
    
    
    root_dir = './{}/images'.format(args.train_dir)
    tags_file = './{}/cartoon_attr.txt'.format(args.train_dir)
#    hair_prior = np.load('../{}/hair_prob.npy'.format(args.train_dir))
#    eye_prior = np.load('../{}/eye_prob.npy'.format(args.train_dir))

    random_sample_dir = '{}/{}/random_generation'.format(args.sample_dir, config)
    fixed_attribute_dir = '{}/{}/fixed_attributes'.format(args.sample_dir, config)
    checkpoint_dir = '{}/{}'.format(args.checkpoint_dir, config)
    
    if not os.path.exists(random_sample_dir):
    	os.makedirs(random_sample_dir)
    if not os.path.exists(fixed_attribute_dir):
    	os.makedirs(fixed_attribute_dir)
    if not os.path.exists(checkpoint_dir):
    	os.makedirs(checkpoint_dir)
        
    ########## Start Training ##########

    transform = Transform.Compose([Transform.ToTensor(),
                                   Transform.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    dataset = Anime(root_dir = root_dir, tags_file = tags_file, transform = transform)
    shuffler = Shuffler(dataset = dataset, batch_size = args.batch_size)
    
    G = Generator(latent_dim = latent_dim, class_dim = num_classes).to(device)
    D = Discriminator(hair_classes=hair_class, eye_classes=eye_class, face_classes=face_class, glass_classes=glass_class).to(device)

    G_optim = optim.Adam(G.parameters(), betas = [args.beta, 0.999], lr = args.lr)
    D_optim = optim.Adam(D.parameters(), betas = [args.beta, 0.999], lr = args.lr)
    
    d_log, g_log, classifier_log = [], [], []
    criterion = torch.nn.BCELoss()
#    criterion = torch.nn.NLLLoss()

    for step_i in range(1, iterations + 1):

        real_label = torch.ones(batch_size).to(device)
        fake_label = torch.zeros(batch_size).to(device)
        soft_label = torch.Tensor(batch_size).uniform_(smooth, 1).to(device)
        
        
        # we need gradient when training descriminator
        for p in D.parameters():
            p.requires_grad = True

        for d_iter in range(5):

            D_optim.zero_grad()
#            D.zero_grad()

            # Train discriminator
            real_img, hair_tags, eye_tags, face_tags, glass_tags = shuffler.get_batch()
            real_img, hair_tags, eye_tags, face_tags, glass_tags = real_img.to(device), hair_tags.to(device), eye_tags.to(device), face_tags.to(device), glass_tags.to(device)
            # real_tag = torch.cat((hair_tags, eye_tags), dim = 1)

            with torch.no_grad(): # totally freeze G , training D
                z = torch.randn(batch_size, latent_dim).to(device)
            
            fake_tag = get_random_label(batch_size = batch_size,
                                        hair_classes = hair_class,
                                        eye_classes = eye_class,
                                        face_classes = face_class,
                                        glass_classes = glass_class).to(device)

            real_img.requires_grad = True
            real_score, real_hair_predict, real_eye_predict, real_face_predict, real_glass_predict = D(real_img)

            real_discrim_loss = real_score.mean()
            
#            real_discrim_loss = criterion(real_score, soft_label)
#            fake_discrim_loss = criterion(fake_score, fake_label)

            real_hair_aux_loss = criterion(real_hair_predict, hair_tags)
            real_eye_aux_loss = criterion(real_eye_predict, eye_tags)
            real_face_aux_loss = criterion(real_face_predict, face_tags)
            real_glass_aux_loss = criterion(real_glass_predict, glass_tags)
            real_classifier_loss = real_hair_aux_loss + real_eye_aux_loss + real_face_aux_loss + real_glass_aux_loss

            fake_img = G(z, fake_tag).to(device)
            fake_score, _ , _ , _ , _ = D(fake_img)

            fake_discrim_loss = fake_score.mean()

            gradient_penalty = calculate_gradient_penalty(D, real_img.detach(), fake_img.detach(), batch_size, device)
        
#            discrim_loss = real_discrim_loss + fake_discrim_loss
            discrim_loss = fake_discrim_loss - real_discrim_loss + gradient_penalty
            classifier_loss = real_classifier_loss * args.classification_weight
        
            classifier_log.append(classifier_loss.item())

            D_loss = discrim_loss + classifier_loss
#            D_optim.zero_grad()
            D_loss.backward()
            D_optim.step()

        # Train generator
        for p in D.parameters():
            p.requires_grad = False

        G_optim.zero_grad()
#        G.zero_grad()

        z = torch.randn(batch_size, latent_dim).to(device)
        z.requires_grad = True
        fake_tag = get_random_label(batch_size = batch_size, 
                                    hair_classes = hair_class,
                                    eye_classes = eye_class,
                                    face_classes = face_class,
                                    glass_classes = glass_class).to(device)
    
        hair_tag = fake_tag[:, :hair_class]
        eye_tag = fake_tag[:, 6:10]
        face_tag = fake_tag[:, 10:13]
        glass_tag = fake_tag[:, 13:15]
        
        fake_img = G(z, fake_tag).to(device)
        
        fake_score, hair_predict, eye_predict, face_predict, glass_predict = D(fake_img)

        discrim_loss = fake_score.mean()
        G_discrim_loss = -discrim_loss
        
#        discrim_loss = criterion(fake_score, real_label)
        hair_aux_loss = criterion(hair_predict, hair_tag)
        eye_aux_loss = criterion(eye_predict, eye_tag)
        face_aux_loss = criterion(face_predict, face_tag)
        glass_aux_loss = criterion(glass_predict, glass_tag)
        classifier_loss = hair_aux_loss + eye_aux_loss + face_aux_loss + glass_aux_loss
        
        G_loss = classifier_loss * args.classification_weight + G_discrim_loss
#        G_optim.zero_grad()
        G_loss.backward()
        G_optim.step()
            
        ########## Updating logs ##########
        d_log.append(D_loss.item())
        g_log.append(G_loss.item())
        show_process(total_steps = iterations, step_i = step_i,
        			 g_log = g_log, d_log = d_log, classifier_log = classifier_log)

        ########## Checkpointing ##########

        if step_i == 1:
            save_image(denorm(real_img[:16,:,:,:]), os.path.join(random_sample_dir, 'real.png'), nrow=4)
        if step_i % args.sample == 0:
            save_image(denorm(fake_img[:16,:,:,:]), os.path.join(random_sample_dir, 'fake_step_{}.png'.format(step_i)), nrow=4)
            
        if step_i % args.check == 0:
            save_model(model = G, optimizer = G_optim, step = step_i, log = tuple(g_log), 
                       file_path = os.path.join(checkpoint_dir, 'G_{}.ckpt'.format(step_i)))
            save_model(model = D, optimizer = D_optim, step = step_i, log = tuple(d_log), 
                       file_path = os.path.join(checkpoint_dir, 'D_{}.ckpt'.format(step_i)))

            plot_loss(g_log = g_log, d_log = d_log, file_path = os.path.join(checkpoint_dir, 'loss.png'))
            plot_classifier_loss(log = classifier_log, file_path = os.path.join(checkpoint_dir, 'classifier loss.png'))

            generation_by_attributes(model = G, device = args.device, step = step_i,
                                     latent_dim = latent_dim, hair_classes = hair_class,
                                     eye_classes = eye_class, face_classes = face_class,
                                     glass_classes = glass_class,
                                     sample_dir = fixed_attribute_dir)