Exemple #1
0
def main():
    os.makedirs('checkpoints', exist_ok=True)

    # create models
    G = Generator(z_dim=20, image_size=64)
    D = Discriminator(z_dim=20, image_size=64)
    G.apply(weights_init)
    D.apply(weights_init)
    print('*** initialize weights')

    # load data
    train_img_list = make_datapath_list()
    print('*** num of data:', len(train_img_list))

    mean = (0.5, )
    std = (0.5, )
    train_dataset = GAN_Img_Dataset(file_list=train_img_list,
                                    transform=ImageTransform(mean, std))

    batch_size = 64
    train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=batch_size,
                                                   shuffle=True)

    num_epochs = 300
    G_update, D_update = train_model(G, D, train_dataloader, num_epochs)

    torch.save(G.state_dict(), 'checkpoints/G.pt')
    torch.save(D.state_dict(), 'checkpoints/D.pt')
Exemple #2
0
def main():

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    if device == 'cuda':
        torch.backends.cudnn.benchmark = True
        print('===>>> cuda')

    train_img_list = make_datapath_list()
    mean = (0.5,)
    std = (0.5,)

    train_dataset = GAN_img_Dataset(file_list=train_img_list,
            transform = ImageTransform(mean, std))

    train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=64,
            shuffle = True
            )

    G = Generator(z_dim =20, image_size=64)
    G.apply(weights_init)

    D = Discriminator(z_dim =20, image_size=64)
    D.apply(weights_init)


    num_epochs = 200
    print('===>>> start training')
    G_update, D_update = train(G, D, train_loader, num_epochs, device)

    print('===>>> finish training')


    batch_size = 8
    z_dim =20
    fixed_z = torch.randn(batch_size, z_dim)
    fixed_z = fixed_z.view(fixed_z.size(0), fixed_z.size(1), 1, 1)

    fake_images = G_update(fixed_z.to(device))
    batch_iterator = iter(train_loader)
    imges = next(batch_iterator)

    fig = plt.figure(figsize=(15,6))
    for i in range(0,5):
        plt.subplot(2,5,i+1)
        plt.imshow(imges[i][0].cpu().detach().numpy(), 'gray')
        plt.subplot(2,5,5+i+1)
        plt.imshow(fake_images[i][0].cpu().detach().numpy(), 'gray')

        plt.savefig('output.jpg')
def get_networks():
    netG = Generator().to(device)
    netG.apply(weights_init)

    discriminator = Discriminator().to(device)
    discriminator.apply(weights_init)

    netD = DHead().to(device)
    netD.apply(weights_init)

    netQ = QHead().to(device)
    netQ.apply(weights_init)
    return netG, discriminator, netD, netQ
Exemple #4
0
def main(
    data_dir: Path = "data/mini-birds",
    batch_size: int = 128,
    # Number of channels in the training images.
    # For color images this is 3
    nc: int = 3,
    # Size of z latent vector (i.e. size of generator input)
    nz: int = 100,
    # Size of feature maps in generator
    ngf: int = 64,
    # Size of feature maps in discriminator
    ndf: int = 64,
    device: str = "cuda",
):
    device = torch.device(device) if torch.cuda.is_available() else torch.device("cpu")

    discriminator = Discriminator(nz, ndf, nc).to(device)
    generator = Generator(nz, ngf, nc).to(device)
    discriminator.apply(weights_init)
    generator.apply(weights_init)
    discriminator_optimiser = torch.optim.Adam(
        discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999)
    )
    generator_optimiser = torch.optim.Adam(
        generator.parameters(), lr=0.0002, betas=(0.5, 0.999)
    )
    if not os.path.exists(f"./{data_dir}"):
        os.mkdir(f"./{data_dir}")
    train_loader = create_dataloader(data_dir, batch_size=batch_size)

    if os.path.exists("./sec_4_lec_2_netG.pytorch"):
        generator.load_state_dict(torch.load("./sec_4_lec_2_netG.pytorch"))
    if os.path.exists("./sec_4_lec_2_netD.pytorch"):
        discriminator.load_state_dict(torch.load("./sec_4_lec_2_netD.pytorch"))
    if os.path.exists("./sec_4_lec_2_optG.pytorch"):
        generator_optimiser.load_state_dict(torch.load("./sec_4_lec_2_optG.pytorch"))
    if os.path.exists("./sec_4_lec_2_optD.pytorch"):
        discriminator_optimiser.load_state_dict(
            torch.load("./sec_4_lec_2_optD.pytorch")
        )

    train(
        train_loader,
        discriminator,
        generator,
        discriminator_optimiser,
        generator_optimiser,
        batch_size,
        nz,
        device,
    )
Exemple #5
0
def choose_model(model_options):
    generator = Generator(model_options)
    discriminator = Discriminator(model_options)

    if torch.cuda.is_available():
        print("CUDA is available")
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        print("Moved models to GPU")

    # Initialize weights
    generator.apply(weights_init)
    discriminator.apply(weights_init)

    return generator, discriminator
Exemple #6
0
def main():
    # 使用パラメータ
    nz = 100
    mini_batch_size = 64
    image_size = 64
    num_epochs = 200

    # GPUの選択
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("Using device: ", device)

    # MNISTデータの読み込み
    load_dataset()

    # Datasetを作成
    train_data_list = make_datapath_list()
    mean = (0.5, )
    std = (0.5, )
    train_dataset = GAN_Img_Dataset(file_list=train_data_list,
                                    transform=ImageTransform(mean, std))

    # DataLoaderを作成
    train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=64,
                                                   shuffle=True)

    # インスタンス変数を作成
    G = Generator(nz=nz, image_size=image_size)
    D = Discriminator(nz=nz, image_size=image_size)

    # 初期化の実施
    G.apply(weights_init)
    D.apply(weights_init)
    print("Finish initialize network")

    # 訓練開始
    G_update, D_update, G_losses, D_losses = train_model(
        G, D, train_dataloader, num_epochs, nz, mini_batch_size, device)

    # 損失関数の可視化
    loss_process(G_losses, D_losses)
Exemple #7
0
def load_network(gpus):
    # Generator
    netG = Generator()
    netG.apply(weights_init)
    netG = torch.nn.DataParallel(netG, device_ids=gpus)
    print(netG)

    # Discriminator
    netD = Discriminator()
    netD.apply(weights_init)
    netD = torch.nn.DataParallel(netD, device_ids=gpus)
    print(netD)

    # Loading pretrained weights, if exists.
    training_iter = 0
    if cfg.TRAIN.NET_G != '':
        state_dict = torch.load(cfg.TRAIN.NET_G)
        netG.load_state_dict(state_dict)
        print('Loaded Generator from saved model.', cfg.TRAIN.NET_G)

        istart = cfg.TRAIN.NET_G.rfind('_') + 1
        iend = cfg.TRAIN.NET_G.rfind('.')
        training_iter = cfg.TRAIN.NET_G[istart:iend]
        training_iter = int(training_iter) + 1

    if cfg.TRAIN.NET_D != '':
        print('Loading Discriminator from %s.pth' % (cfg.TRAIN.NET_D))
        state_dict = torch.load('%s.pth' % (cfg.TRAIN.NET_D))
        netD.load_state_dict(state_dict)

    inception_model = INCEPTION_V3()

    # Moving to GPU
    if cfg.CUDA:
        netG.cuda()
        netD.cuda()
        inception_model = inception_model.cuda()

    inception_model.eval()

    return netG, netD, inception_model, training_iter
Exemple #8
0
def build_models(args):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    netG_A2B = Generator(args.input_nc, args.output_nc)
    netG_B2A = Generator(args.input_nc, args.output_nc)
    netD_A = Discriminator(args.input_nc)
    netD_B = Discriminator(args.output_nc)

    netG_A2B.apply(weights_init_normal)
    netG_B2A.apply(weights_init_normal)
    netD_A.apply(weights_init_normal)
    netD_B.apply(weights_init_normal)

    print('Weights Initialized')

    netG_A2B.to(device)
    netG_B2A.to(device)
    netD_A.to(device) 
    netD_B.to(device) 

    print(f'Transferred to {device}')

    return netG_A2B, netG_B2A, netD_A, netD_B
Exemple #9
0
def main():
    # load data
    annotationfile = image_dir + 'edited_annotations.csv'
    animefacedata = AnimeFaceDataset(annotationfile, image_dir)
    dataloader = DataLoader(animefacedata,
                            batch_size=batch_size,
                            shuffle=True,
                            collate_fn=my_collate,
                            drop_last=True)
    print("Data loaded : %d" % (len(animefacedata)))

    G = Generator()
    D = Discriminator()
    G.apply(init_weight)
    D.apply(init_weight)

    if args.cuda:
        G = G.cuda()
        D = D.cuda()

    criterion = nn.BCELoss()
    print("Start Training")
    train(G, D, dataloader, criterion)
    print("Finished training!")
Exemple #10
0
    # !!! Minimizes MSE instead of BCE
    adversarial_loss = torch.nn.MSELoss()

    # Initialize generator and discriminator
    generator = Generator(opt)
    discriminator = Discriminator(opt)

    if cuda:
        generator.cuda()
        discriminator.cuda()
        adversarial_loss.cuda()

    # Initialize weights
    generator.apply(weights_init_normal)
    discriminator.apply(weights_init_normal)

    dataloader = get_data_loader(opt)

    # optimizer
    optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr)
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr)

    Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

    # resume checkpoint
    if opt.resume_generator and opt.resume_discriminator:
        print('Resuming checkpoint from {} and {}'.format(
            opt.resume_generator, opt.resume_discriminator))
        checkpoint_generator = torch.load(opt.resume_generator)
        checkpoint_discriminator = torch.load(opt.resume_discriminator)
Exemple #11
0
class Solver(object):
    def __init__(self, config, data_loader):
        self.generator = None
        self.discriminator = None
        self.g_optimizer = None
        self.d_optimizer = None
        self.pc_name = config.pc_name
        self.base_path = config.base_path
        self.time_now = config.time_now
        self.inject_z = config.inject_z
        self.data_loader = data_loader
        self.num_epochs = config.num_epochs
        self.sample_size = config.sample_size
        self.logs_path = config.logs_path
        self.save_every = config.save_every
        self.activation_fn = config.activation_fn
        self.max_score = config.max_score
        self.lr = config.lr
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.validation_step = config.validation_step
        self.sample_path = config.sample_path
        self.model_path = config.model_path
        self.g_layers = config.g_layers
        self.d_layers = config.d_layers
        self.z_dim = self.g_layers[0]
        self.num_imgs_val = config.num_imgs_val
        self.criterion = nn.BCEWithLogitsLoss()
        self.ckpt_gen_path = config.ckpt_gen_path
        self.gp_weight = config.gp_weight
        self.loss = config.loss
        self.seed = config.seed
        self.validation_path = config.validation_path
        self.FID_images = config.FID_images
        self.transform_rep = config.transform_rep
        self.transform_z = config.transform_z
        self.spectral_norm = config.spectral_norm
        self.cifar10_path = config.cifar10_path
        self.fid_score = 100000
        self.concat_injection = config.concat_injection
        self.norm = config.norm
        self.build_model()

    def build_model(self):
        torch.manual_seed(self.seed)
        self.generator = Generator(g_layers=self.g_layers,
                                   activation_fn=self.activation_fn,
                                   inject_z=self.inject_z,
                                   transform_rep=self.transform_rep,
                                   transform_z=self.transform_z,
                                   concat_injection=self.concat_injection,
                                   norm=self.norm)
        self.discriminator = Discriminator(d_layers=self.d_layers,
                                           activation_fn=self.activation_fn,
                                           spectral_norm=self.spectral_norm)
        self.generator.apply(self.weights_init)
        self.discriminator.apply(self.weights_init)
        self.g_optimizer = optim.Adam(self.generator.parameters(),
                                      self.lr,
                                      betas=(self.beta1, self.beta2))
        self.d_optimizer = optim.Adam(self.discriminator.parameters(),
                                      self.lr,
                                      betas=(self.beta1, self.beta2))
        self.logger = Logger(self.logs_path)

        self.gen_params = sum(p.numel() for p in self.generator.parameters()
                              if p.requires_grad)
        self.disc_params = sum(p.numel()
                               for p in self.discriminator.parameters()
                               if p.requires_grad)
        self.total_params = self.gen_params + self.disc_params

        print("Generator params: {}".format(self.gen_params))
        print("Discrimintor params: {}".format(self.disc_params))
        print("Total params: {}".format(self.total_params))

        if torch.cuda.is_available():
            self.generator.cuda()
            self.discriminator.cuda()

    def reset_grad(self):
        self.discriminator.zero_grad()
        self.generator.zero_grad()

    # custom weights initialization called on netG and netD
    def weights_init(self, m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            m.weight.data.normal_(0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            m.weight.data.normal_(1.0, 0.02)
            m.bias.data.fill_(0)

    def gradient_penalty(self, real_data, generated_data):
        batch_size = real_data.size()[0]

        # Calculate interpolation
        alpha = torch.rand(batch_size, 1)
        alpha = alpha.expand_as(real_data)
        alpha = alpha.cuda()
        interpolated = alpha * real_data.data + (1 -
                                                 alpha) * generated_data.data
        interpolated = Variable(interpolated, requires_grad=True)
        interpolated = interpolated.cuda()

        # Calculate probability of interpolated examples
        prob_interpolated = self.discriminator(interpolated)

        # Calculate gradients of probabilities with respect to examples
        gradients = torch_grad(outputs=prob_interpolated,
                               inputs=interpolated,
                               grad_outputs=torch.ones(
                                   prob_interpolated.size()).cuda(),
                               create_graph=True,
                               retain_graph=True)[0]

        # Gradients have shape (batch_size, num_channels, height, width),
        # so flatten to easily take norm per example in batch
        gradients = gradients.view(batch_size, -1)

        # Derivatives of the gradient close to 0 can cause problems because of
        # the square root, so manually calculate norm and add epsilon
        gradients_norm = torch.sqrt(torch.sum(gradients**2, dim=1) + 1e-12)

        # Return gradient penalty
        return self.gp_weight * ((gradients_norm - 1)**2).mean()

    def train(self):
        total_step = len(self.data_loader)
        for epoch in range(self.num_epochs):
            for i, (data, _) in enumerate(self.data_loader):

                batch_size = data.size(0)
                # train Discriminator
                data = data.type(torch.FloatTensor)
                data = to_cuda(data)

                real_labels = to_cuda(torch.ones(batch_size,
                                                 self.d_layers[-1]))
                fake_labels = to_cuda(
                    torch.zeros(batch_size, self.d_layers[-1]))

                outputs_real = self.discriminator(data)
                z = to_cuda(torch.randn(batch_size, self.z_dim, 1, 1))
                fake_data = self.generator(z)
                outputs_fake = self.discriminator(fake_data)

                if self.loss == 'original':
                    d_loss_real = self.criterion(outputs_real.squeeze(),
                                                 real_labels.squeeze())
                    d_loss_fake = self.criterion(outputs_fake.squeeze(),
                                                 fake_labels.squeeze())
                    d_loss = d_loss_real + d_loss_fake

                elif self.loss == 'wgan-gp':
                    gradient_penalty = self.gradient_penalty(data, fake_data)
                    d_loss = -outputs_real.mean() + outputs_fake.mean(
                    ) + gradient_penalty

                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()

                # train Generator
                z = to_cuda(torch.randn(batch_size, self.z_dim, 1, 1))
                fake_data = self.generator(z)
                outputs_fake = self.discriminator(fake_data)

                if self.loss == 'original':
                    g_loss = self.criterion(outputs_fake.squeeze(),
                                            real_labels.squeeze())
                elif self.loss == 'wgan-gp':
                    g_loss = -outputs_fake.mean()

                self.reset_grad()
                g_loss.backward()
                self.g_optimizer.step()

                if (i + 1) % self.log_step == 0:
                    print(
                        'Epoch [{0:d}/{1:d}], Step [{2:d}/{3:d}], d_real_loss: {4:.4f}, '
                        ' g_loss: {5:.4f}'.format(epoch + 1, self.num_epochs,
                                                  i + 1, total_step,
                                                  d_loss.item(),
                                                  g_loss.item()))

                    # log scalars in tensorboard
                    info = {
                        'd_real_loss': d_loss.item(),
                        'g_loss': g_loss.item(),
                        'inception_score': self.max_score
                    }

                    for tag, value in info.items():
                        self.logger.scalar_summary(tag, value,
                                                   epoch * total_step + i + 1)

                if (i + 1) % self.sample_step == 0:
                    save_image(denorm(fake_data).cpu(),
                               self.sample_path +
                               "/epoch_{}_{}.png".format(i + 1, epoch + 1),
                               nrow=8)

                if (i + 1) % self.validation_step == 0:
                    fake_data_all = np.zeros(
                        (self.num_imgs_val, fake_data.size(1),
                         fake_data.size(2), fake_data.size(3)))
                    for j in range(self.num_imgs_val // batch_size):
                        fake_data_all[j * batch_size:(j + 1) *
                                      batch_size] = to_numpy(fake_data)
                    npy_path = os.path.join(
                        self.model_path,
                        '{}_{}_val_data.pkl'.format(epoch + 1, i + 1))
                    np.save(npy_path, fake_data_all)
                    score, _ = IS(fake_data_all,
                                  cuda=True,
                                  batch_size=batch_size)
                    if score > self.max_score:
                        print("Found new best IS score: {}".format(score))
                        self.max_score = score
                        data = "IS " + str(self.seed) + " " + str(
                            epoch + 1) + " " + str(i + 1) + " " + str(
                                self.max_score)
                        save_is(self.base_path, data)
                        g_path = os.path.join(self.model_path,
                                              'generator-best.pkl')
                        d_path = os.path.join(self.model_path,
                                              'discriminator-best.pkl')
                        torch.save(self.generator.state_dict(), g_path)
                        torch.save(self.discriminator.state_dict(), d_path)
                    for j in range(self.FID_images):
                        z = to_cuda(torch.randn(1, self.z_dim, 1, 1))
                        fake_datum = self.generator(z)
                        save_image(
                            denorm(fake_datum.squeeze()).cpu(),
                            self.validation_path + "/" + str(j) + ".png")
                    fid_value = FID([self.validation_path, self.cifar10_path],
                                    64, True, 2048)
                    if fid_value < self.fid_score:
                        self.fid_score = fid_value
                        print("Found new best FID score: {}".format(
                            self.fid_score))
                        data = "FID " + str(self.seed) + " " + str(
                            epoch + 1) + " " + str(i + 1) + " " + str(
                                self.fid_score)
                        save_is(self.base_path, data)
                        g_path = os.path.join(self.model_path,
                                              'generator-best-fid.pkl')
                        d_path = os.path.join(self.model_path,
                                              'discriminator-best-fid.pkl')
                        torch.save(self.generator.state_dict(), g_path)
                        torch.save(self.discriminator.state_dict(), d_path)

            if (epoch + 1) % self.save_every == 0:
                g_path = os.path.join(self.model_path,
                                      'generator-{}.pkl'.format(epoch + 1))
                d_path = os.path.join(self.model_path,
                                      'discriminator-{}.pkl'.format(epoch + 1))
                torch.save(self.generator.state_dict(), g_path)
                torch.save(self.discriminator.state_dict(), d_path)

    def sample(self, n_samples):
        self.n_samples = n_samples
        self.generator = Generator(g_layers=self.g_layers,
                                   inject_z=self.inject_z)
        self.generator.load_state_dict(torch.load(self.ckpt_gen_path))
        if torch.cuda.is_available():
            self.generator.cuda()
        self.generator.eval()

        z_samples = to_cuda(torch.randn(n_samples, self.z_dim, 1, 1))
        generated_samples = self.generator(z_samples)
        generated_samples = to_numpy(generated_samples)
        np.save('./saved/generated_samples.npy', generated_samples)
        z_samples = to_numpy(z_samples)
        np.save('./saved/z_samples.npy', z_samples)
Exemple #12
0
        opt.batch = state_json['batch']

        netG_A2B.load_state_dict(torch.load(get_state_path('netG_A2B')))
        netG_B2A.load_state_dict(torch.load(get_state_path('netG_B2A')))
        netD_A.load_state_dict(torch.load(get_state_path('netD_A')))
        netD_B.load_state_dict(torch.load(get_state_path('netD_B')))

        if opt.use_mask:
            netD_Am.load_state_dict(torch.load(get_state_path('netD_Am')))
            netD_Bm.load_state_dict(torch.load(get_state_path('netD_Bm')))
else:
    os.makedirs(_run_dir)

    netG_A2B.apply(weights_init_normal)
    netG_B2A.apply(weights_init_normal)
    netD_A.apply(weights_init_normal)
    netD_B.apply(weights_init_normal)

    netD_Am.apply(weights_init_normal)
    netD_Bm.apply(weights_init_normal)

# Lossess
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()

# Optimizers & LR schedulers
optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(),
                                               netG_B2A.parameters()),
                               lr=opt.lr,
                               betas=(0.5, 0.999))
def main():

    parse = argparse.ArgumentParser()

    parse.add_argument("--lr", type=float, default=0.00005, 
                        help="learning rate of generate and discriminator")
    parse.add_argument("--clamp", type=float, default=0.01, 
                        help="clamp discriminator parameters")
    parse.add_argument("--batch_size", type=int, default=10,
                        help="number of dataset in every train or test iteration")
    parse.add_argument("--dataset", type=str, default="faces",
                        help="base path for dataset")
    parse.add_argument("--epochs", type=int, default=500,
                        help="number of training epochs")
    parse.add_argument("--loaders", type=int, default=4,
                        help="number of parallel data loading processing")
    parse.add_argument("--size_per_dataset", type=int, default=30000,
                        help="number of training data")

    args = parse.parse_args()

    if not os.path.exists("saved"):
        os.mkdir("saved")
    if not os.path.exists("saved/img"):
        os.mkdir("saved/img")

    if os.path.exists("faces"):
        pass
    else:
        print("Don't find the dataset directory, please copy the link in website ,download and extract faces.tar.gz .\n \
        https://drive.google.com/drive/folders/1mCsY5LEsgCnc0Txv0rpAUhKVPWVkbw5I \n ")
        exit()

    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    generate = Generate().to(device)
    discriminator = Discriminator().to(device)

    generate.apply(weight_init)
    discriminator.apply(weight_init)

    dataset = AnimeDataset(os.getcwd(), args.dataset, args.size_per_dataset)
    dataload = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.loaders)

    optimizer_G = RMSprop(generate.parameters(), lr=args.lr)
    optimizer_D = RMSprop(discriminator.parameters(), lr=args.lr)

    fixed_noise = torch.randn(64, 100, 1, 1).to(device)
    step = 0
    for epoch in range(args.epochs):

        print("Main epoch{}:".format(epoch))
        progress = tqdm(total=len(dataload.dataset))
        
        for i, inp in enumerate(dataload):
            step += 1
            # train discriminator   
            real_data = inp.float().to(device)
            noise = torch.randn(inp.size()[0], 100, 1, 1).to(device)
            fake_data = generate(noise)
            optimizer_D.zero_grad()
            real_output = torch.mean(discriminator(real_data).squeeze())
            fake_output = torch.mean(discriminator(fake_data).squeeze())
            output = (real_output - fake_output)* -1
            output.backward()
            optimizer_D.step()
            
            for param in discriminator.parameters():
                param.data.clamp_(-args.clamp, args.clamp)

            #train generate
            if step%5 == 0:
                optimizer_G.zero_grad()
                fake_data = generate(noise)
                fake_output = -torch.mean(discriminator(fake_data).squeeze())
                fake_output.backward()
                optimizer_G.step()
            
            progress.update(dataload.batch_size)

        if epoch % 20 == 0:

            torch.save(generate, os.path.join(os.getcwd(), "saved/generate.t7"))
            torch.save(discriminator, os.path.join(os.getcwd(), "saved/discriminator.t7"))

            img = generate(fixed_noise).to("cpu").detach().numpy()

            display_grid = np.zeros((8*96,8*96,3))
            
            for j in range(int(64/8)):
                for k in range(int(64/8)):
                    display_grid[j*96:(j+1)*96,k*96:(k+1)*96,:] = (img[k+8*j].transpose(1, 2, 0)+1)/2

            img_save_path = os.path.join(os.getcwd(),"saved/img/{}.png".format(epoch))
            scipy.misc.imsave(img_save_path, display_grid)

    creat_gif("evolution.gif", os.path.join(os.getcwd(),"saved/img"))
class Trainer(object):
    def __init__(self, args):
        self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

        self.batch_size = args.batch_size
        self.half_size = self.batch_size // 2
        assert self.batch_size % 2 == 0, '[!] batch_size is '
        self.nz = args.nz

        self.lambda_kl = args.lambda_kl
        self.lambda_img = args.lambda_img
        self.lambda_z = args.lambda_z

        if args.img_size == 128:
            d_n_blocks = 2
            g_n_blocks = 7
            e_n_blocks = 4
        elif args.img_size == 256:
            d_n_blocks = 3
            g_n_blocks = 8
            e_n_blocks = 5

        # Discriminator for cVAE-GAN(encoded vector z)
        self.D_cVAE = Discriminator(args.input_nc + args.output_nc,
                                    args.ndf,
                                    n_blocks=d_n_blocks).to(self.device)
        self.D_cVAE.apply(weights_init)
        # print(self.D_cVAE)
        # Discriminator for cLR-GAN(random vector z)
        self.D_cLR = Discriminator(args.input_nc + args.output_nc,
                                   args.ndf,
                                   n_blocks=d_n_blocks).to(self.device)
        self.D_cLR.apply(weights_init)

        self.G = Generator(args.input_nc,
                           args.output_nc,
                           args.ngf,
                           args.nz,
                           n_blocks=g_n_blocks).to(self.device)
        self.G.apply(weights_init)
        # print(self.G)

        self.E = Encoder(args.input_nc, args.nz, args.nef,
                         n_blocks=e_n_blocks).to(self.device)
        self.E.apply(weights_init)
        # print(self.E)

        # Optimizers
        self.optim_D_cVAE = optim.Adam(self.D_cVAE.parameters(),
                                       lr=args.lr,
                                       betas=(args.beta1, args.beta2))
        self.optim_D_cLR = optim.Adam(self.D_cLR.parameters(),
                                      lr=args.lr,
                                      betas=(args.beta1, args.beta2))
        self.optim_G = optim.Adam(self.G.parameters(),
                                  lr=args.lr,
                                  betas=(args.beta1, args.beta2))
        self.optim_E = optim.Adam(self.E.parameters(),
                                  lr=args.lr,
                                  betas=(args.beta1, args.beta2))

        time_str = time.strftime("%Y%m%d-%H%M%S")
        self.writer = SummaryWriter('{}/{}-{}'.format(args.log_dir,
                                                      args.dataset_name,
                                                      time_str))

    def __del__(self):
        self.writer.close()

    def all_zero_grad(self):
        self.optim_D_cVAE.zero_grad()
        self.optim_D_cLR.zero_grad()
        self.optim_G.zero_grad()
        self.optim_E.zero_grad()

    def save_weights(self, save_dir, global_step):
        d_cVAE_name = 'D_cVAE_{}.pth'.format(global_step)
        d_cLR_name = 'D_cLR_{}.pth'.format(global_step)
        g_name = 'G_{}.pth'.format(global_step)
        e_name = 'E_{}.pth'.format(global_step)

        torch.save(self.D_cVAE.state_dict(),
                   os.path.join(save_dir, d_cVAE_name))
        torch.save(self.D_cLR.state_dict(), os.path.join(save_dir, d_cLR_name))
        torch.save(self.G.state_dict(), os.path.join(save_dir, g_name))
        torch.save(self.E.state_dict(), os.path.join(save_dir, e_name))

    def optimize(self, A, B, global_step):
        if A.size(0) <= 1:
            return

        A = A.to(self.device)
        B = B.to(self.device)

        cVAE_data = {'A': A[0:self.half_size], 'B': B[0:self.half_size]}
        cLR_data = {'A': A[self.half_size:], 'B': B[self.half_size:]}

        # Logging the input images
        log_imgs = torch.cat([cVAE_data['A'], cVAE_data['B']], 0)
        log_imgs = torchvision.utils.make_grid(log_imgs)
        log_imgs = denormalize(log_imgs)
        self.writer.add_image('cVAE_input', log_imgs, global_step)

        log_imgs = torch.cat([cLR_data['A'], cLR_data['B']], 0)
        log_imgs = torchvision.utils.make_grid(log_imgs)
        log_imgs = denormalize(log_imgs)
        self.writer.add_image('cLR_input', log_imgs, global_step)

        # ----------------------------------------------------------------
        # 1. Train D
        # ----------------------------------------------------------------

        # -----------------------------
        # Optimize D in cVAE-GAN
        # -----------------------------
        # Generate encoded latent vector
        mu, logvar = self.E(cVAE_data['B'])
        std = torch.exp(logvar / 2)
        random_z = sample_z(self.half_size, self.nz, 'gauss').to(self.device)
        encoded_z = (random_z * std) + mu

        # Generate fake image
        fake_img_cVAE = self.G(cVAE_data['A'], encoded_z)
        log_imgs = torchvision.utils.make_grid(fake_img_cVAE)
        log_imgs = denormalize(log_imgs)
        self.writer.add_image('cVAE_fake_encoded', log_imgs, global_step)

        real_pair_cVAE = torch.cat([cVAE_data['A'], cVAE_data['B']], dim=1)
        fake_pair_cVAE = torch.cat([cVAE_data['A'], fake_img_cVAE], dim=1)

        real_D_cVAE_1, real_D_cVAE_2 = self.D_cVAE(real_pair_cVAE)
        fake_D_cVAE_1, fake_D_cVAE_2 = self.D_cVAE(fake_pair_cVAE.detach())

        # The loss for small patch & big patch
        loss_D_cVAE_1 = mse_loss(real_D_cVAE_1, target=1) + mse_loss(
            fake_D_cVAE_1, target=0)
        loss_D_cVAE_2 = mse_loss(real_D_cVAE_2, target=1) + mse_loss(
            fake_D_cVAE_2, target=0)

        self.writer.add_scalar('loss/loss_D_cVAE_1', loss_D_cVAE_1.item(),
                               global_step)
        self.writer.add_scalar('loss/loss_D_cVAE_2', loss_D_cVAE_2.item(),
                               global_step)

        # -----------------------------
        # Optimize D in cLR-GAN
        # -----------------------------
        # Generate fake image
        fake_img_cLR = self.G(cLR_data['A'], random_z)
        log_imgs = torchvision.utils.make_grid(fake_img_cLR)
        log_imgs = denormalize(log_imgs)
        self.writer.add_image('cLR_fake_random', log_imgs, global_step)

        real_pair_cLR = torch.cat([cLR_data['A'], cLR_data['B']], dim=1)
        fake_pair_cLR = torch.cat([cVAE_data['A'], fake_img_cLR], dim=1)

        real_D_cLR_1, real_D_cLR_2 = self.D_cLR(real_pair_cLR)
        fake_D_cLR_1, fake_D_cLR_2 = self.D_cLR(fake_pair_cLR.detach())

        # Loss for small patch & big patch
        loss_D_cLR_1 = mse_loss(real_D_cLR_1, target=1) + mse_loss(
            fake_D_cLR_1, target=0)
        loss_D_cLR_2 = mse_loss(real_D_cLR_2, target=1) + mse_loss(
            fake_D_cLR_2, target=0)

        self.writer.add_scalar('loss/loss_D_cVAE_1', loss_D_cVAE_1.item(),
                               global_step)
        self.writer.add_scalar('loss/loss_D_cVAE_2', loss_D_cVAE_2.item(),
                               global_step)

        loss_D = loss_D_cVAE_1 + loss_D_cVAE_2 + loss_D_cLR_1 + loss_D_cLR_2
        self.writer.add_scalar('loss/loss_D', loss_D.item(), global_step)

        # -----------------------------
        # Update D
        # -----------------------------
        # set_requires_grad([], False)
        self.all_zero_grad()
        loss_D.backward()
        self.optim_D_cVAE.step()
        self.optim_D_cLR.step()

        # ----------------------------------------------------------------
        # 2. Train G & E
        # ----------------------------------------------------------------

        # -----------------------------
        # GAN loss
        # -----------------------------
        # Generate encoded latent vector
        mu, logvar = self.E(cVAE_data['B'])
        std = torch.exp(logvar / 2)
        random_z = sample_z(self.half_size, self.nz, 'gauss').to(self.device)
        encoded_z = (random_z * std) + mu

        # Generate fake image
        fake_img_cVAE = self.G(cVAE_data['A'], encoded_z)
        # self.writer.add_images('cVAE_output', fake_img_cVAE.add(1.0).mul(0.5), global_step)
        fake_pair_cVAE = torch.cat([cVAE_data['A'], fake_img_cVAE], dim=1)

        # Fool D_cVAE
        fake_D_cVAE_1, fake_D_cVAE_2 = self.D_cVAE(fake_pair_cVAE)

        # Loss for small patch & big patch
        loss_G_cVAE_1 = mse_loss(fake_D_cVAE_1, target=1)
        loss_G_cVAE_2 = mse_loss(fake_D_cVAE_2, target=1)

        # Random latent vector and generate fake image
        random_z = sample_z(self.half_size, self.nz, 'gauss').to(self.device)
        fake_img_cLR = self.G(cLR_data['A'], random_z)
        fake_pair_cLR = torch.cat([cLR_data['A'], fake_img_cLR], dim=1)

        # Fool D_cLR
        fake_D_cLR_1, fake_D_cLR_2 = self.D_cLR(fake_pair_cLR)

        # Loss for small patch & big patch
        loss_G_cLR_1 = mse_loss(fake_D_cLR_1, target=1)
        loss_G_cLR_2 = mse_loss(fake_D_cLR_2, target=1)

        loss_G = loss_G_cVAE_1 + loss_G_cVAE_2 + loss_G_cLR_1 + loss_G_cLR_2
        self.writer.add_scalar('loss/loss_G', loss_G.item(), global_step)

        # -----------------------------
        # KL-divergence (cVAE-GAN)
        # -----------------------------
        kl_div = torch.sum(
            0.5 * (mu**2 + torch.exp(logvar) - logvar - 1)) * self.lambda_kl
        self.writer.add_scalar('loss/kl_div', kl_div.item(), global_step)

        # -----------------------------
        # Reconstruction of image B (|G(A, z) - B|) (cVAE-GAN)
        # -----------------------------
        loss_img_recon = l1_loss(fake_img_cVAE,
                                 cVAE_data['B']) * self.lambda_img
        self.writer.add_scalar('loss/loss_img_recon', loss_img_recon.item(),
                               global_step)

        loss_E_G = loss_G + kl_div + loss_img_recon
        self.writer.add_scalar('loss/loss_E_G', loss_E_G.item(), global_step)

        # -----------------------------
        # Update E & G
        # -----------------------------
        self.all_zero_grad()
        loss_E_G.backward(retain_graph=True)
        self.optim_E.step()
        self.optim_G.step()

        # ----------------------------------------------------------------
        # 3. Train only G
        # ----------------------------------------------------------------

        # -----------------------------
        # Reconstruction of random latent code (|E(G(A, z)) - z|) (cLR-GAN)
        # -----------------------------
        # This step should update only G.
        # See https://github.com/junyanz/BicycleGAN/issues/5 for details.
        mu, logvar = self.E(fake_img_cLR)

        loss_z_recon = l1_loss(mu, random_z) * self.lambda_z
        self.writer.add_scalar('loss/loss_z_recon', loss_z_recon.item(),
                               global_step)

        # -----------------------------
        # Update G
        # -----------------------------
        self.all_zero_grad()
        loss_z_recon.backward()
        self.optim_G.step()
Exemple #15
0
def main(input_size, output_size, gen_filter_size, dis_filter_size, num_epochs,
         beta1, lr, dataloader, device, ngpu):
    torch.manual_seed(0)

    G = Generator(input_size, gen_filter_size, output_size,
                  ngpu=ngpu).to(device)
    D = Discriminator(output_size, dis_filter_size, ngpu=ngpu).to(device)

    if (device.type == 'cuda') and (ngpu > 1):
        G = nn.DataParallel(G, list(range(ngpu)))
        D = nn.DataParallel(D, list(range(ngpu)))

    G.apply(weights_init)
    D.apply(weights_init)

    criterion = nn.BCELoss()
    fixed_noise = torch.randn(64, input_size, 1, 1, device=device)
    real_rabel = 1.
    fake_label = 0.

    optimizerG = optim.Adam(G.parameters(), lr=lr, betas=(beta1, 0.999))
    optimizerD = optim.Adam(D.parameters(), lr=lr, betas=(beta1, 0.999))

    img_list = []
    G_losses = []
    D_losses = []

    iters = 0

    for epoch in range(num_epochs):
        for i, data in enumerate(dataloader, 0):
            D.zero_grad()
            real = data[0].to(device)
            b_size = real.size(0)
            label = torch.full((b_size, ),
                               real_rabel,
                               dtype=torch.float,
                               device=device)

            output = D(real).view(-1)
            errD_real = criterion(output, label)
            errD_real.backward()
            D_x = output.mean().item()

            noise = torch.randn(b_size, input_size, 1, 1, device=device)
            fake = G(noise)
            label.fill_(fake_label)
            output = D(fake.detach()).view(-1)
            errD_fake = criterion(output, label)
            errD_fake.backward()
            D_G_z1 = output.mean().item()

            errD = errD_real + errD_fake
            optimizerD.step()

            G.zero_grad()
            label.fill_(real_rabel)
            output = D(fake).view(-1)
            errG = criterion(output, label)
            errG.backward()

            D_G_z2 = output.mean().item()

            optimizerG.step()
            if i % 50 == 0:
                logger.info(
                    f'[{epoch}/{num_epochs}][{i}/{len(dataloader)}]\tLoss_D: {errD.item():.4f}\tLoss_G: {errG.item():.4f}\tD(x): {D_x:.4f}\tD(G(z)): {D_G_z1:.4f} / {D_G_z2:.4f}'
                )
            G_losses.append(errG.item())
            D_losses.append(errD.item())
            if (iters % 500 == 0) or ((epoch == num_epochs - 1) and
                                      (i == len(dataloader) - 1)):
                with torch.no_grad():
                    fake = G(fixed_noise).detach().cpu()
                img_list.append(
                    vutils.make_grid(fake, padding=2, normalize=True))

            iters += 1
    torch.save(G.state_dict(), "./generator.pth")
    torch.save(D.state_dict(), "./discriminator.pth")
    return img_list, G_losses, D_losses
Exemple #16
0
g = Generator().to(device)
g_opt = Adam(g.parameters())
d = Discriminator().to(device)
d_opt = Adagrad(d.parameters())
loss_func = nn.BCELoss()

def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm2d') != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

g.apply(weights_init_normal)
d.apply(weights_init_normal)



def generate(save_dir='output', save_png=False, save_mid=False):
    z = torch.randn(64, 64).to(device)
    gen_x = g(z)
    gen_x = gen_x.detach().reshape(64, 64, 64).to('cpu').numpy()
    os.makedirs(save_dir, exist_ok=True)
    if save_png:
        from PIL import Image
        imgs = (gen_x * 255).astype(np.uint8)
        path = os.path.join(save_dir, 'all.png')
        Image.fromarray(np.concatenate(imgs, 0).T).save(path)
    if save_mid:
        from notes import array_to_pm
Exemple #17
0
class Trainer(object):
    def __init__(self, args):
        self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

        self.G = Generator(args.input_nc, args.output_nc,
                           args.ngf).to(self.device)
        self.G.apply(weights_init)
        print(self.G)

        self.D = Discriminator(args.input_nc + args.output_nc,
                               args.ndf).to(self.device)
        self.D.apply(weights_init)
        print(self.D)

        self.optim_G = optim.Adam(self.G.parameters(),
                                  lr=args.lr,
                                  betas=(args.beta1, args.beta2))
        self.optim_D = optim.Adam(self.D.parameters(),
                                  lr=args.lr,
                                  betas=(args.beta1, args.beta2))

        # Arguments
        self.lambda_l1 = args.lambda_l1
        self.log_freq = args.log_freq
        self.dataset_name = args.dataset_name
        self.img_size = args.crop_size

        time_str = time.strftime("%Y%m%d-%H%M%S")
        self.writer = SummaryWriter('{}/{}-{}'.format(args.log_dir,
                                                      args.dataset_name,
                                                      time_str))

    def __del__(self):
        self.writer.close()

    def optimize(self, A, B, global_step):
        A = A.to(self.device)
        B = B.to(self.device)

        # Logging the input images
        if global_step % self.log_freq == 0:
            log_real_A = torchvision.utils.make_grid(A)
            log_real_A = denormalize(log_real_A)
            self.writer.add_image('real_A', log_real_A, global_step)

            log_real_B = torchvision.utils.make_grid(B)
            log_real_B = denormalize(log_real_B)
            self.writer.add_image('real_B', log_real_B, global_step)

        # Forward pass
        fake_B = self.G(A)

        if global_step % self.log_freq == 0:
            log_fake_B = torchvision.utils.make_grid(fake_B)
            log_fake_B = denormalize(log_fake_B)
            self.writer.add_image('fake_B', log_fake_B, global_step)

        # ==================================================================
        # 1. Train D
        # ==================================================================
        self._set_requires_grad(self.D, True)

        # Real
        real_pair = torch.cat([A, B], dim=1)
        real_D = self.D(real_pair)
        loss_real_D = gan_loss(real_D, target=1)

        # Fake
        fake_pair = torch.cat([A, fake_B], dim=1)
        fake_D = self.D(fake_pair.detach())
        loss_fake_D = gan_loss(fake_D, target=0)

        loss_D = (loss_real_D + loss_fake_D) * 0.5

        self._all_zero_grad()
        loss_D.backward()
        self.optim_D.step()

        # Logging
        self.writer.add_scalar('loss/loss_D', loss_D.item(), global_step)

        # ==================================================================
        # 2. Train G
        # ==================================================================
        self._set_requires_grad(self.D, False)

        # Fake
        fake_D2 = self.D(fake_pair)

        loss_G_GAN = gan_loss(fake_D2, target=1)
        loss_G_L1 = l1_loss(fake_B, B)
        loss_G = loss_G_GAN + loss_G_L1 * self.lambda_l1

        self._all_zero_grad()
        loss_G.backward()
        self.optim_G.step()

        # Logging
        self.writer.add_scalar('loss/loss_G_GAN', loss_G_GAN.item(),
                               global_step)
        self.writer.add_scalar('loss/loss_G_L1', loss_G_L1.item(), global_step)
        self.writer.add_scalar('loss/loss_G', loss_G.item(), global_step)

    def _all_zero_grad(self):
        self.optim_D.zero_grad()
        self.optim_G.zero_grad()

    def _set_requires_grad(self, nets, requires_grad=False):
        if not isinstance(nets, list):
            nets = [nets]
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.requires_grad = requires_grad

    def save_weights(self, save_dir, global_step):
        d_name = '{}_D_{}.pth'.format(self.dataset_name, global_step)
        g_name = '{}_G_{}.pth'.format(self.dataset_name, global_step)

        torch.save(self.D.state_dict(), os.path.join(save_dir, d_name))
        torch.save(self.G.state_dict(), os.path.join(save_dir, g_name))

    def save_video(self, video_dir, global_step):
        output_dir = os.path.join(video_dir, 'step_{}'.format(global_step))
        os.mkdir(output_dir)

        input_img = Image.open('imgs/test.png').convert('RGB').resize(
            (self.img_size, self.img_size), Image.BICUBIC)
        input_tensor = get_input_tensor(input_img).unsqueeze(0).to(self.device)

        self.G.eval()
        for i in range(450):
            with torch.no_grad():
                out = self.G(input_tensor)

            out_denormalized = denormalize(out.squeeze()).cpu()
            out_img = toPIL(out_denormalized)
            out_img.save('{0}/{1:04d}.png'.format(output_dir, i))

            input_tensor = out

        self.G.train()

        cmd = 'ffmpeg -r 30 -i {}/%04d.png -vcodec libx264 -pix_fmt yuv420p -r 30 {}/movie.mp4'.format(
            output_dir, output_dir)
        subprocess.call(cmd.split())
Exemple #18
0
def train(batchsize, epochs):
    dataset = dset.ImageFolder(root="./data/",
                               transform=transforms.Compose([
                                   transforms.Resize(64),
                                   transforms.CenterCrop(64),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                              ]))
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batchsize, shuffle=True, num_workers=1)
    

    nz = 150
    netG = Generator(nz, (64,64,3))
    netG = netG.cuda()
    netG.apply(weights_init)
    netD = Discriminator((64,64,3))
    netD = netD.cuda()
    netD.apply(weights_init)

    optimizerD = optim.RMSprop(netD.parameters(), lr=0.00005, alpha=0.9)
    optimizerG = optim.RMSprop(netG.parameters(), lr=0.00005, alpha=0.9)

    img_list = []
    G_losses = []
    D_losses = []

    netG.train()
    netD.train()
    for epoch in range(epochs):
        d_loss = 0
        g_loss = 0
        count = 0
        fixed_noise = torch.randn(25, nz, 1, 1, device="cuda")
        print("Epoch: "+str(epoch)+"/"+str(epochs))
        is_d = 0
        for data in tqdm(dataloader):
            optimizerD.zero_grad()
            # Format batch
            real_cpu = data[0].to(device)
            b_size = real_cpu.size(0)

            #discriminate real image
            D_real = netD(real_cpu).view(-1)
            D_real_loss = torch.mean(D_real)
            #generate fake image from noise vector
            noise = torch.randn(b_size, nz, 1, 1, device=device)
            fake = netG(noise).detach()
            #discriminate fake image
            D_fake = netD(fake).view(-1)
            D_fake_loss = torch.mean(D_fake)

            gradient_penalty = calc_gradient_penalty(netD, real_cpu, fake , b_size)
            # discriminator loss
            D_loss =  D_fake_loss - D_real_loss + gradient_penalty
            D_loss.backward()
            # Update D
            optimizerD.step()
            d_loss += D_loss.item()
            D_losses.append(D_loss.item())
            is_d+=1
                # weight clipping
            for p in netD.parameters():
                p.data.clamp_(-0.01, 0.01)

            # update generator every 5 batch
            if is_d%5 == 0:
                is_d = 1
                # freeze discriminator
                for p in netD.parameters():
                    p.requires_grad = False
                optimizerG.zero_grad()
                #generate fake image
                noise = torch.randn(b_size, nz, 1, 1, device=device)
                fake = netG(noise)
                #to confuse discriminator
                G_fake = netD(fake).view(-1)
                #generator loss
                G_loss = -torch.mean(G_fake)
                G_loss.backward()
            # Update G
                optimizerG.step()
                g_loss += G_loss.item()
                G_losses.append(G_loss.item())
                for p in netD.parameters():
                    p.requires_grad = True
        print("D_real_loss:%.6f, D_fake_loss:%.6f"%(D_real_loss,D_fake_loss))

        # output image every 3 epoch
        if epoch%3 == 0:
            with torch.no_grad():
                test_img = netG(fixed_noise).detach().cpu()
            test_img = test_img.numpy()
            test_img = np.transpose(test_img,(0,2,3,1))
            fig, axs = plt.subplots(5, 5)
            cnt = 0
            for i in range(5):
                for j in range(5):
                    axs[i,j].imshow(test_img[cnt, :,:,:])
                    axs[i,j].axis('off')
                    cnt += 1
            fig.savefig("./output_grad/"+str(epoch)+".png")
            plt.close()
        print("d loss: "+str(d_loss)+", g loss: "+str(g_loss))
    torch.save({'g': netG.state_dict(), 'd': netD.state_dict()},"model_best")
Exemple #19
0
class EGBADTrainer:
    def __init__(self, args, data, device):
        self.args = args
        self.train_loader, _ = data
        self.device = device
        self.build_models()

    def train(self):
        """Training the AGBAD"""

        if self.args.pretrained:
            self.load_weights()

        optimizer_ge = optim.Adam(list(self.G.parameters()) +
                                  list(self.E.parameters()),
                                  lr=self.args.lr)
        optimizer_d = optim.Adam(self.D.parameters(), lr=self.args.lr)

        fixed_z = Variable(torch.randn((16, self.args.latent_dim, 1, 1)),
                           requires_grad=False).to(self.device)
        criterion = nn.BCELoss()
        for epoch in range(self.args.num_epochs + 1):
            ge_losses = 0
            d_losses = 0
            for x, _ in Bar(self.train_loader):
                #Defining labels
                y_true = Variable(torch.ones((x.size(0), 1)).to(self.device))
                y_fake = Variable(torch.zeros((x.size(0), 1)).to(self.device))

                #Noise for improving training.
                noise1 = Variable(torch.Tensor(x.size()).normal_(
                    0, 0.1 * (self.args.num_epochs - epoch) /
                    self.args.num_epochs),
                                  requires_grad=False).to(self.device)
                noise2 = Variable(torch.Tensor(x.size()).normal_(
                    0, 0.1 * (self.args.num_epochs - epoch) /
                    self.args.num_epochs),
                                  requires_grad=False).to(self.device)

                #Cleaning gradients.
                optimizer_d.zero_grad()
                optimizer_ge.zero_grad()

                #Generator:
                z_fake = Variable(torch.randn(
                    (x.size(0), self.args.latent_dim, 1, 1)).to(self.device),
                                  requires_grad=False)
                x_fake = self.G(z_fake)

                #Encoder:
                x_true = x.float().to(self.device)
                z_true = self.E(x_true)

                #Discriminator
                out_true = self.D(x_true + noise1, z_true)
                out_fake = self.D(x_fake + noise2, z_fake)

                #Losses
                loss_d = criterion(out_true, y_true) + criterion(
                    out_fake, y_fake)
                loss_ge = criterion(out_fake, y_true) + criterion(
                    out_true, y_fake)

                #Computing gradients and backpropagate.
                loss_d.backward(retain_graph=True)
                optimizer_d.step()

                loss_ge.backward()
                optimizer_ge.step()

                ge_losses += loss_ge.item()
                d_losses += loss_d.item()

            if epoch % 10 == 0:
                vutils.save_image((self.G(fixed_z).data + 1) / 2.,
                                  './images/{}_fake.png'.format(epoch))

            print(
                "Training... Epoch: {}, Discrimiantor Loss: {:.3f}, Generator Loss: {:.3f}"
                .format(epoch, d_losses / len(self.train_loader),
                        ge_losses / len(self.train_loader)))
        self.save_weights()

    def build_models(self):
        self.G = Generator(self.args.latent_dim).to(self.device)
        self.E = Encoder(self.args.latent_dim).to(self.device)
        self.D = Discriminator(self.args.latent_dim).to(self.device)
        self.G.apply(weights_init_normal)
        self.E.apply(weights_init_normal)
        self.D.apply(weights_init_normal)

    def save_weights(self):
        """Save weights."""
        state_dict_D = self.D.state_dict()
        state_dict_E = self.E.state_dict()
        state_dict_G = self.G.state_dict()
        torch.save(
            {
                'Generator': state_dict_G,
                'Encoder': state_dict_E,
                'Discriminator': state_dict_D
            }, 'weights/model_parameters.pth')

    def load_weights(self):
        """Load weights."""
        state_dict = torch.load('weights/model_parameters.pth')

        self.D.load_state_dict(state_dict['Discriminator'])
        self.G.load_state_dict(state_dict['Generator'])
        self.E.load_state_dict(state_dict['Encoder'])
Exemple #20
0
a_test_loader = torch.utils.data.DataLoader(a_test_data,
                                            batch_size=args.test_batch_size,
                                            shuffle=True,
                                            num_workers=4)
b_test_loader = torch.utils.data.DataLoader(b_test_data,
                                            batch_size=args.test_batch_size,
                                            shuffle=True,
                                            num_workers=4)

disc_a = Discriminator()
disc_b = Discriminator()
gen_a = Generator()
gen_b = Generator()

# weight initialization
disc_a.apply(weights_init_normal)
disc_b.apply(weights_init_normal)
gen_a.apply(weights_init_normal)
gen_b.apply(weights_init_normal)

MSE = nn.MSELoss()
L1 = nn.L1Loss()
cuda([disc_a, disc_b, gen_a, gen_b])

disc_a_optimizer = torch.optim.Adam(disc_a.parameters(),
                                    lr=args.lr,
                                    betas=(0.5, 0.999))
disc_b_optimizer = torch.optim.Adam(disc_b.parameters(),
                                    lr=args.lr,
                                    betas=(0.5, 0.999))
gen_a_optimizer = torch.optim.Adam(gen_a.parameters(),
Exemple #21
0
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=8)
    test_loader = DataLoader(test_data,
                             batch_size=1,
                             shuffle=False,
                             num_workers=8)

    # model setup
    G_A = Generator(3, 3).cuda()
    G_B = Generator(3, 3).cuda()
    D_A = Discriminator(3).cuda()
    D_B = Discriminator(3).cuda()
    G_A.apply(weights_init_normal)
    G_B.apply(weights_init_normal)
    D_A.apply(weights_init_normal)
    D_B.apply(weights_init_normal)

    # optimizer setup
    optimizer_G = Adam(itertools.chain(G_A.parameters(), G_B.parameters()),
                       lr=lr,
                       betas=(0.5, 0.999))
    optimizer_DA = Adam(D_A.parameters(), lr=lr, betas=(0.5, 0.999))
    optimizer_DB = Adam(D_B.parameters(), lr=lr, betas=(0.5, 0.999))
    lr_scheduler_G = LambdaLR(
        optimizer_G,
        lr_lambda=lambda eiter: 1.0 - max(0, eiter - decay) / float(decay))
    lr_scheduler_DA = LambdaLR(
        optimizer_DA,
        lr_lambda=lambda eiter: 1.0 - max(0, eiter - decay) / float(decay))
    lr_scheduler_DB = LambdaLR(
Exemple #22
0
class BiAAE(object):
    def __init__(self, params):

        self.params = params
        self.tune_dir = "{}/{}-{}/{}".format(params.exp_id, params.src_lang,
                                             params.tgt_lang,
                                             params.norm_embeddings)
        self.tune_best_dir = "{}/best".format(self.tune_dir)

        self.X_AE = AE(params)
        self.Y_AE = AE(params)
        self.D_X = Discriminator(input_size=params.d_input_size,
                                 hidden_size=params.d_hidden_size,
                                 output_size=params.d_output_size)
        self.D_Y = Discriminator(input_size=params.d_input_size,
                                 hidden_size=params.d_hidden_size,
                                 output_size=params.d_output_size)

        self.nets = [self.X_AE, self.Y_AE, self.D_X, self.D_Y]
        self.loss_fn = torch.nn.BCELoss()
        self.loss_fn2 = torch.nn.CosineSimilarity(dim=1, eps=1e-6)

    def weights_init(self, m):  # 正交初始化
        if isinstance(m, torch.nn.Linear):
            torch.nn.init.orthogonal(m.weight)
            if m.bias is not None:
                torch.nn.init.constant(m.bias, 0.01)

    def weights_init2(self, m):  # xavier_normal 初始化
        if isinstance(m, torch.nn.Linear):
            torch.nn.init.xavier_normal(m.weight)
            if m.bias is not None:
                torch.nn.init.constant(m.bias, 0.01)

    def weights_init3(self, m):  # 单位阵初始化
        if isinstance(m, torch.nn.Linear):
            m.weight.data.copy_(
                torch.diag(torch.ones(self.params.g_input_size)))

    def freeze(self, m):
        for p in m.parameters():
            p.requires_grad = False

    def defreeze(self, m):
        for p in m.parameters():
            p.requires_grad = True

    def init_state(self, seed=-1):
        if torch.cuda.is_available():
            # Move the network and the optimizer to the GPU
            for net in self.nets:
                net.cuda()
            self.loss_fn = self.loss_fn.cuda()
            self.loss_fn2 = self.loss_fn2.cuda()

        print('Init3 the model...')
        self.X_AE.apply(self.weights_init)  # 可更改G初始化方式
        self.Y_AE.apply(self.weights_init)  # 可更改G初始化方式

        self.D_X.apply(self.weights_init2)
        #print(self.D_X.map1.weight)
        self.D_Y.apply(self.weights_init2)

    def train(self, src_dico, tgt_dico, src_emb, tgt_emb, seed):
        # Load data
        if not os.path.exists(self.params.data_dir):
            print("Data path doesn't exists: %s" % self.params.data_dir)
        if not os.path.exists(self.tune_dir):
            os.makedirs(self.tune_dir)
        if not os.path.exists(self.tune_best_dir):
            os.makedirs(self.tune_best_dir)

        src_word2id = src_dico[1]
        tgt_word2id = tgt_dico[1]
        en = src_emb
        it = tgt_emb

        #eval = Evaluator(self.params, en,it, torch.cuda.is_available())

        AE_optimizer = optim.SGD(filter(
            lambda p: p.requires_grad,
            list(self.X_AE.parameters()) + list(self.Y_AE.parameters())),
                                 lr=self.params.g_learning_rate)
        D_optimizer = optim.SGD(list(self.D_X.parameters()) +
                                list(self.D_Y.parameters()),
                                lr=self.params.d_learning_rate)

        D_A_acc_epochs = []
        D_B_acc_epochs = []
        D_A_loss_epochs = []
        D_B_loss_epochs = []
        d_loss_epochs = []
        G_AB_loss_epochs = []
        G_BA_loss_epochs = []
        G_AB_recon_epochs = []
        G_BA_recon_epochs = []
        g_loss_epochs = []
        L_Z_loss_epoches = []

        acc_epochs = []

        criterion_epochs = []
        best_valid_metric = -100

        try:
            for epoch in range(self.params.num_epochs):
                D_A_losses = []
                D_B_losses = []
                G_AB_losses = []
                G_AB_recon = []
                G_BA_losses = []
                G_adv_losses = []
                G_BA_recon = []
                L_Z_losses = []
                d_losses = []
                g_losses = []
                hit_A = 0
                hit_B = 0
                total = 0
                start_time = timer()
                # lowest_loss = 1e5
                label_D = to_variable(
                    torch.FloatTensor(2 * self.params.mini_batch_size).zero_())
                label_D[:self.params.
                        mini_batch_size] = 1 - self.params.smoothing
                label_D[self.params.mini_batch_size:] = self.params.smoothing

                label_G = to_variable(
                    torch.FloatTensor(self.params.mini_batch_size).zero_())
                label_G = label_G + 1 - self.params.smoothing

                for mini_batch in range(
                        0, self.params.iters_in_epoch //
                        self.params.mini_batch_size):
                    for d_index in range(self.params.d_steps):
                        D_optimizer.zero_grad()  # Reset the gradients
                        self.D_X.train()
                        self.D_Y.train()

                        view_X, view_Y = self.get_batch_data_fast(en, it)

                        # Discriminator X
                        Y_Z = self.Y_AE.encode(view_Y).detach()
                        fake_X = self.X_AE.decode(Y_Z).detach()
                        input = torch.cat([view_X, fake_X], 0)

                        pred_A = self.D_X(input)
                        D_A_loss = self.loss_fn(pred_A, label_D)

                        # Discriminator Y
                        X_Z = self.X_AE.encode(view_X).detach()
                        fake_Y = self.Y_AE.decode(X_Z).detach()

                        input = torch.cat([view_Y, fake_Y], 0)
                        pred_B = self.D_Y(input)
                        D_B_loss = self.loss_fn(pred_B, label_D)

                        D_loss = D_A_loss + self.params.gate * D_B_loss

                        D_loss.backward(
                        )  # compute/store gradients, but don't change params
                        d_losses.append(to_numpy(D_loss.data))
                        D_A_losses.append(to_numpy(D_A_loss.data))
                        D_B_losses.append(to_numpy(D_B_loss.data))

                        discriminator_decision_A = to_numpy(pred_A.data)
                        hit_A += np.sum(
                            discriminator_decision_A[:self.params.
                                                     mini_batch_size] >= 0.5)
                        hit_A += np.sum(
                            discriminator_decision_A[self.params.
                                                     mini_batch_size:] < 0.5)

                        discriminator_decision_B = to_numpy(pred_B.data)
                        hit_B += np.sum(
                            discriminator_decision_B[:self.params.
                                                     mini_batch_size] >= 0.5)
                        hit_B += np.sum(
                            discriminator_decision_B[self.params.
                                                     mini_batch_size:] < 0.5)

                        D_optimizer.step(
                        )  # Only optimizes D's parameters; changes based on stored gradients from backward()

                        # Clip weights
                        #_clip(self.D_X, self.params.clip_value)
                        #_clip(self.D_Y, self.params.clip_value)

                        sys.stdout.write(
                            "[%d/%d] :: Discriminator Loss: %.3f \r" %
                            (mini_batch, self.params.iters_in_epoch //
                             self.params.mini_batch_size,
                             np.asscalar(np.mean(d_losses))))
                        sys.stdout.flush()

                    total += 2 * self.params.mini_batch_size * self.params.d_steps

                    for g_index in range(self.params.g_steps):
                        # 2. Train G on D's response (but DO NOT train D on these labels)
                        AE_optimizer.zero_grad()
                        self.D_X.eval()
                        self.D_Y.eval()
                        view_X, view_Y = self.get_batch_data_fast(en, it)

                        # Generator X_AE
                        ## adversarial loss
                        X_Z = self.X_AE.encode(view_X)
                        X_recon = self.X_AE.decode(X_Z)
                        Y_fake = self.Y_AE.decode(X_Z)
                        pred_Y = self.D_Y(Y_fake)
                        L_adv_X = self.loss_fn(pred_Y, label_G)

                        L_recon_X = 1.0 - torch.mean(
                            self.loss_fn2(view_X, X_recon))

                        # Generator Y_AE
                        # adversarial loss
                        Y_Z = self.Y_AE.encode(view_Y)
                        Y_recon = self.Y_AE.decode(Y_Z)
                        X_fake = self.X_AE.decode(Y_Z)
                        pred_X = self.D_X(X_fake)
                        L_adv_Y = self.loss_fn(pred_X, label_G)

                        ### autoAE Loss
                        L_recon_Y = 1.0 - torch.mean(
                            self.loss_fn2(view_Y, Y_recon))

                        # cross-lingual Loss
                        L_Z = 1.0 - torch.mean(self.loss_fn2(X_Z, Y_Z))

                        G_loss = self.params.adv_weight * (self.params.gate*L_adv_X + L_adv_Y) + \
                                self.params.mono_weight * (L_recon_X+L_recon_Y) + \
                                self.params.cross_weight * L_Z

                        G_loss.backward()

                        g_losses.append(to_numpy(G_loss.data))
                        G_AB_losses.append(to_numpy(L_adv_X.data))
                        G_BA_losses.append(to_numpy(L_adv_Y.data))
                        G_adv_losses.append(
                            to_numpy(L_adv_Y.data + L_adv_X.data))
                        G_AB_recon.append(to_numpy(L_recon_X.data))
                        G_BA_recon.append(to_numpy(L_recon_Y.data))
                        L_Z_losses.append(to_numpy(L_Z.data))

                        AE_optimizer.step()  # Only optimizes G's parameters

                        sys.stdout.write(
                            "[%d/%d] ::                                     Generator Loss: %.3f \r"
                            % (mini_batch, self.params.iters_in_epoch //
                               self.params.mini_batch_size,
                               np.asscalar(np.mean(g_losses))))
                        sys.stdout.flush()
                '''for each epoch'''

                D_A_acc_epochs.append(hit_A / total)
                D_B_acc_epochs.append(hit_B / total)
                G_AB_loss_epochs.append(np.asscalar(np.mean(G_AB_losses)))
                G_BA_loss_epochs.append(np.asscalar(np.mean(G_BA_losses)))
                D_A_loss_epochs.append(np.asscalar(np.mean(D_A_losses)))
                D_B_loss_epochs.append(np.asscalar(np.mean(D_B_losses)))
                G_AB_recon_epochs.append(np.asscalar(np.mean(G_AB_recon)))
                G_BA_recon_epochs.append(np.asscalar(np.mean(G_BA_recon)))
                L_Z_loss_epoches.append(np.asscalar(np.mean(L_Z_losses)))
                d_loss_epochs.append(np.asscalar(np.mean(d_losses)))
                g_loss_epochs.append(np.asscalar(np.mean(g_losses)))

                print(
                    "Epoch {} : Discriminator Loss: {:.3f}, Discriminator Accuracy: {:.3f}, Generator Loss: {:.3f}, Time elapsed {:.2f} mins"
                    .format(epoch, np.asscalar(np.mean(d_losses)),
                            0.5 * (hit_A + hit_B) / total,
                            np.asscalar(np.mean(g_losses)),
                            (timer() - start_time) / 60))

                if (epoch + 1) % self.params.print_every == 0:
                    # No need for discriminator weights

                    X_Z = self.X_AE.encode(Variable(en)).data
                    Y_Z = self.Y_AE.encode(Variable(it)).data

                    mstart_time = timer()
                    for method in [self.params.eval_method]:
                        results = get_word_translation_accuracy(
                            self.params.src_lang,
                            src_word2id,
                            X_Z,
                            self.params.tgt_lang,
                            tgt_word2id,
                            Y_Z,
                            method=method,
                            dico_eval=self.params.eval_file)
                        acc1 = results[0][1]

                    print('{} takes {:.2f}s'.format(method,
                                                    timer() - mstart_time))
                    print('Method:{} score:{:.4f}'.format(method, acc1))

                    csls, size = dist_mean_cosine(self.params, X_Z, Y_Z)
                    criterion = size
                    if criterion > best_valid_metric:
                        print("New criterion value: {}".format(criterion))
                        best_valid_metric = criterion
                        fp = open(
                            self.tune_best_dir +
                            "/seed_{}_dico_{}_gate_{}_epoch_{}_acc_{:.3f}.tmp".
                            format(seed, self.params.dico_build,
                                   self.params.gate, epoch, acc1), 'w')
                        fp.close()
                        torch.save(
                            self.X_AE.state_dict(), self.tune_best_dir +
                            '/seed_{}_dico_{}_gate_{}_best_X.t7'.format(
                                seed, self.params.dico_build,
                                self.params.gate))
                        torch.save(
                            self.Y_AE.state_dict(), self.tune_best_dir +
                            '/seed_{}_dico_{}_gate_{}_best_Y.t7'.format(
                                seed, self.params.dico_build,
                                self.params.gate))
                        torch.save(
                            self.D_X.state_dict(), self.tune_best_dir +
                            '/seed_{}_dico_{}_gate_{}_best_Dx.t7'.format(
                                seed, self.params.dico_build,
                                self.params.gate))
                        torch.save(
                            self.D_Y.state_dict(), self.tune_best_dir +
                            '/seed_{}_dico_{}_gate_{}__best_Dy.t7'.format(
                                seed, self.params.dico_build,
                                self.params.gate))

                    # Saving generator weights
                    fp = open(
                        self.tune_dir +
                        "/seed_{}_gate_{}_epoch_{}_acc_{:.3f}.tmp".format(
                            seed, self.params.gate, epoch, acc1), 'w')
                    fp.close()

                    acc_epochs.append(acc1)
                    criterion_epochs.append(criterion)

            criterion_fb, epoch_fb = max([
                (score, index) for index, score in enumerate(criterion_epochs)
            ])
            fp = open(
                self.tune_best_dir +
                "/seed_{}_dico_{}_gate_{}_epoch_{}_Acc_{:.3f}_{:.4f}.cslsfb".
                format(seed, self.params.gate, self.params.dico_build,
                       epoch_fb, acc_epochs[epoch_fb], criterion_fb), 'w')
            fp.close()

            # Save the plot for discriminator accuracy and generator loss
            fig = plt.figure()
            plt.plot(range(0, len(D_A_acc_epochs)),
                     D_A_acc_epochs,
                     color='b',
                     label='D_A')
            plt.plot(range(0, len(D_B_acc_epochs)),
                     D_B_acc_epochs,
                     color='r',
                     label='D_B')
            plt.ylabel('D_accuracy')
            plt.xlabel('epochs')
            plt.legend()
            fig.savefig(self.tune_dir + '/seed_{}_D_acc.png'.format(seed))

            fig = plt.figure()
            plt.plot(range(0, len(D_A_loss_epochs)),
                     D_A_loss_epochs,
                     color='b',
                     label='D_A')
            plt.plot(range(0, len(D_B_loss_epochs)),
                     D_B_loss_epochs,
                     color='r',
                     label='D_B')
            plt.ylabel('D_losses')
            plt.xlabel('epochs')
            plt.legend()
            fig.savefig(self.tune_dir + '/seed_{}_D_loss.png'.format(seed))

            fig = plt.figure()
            plt.plot(range(0, len(G_AB_loss_epochs)),
                     G_AB_loss_epochs,
                     color='b',
                     label='G_AB')
            plt.plot(range(0, len(G_BA_loss_epochs)),
                     G_BA_loss_epochs,
                     color='r',
                     label='G_BA')
            plt.ylabel('G_losses')
            plt.xlabel('epochs')
            plt.legend()
            fig.savefig(self.tune_dir + '/seed_{}_G_loss.png'.format(seed))

            fig = plt.figure()
            plt.plot(range(0, len(G_AB_recon_epochs)),
                     G_AB_recon_epochs,
                     color='b',
                     label='G_AB')
            plt.plot(range(0, len(G_BA_recon_epochs)),
                     G_BA_recon_epochs,
                     color='r',
                     label='G_BA')
            plt.ylabel('G_recon_loss')
            plt.xlabel('epochs')
            plt.legend()
            fig.savefig(self.tune_dir + '/seed_{}_G_Recon.png'.format(seed))

            # fig = plt.figure()
            # plt.plot(range(0, len(L_Z_loss_epoches)), L_Z_loss_epoches, color='b', label='L_Z')
            # plt.ylabel('L_Z_loss')
            # plt.xlabel('epochs')
            # plt.legend()
            # fig.savefig(tune_dir + '/seed_{}_L_Z.png'.format(seed))

            fig = plt.figure()
            plt.plot(range(0, len(acc_epochs)),
                     acc_epochs,
                     color='b',
                     label='trans_acc1')
            plt.ylabel('trans_acc')
            plt.xlabel('epochs')
            plt.legend()
            fig.savefig(self.tune_dir + '/seed_{}_trans_acc.png'.format(seed))
            '''
            fig = plt.figure()
            plt.plot(range(0, len(csls_epochs)), csls_epochs, color='b', label='csls')
            plt.ylabel('csls')
            plt.xlabel('epochs')
            plt.legend()
            fig.savefig(self.tune_dir + '/seed_{}_csls.png'.format(seed))
            '''
            fig = plt.figure()
            plt.plot(range(0, len(g_loss_epochs)),
                     g_loss_epochs,
                     color='b',
                     label='G_loss')
            plt.ylabel('g_loss')
            plt.xlabel('epochs')
            plt.legend()
            fig.savefig(self.tune_dir + '/seed_{}_g_loss.png'.format(seed))

            fig = plt.figure()
            plt.plot(range(0, len(d_loss_epochs)),
                     d_loss_epochs,
                     color='b',
                     label='csls')
            plt.ylabel('D_loss')
            plt.xlabel('epochs')
            plt.legend()
            fig.savefig(self.tune_dir + '/seed_{}_d_loss.png'.format(seed))
            plt.close('all')

        except KeyboardInterrupt:
            print("Interrupted.. saving model !!!")
            torch.save(self.X_AE.state_dict(),
                       self.tune_dir + '/X_AE_model_interrupt.t7')
            torch.save(self.Y_AE.state_dict(),
                       self.tune_dir + '/Y_AE_model_interrupt.t7')
            torch.save(self.D_X.state_dict(),
                       self.tune_dir + '/D_X_model_interrupt.t7')
            torch.save(self.D_Y.state_dict(),
                       self.tune_dir + '/D_y_model_interrupt.t7')
            exit()

        return

    def get_batch_data_fast(self, emb_en, emb_it):

        params = self.params
        random_en_indices = torch.LongTensor(params.mini_batch_size).random_(
            params.most_frequent_sampling_size)
        random_it_indices = torch.LongTensor(params.mini_batch_size).random_(
            params.most_frequent_sampling_size)
        en_batch = to_variable(emb_en)[random_en_indices.cuda()]
        it_batch = to_variable(emb_it)[random_it_indices.cuda()]

        return en_batch, it_batch
Exemple #23
0
def train(args):

    # Decide which device we want to run on
    device = torch.device("cuda:0" if (
        torch.cuda.is_available() and args.ngpu > 0) else "cpu")

    dataloader = load_data(args)

    # Plot some training images
    real_batch = next(iter(dataloader))
    plt.figure(figsize=(8, 8))
    plt.axis("off")
    plt.title("Training Images")
    plt.imshow(
        np.transpose(
            vutils.make_grid(real_batch[0].to(device)[:64],
                             padding=2,
                             normalize=True).cpu(), (1, 2, 0)))

    # Create the generator
    netG = Generator(args.ngpu, args.nc, args.nz, args.ngf).to(device)

    # Handle multi-gpu if desired
    if (device.type == 'cuda') and (args.ngpu > 1):
        netG = nn.DataParallel(netG, list(range(args.ngpu)))

    # Apply the weights_init function to randomly initialize all weights
    # to mean=0, stdev=0.2.
    netG.apply(weights_init)

    # Print the model
    print(netG)

    # Create the Discriminator
    netD = Discriminator(args.ngpu, args.nc, args.ndf).to(device)

    # Handle multi-gpu if desired
    if (device.type == 'cuda') and (args.ngpu > 1):
        netD = nn.DataParallel(netD, list(range(args.ngpu)))

    # Apply the weights_init function to randomly initialize all weights
    #  to mean=0, stdev=0.2.
    netD.apply(weights_init)

    # Print the model
    print(netD)

    # Initialize BCELoss function
    criterion = nn.BCELoss()

    # Create batch of latent vectors that we will use to visualize
    #  the progression of the generator
    fixed_noise = torch.randn(64, args.nz, 1, 1, device=device)

    # Establish convention for real and fake labels during training
    real_label = 1
    fake_label = 0

    # Setup Adam optimizers for both G and D
    optimizerD = optim.Adam(netD.parameters(),
                            lr=args.lr,
                            betas=(args.beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(),
                            lr=args.lr,
                            betas=(args.beta1, 0.999))

    # Lists to keep track of progress
    img_list = []
    G_losses = []
    D_losses = []
    losses_for_csv = []
    iters = 0

    print("Starting Training Loop...")
    # For each epoch
    for epoch in range(args.num_epochs):
        print("Epoch: {}".format(epoch))
        # For each batch in the dataloader
        for i, data in enumerate(tqdm(dataloader, 0)):

            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            ## Train with all-real batch
            netD.zero_grad()
            # Format batch
            real_cpu = data[0].to(device)
            b_size = real_cpu.size(0)
            label = torch.full((b_size, ), real_label, device=device)
            # Forward pass real batch through D
            output = netD(real_cpu).view(-1)
            # Calculate loss on all-real batch
            errD_real = criterion(output, label)
            # Calculate gradients for D in backward pass
            errD_real.backward()
            D_x = output.mean().item()

            ## Train with all-fake batch
            # Generate batch of latent vectors
            noise = torch.randn(b_size, args.nz, 1, 1, device=device)
            # Generate fake image batch with G
            fake = netG(noise)
            label.fill_(fake_label)
            # Classify all fake batch with D
            output = netD(fake.detach()).view(-1)
            # Calculate D's loss on the all-fake batch
            errD_fake = criterion(output, label)
            # Calculate the gradients for this batch
            errD_fake.backward()
            D_G_z1 = output.mean().item()
            # Add the gradients from the all-real and all-fake batches
            errD = errD_real + errD_fake
            # Update D
            optimizerD.step()

            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            netG.zero_grad()
            label.fill_(real_label)  # fake labels are real for generator cost
            # Since we just updated D, perform another forward pass of all-fake batch through D
            output = netD(fake).view(-1)
            # Calculate G's loss based on this output
            errG = criterion(output, label)
            # Calculate gradients for G
            errG.backward()
            D_G_z2 = output.mean().item()
            # Update G
            optimizerG.step()

            # Output training stats
            if i % 50 == 0:
                print(
                    '[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                    % (epoch, args.num_epochs, i, len(dataloader), errD.item(),
                       errG.item(), D_x, D_G_z1, D_G_z2))

            # Save Losses for plotting later
            G_losses.append(errG.item())
            D_losses.append(errD.item())
            losses_for_csv.append([iters, errG.item(), errD.item()])

            # Check how the generator is doing by saving G's output on fixed_noise
            if (iters % 500 == 0) or ((epoch == args.num_epochs - 1) and
                                      (i == len(dataloader) - 1)):
                with torch.no_grad():
                    fake = netG(fixed_noise).detach().cpu()
                img_list.append(
                    vutils.make_grid(fake, padding=2, normalize=True))

                vutils.save_image(fake,
                                  os.path.join(args.save_sample_path,
                                               f"fake_iter_{iters:03}.png"),
                                  nrow=8,
                                  range=(-1.0, 1.0),
                                  normalize=True)

            iters += 1

        if epoch % 10 == 0 or args.num_epochs - 1 == 0:
            torch.save(
                netG.state_dict(),
                os.path.join(args.save_model_path, f"gen_{epoch:03}.pt"))
            torch.save(
                netD.state_dict(),
                os.path.join(args.save_model_path, f"dis_{epoch:03}.pt"))

    save_file_name = os.path.join(args.save_image_path, "Gen_Dis_loss.png")
    plt.figure(figsize=(10, 5))
    plt.title("Generator and Discriminator Loss During Training")
    plt.plot(G_losses, label="G")
    plt.plot(D_losses, label="D")
    plt.xlabel("iterations")
    plt.ylabel("Loss")
    plt.legend()
    print("Saving ==>> {}".format(save_file_name))
    plt.savefig(save_file_name)

    save_file_name = os.path.join(args.save_image_path, "Gen_animation.gif")
    fig = plt.figure(figsize=(8, 8))
    plt.axis("off")
    ims = [[plt.imshow(np.transpose(i, (1, 2, 0)), animated=True)]
           for i in img_list]
    ani = animation.ArtistAnimation(fig,
                                    ims,
                                    interval=1000,
                                    repeat_delay=1000,
                                    blit=True)
    print("Saving ==>> {}".format(save_file_name))
    ani.save(save_file_name, writer='imagemagick', fps=4)

    # Grab a batch of real images from the dataloader
    real_batch = next(iter(dataloader))

    # Plot the real images
    save_file_name = os.path.join(args.save_image_path, "Gen_img.png")
    plt.figure(figsize=(15, 15))
    plt.subplot(1, 2, 1)
    plt.axis("off")
    plt.title("Real Images")
    plt.imshow(
        np.transpose(
            vutils.make_grid(real_batch[0].to(device)[:64],
                             padding=5,
                             normalize=True).cpu(), (1, 2, 0)))

    # Plot the fake images from the last epoch
    plt.subplot(1, 2, 2)
    plt.axis("off")
    plt.title("Fake Images")
    plt.imshow(np.transpose(img_list[-1], (1, 2, 0)))
    print("Saving ==>> {}".format(save_file_name))
    plt.savefig(save_file_name)

    save_file_name = os.path.join(args.save_csv_path, "Gen_Dis_loss.csv")
    df = pd.DataFrame(
        losses_for_csv,
        columns=['Iteration', 'Generator Loss', 'Discriminator Loss'])
    print("Saving ==>> {}".format(save_file_name))
    df.to_csv(save_file_name, index=False)
Exemple #24
0
def train():
    opt = parse_args()
    cuda = True if torch.cuda.is_available() else False

    input_shape = (opt.channels, opt.img_width, opt.img_height)
    FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

    transform = transforms.Compose([
        transforms.Resize(int(opt.img_height * 1.12), Image.BICUBIC),
        transforms.RandomCrop((opt.img_height, opt.img_width)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # Get dataloader
    train_loader = coco_loader(opt, mode='train', transform=transform)
    test_loader = coco_loader(opt, mode='test', transform=transform)

    # Get vgg
    vgg = VGGNet()

    # Initialize two generators and the discriminator
    shared_E = Encoder(opt.channels, opt.dim, opt.n_downsample)
    shared_D = Decoder(3, 256, opt.n_upsample)

    G_A = GeneratorA(opt.n_residual, 256, shared_E, shared_D)
    G_B = GeneratorB(opt.n_residual, 256, shared_E, shared_D)

    D_B = Discriminator(input_shape)

    # Initialize weights
    G_A.apply(weights_init_normal)
    G_B.apply(weights_init_normal)
    D_B.apply(weights_init_normal)

    # Losses
    criterion_GAN = torch.nn.MSELoss()
    criterion_pixel = torch.nn.L1Loss()

    if cuda:
        vgg = vgg.cuda().eval()
        G_A = G_A.cuda()
        G_B = G_B.cuda()
        D_B = D_B.cuda()
        criterion_GAN.cuda()
        criterion_pixel.cuda()

    optimizer_G = torch.optim.Adam(itertools.chain(G_A.parameters(),
                                                   G_B.parameters()),
                                   lr=opt.lr,
                                   betas=(0.5, 0.999))
    optimizer_D = torch.optim.Adam(D_B.parameters(),
                                   lr=opt.lr,
                                   betas=(0.5, 0.999))

    lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
        optimizer_G,
        lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)
    lr_scheduler_D = torch.optim.lr_scheduler.LambdaLR(
        optimizer_D,
        lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step)

    # Compute the style features in advance
    style_img = Variable(load_img(opt.style_img, transform).type(FloatTensor))
    style_feature = vgg(style_img)

    prev_time = time.time()
    for epoch in range(opt.epoch, opt.n_epochs):
        for batch_i, content_img in enumerate(train_loader):
            content_img = Variable(content_img.type(FloatTensor))

            valid = Variable(FloatTensor(
                np.ones((content_img.size(0), *D_B.output_shape))),
                             requires_grad=False)
            fake = Variable(FloatTensor(
                np.zeros((content_img.size(0), *D_B.output_shape))),
                            requires_grad=False)

            # ---------------------
            #  Train Generators
            # ---------------------

            optimizer_G.zero_grad()

            # 生成的图像并没有做反正则化,得保证:内容,风格,生成图,图像预处理的一致性!
            stylized_img = G_A(content_img)

            target_feature = vgg(stylized_img)
            content_feature = vgg(content_img)
            loss_st = opt.lambda_st * vgg.compute_st_loss(
                target_feature, content_feature, style_feature,
                opt.lambda_style)

            reconstructed_img = G_B(stylized_img)
            loss_adv = opt.lambda_adv * criterion_GAN(D_B(reconstructed_img),
                                                      valid)

            loss_G = loss_st + loss_adv
            loss_G.backward(retain_graph=True)
            optimizer_G.step()

            # ----------------------
            #  Train Discriminator
            # ----------------------

            optimizer_D.zero_grad()

            loss_D = criterion_GAN(D_B(content_img), valid) + criterion_GAN(
                D_B(reconstructed_img.detach()), fake)
            loss_D.backward()
            optimizer_D.step()

            # ------------------
            # Log Information
            # ------------------

            batches_done = epoch * len(train_loader) + batch_i
            batches_left = opt.n_epochs * len(train_loader) - batches_done
            time_left = datetime.timedelta(seconds=batches_left *
                                           (time.time() - prev_time))
            prev_time = time.time()

            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] ETA: %s"
                % (epoch, opt.n_epochs, batch_i, len(train_loader),
                   loss_D.item(), loss_G.item(), time_left))

            if batches_done % opt.sample_interval == 0:
                save_sample(opt.style_name, test_loader, batches_done, G_A,
                            G_B, FloatTensor)

            if batches_done % opt.checkpoint_interval == 0:
                torch.save(
                    G_A.state_dict(),
                    "checkpoints/%s/G_A_%d.pth" % (opt.style_name, epoch))
                torch.save(
                    G_B.state_dict(),
                    "checkpoints/%s/G_B_%d.pth" % (opt.style_name, epoch))

        # Update learning rates
        lr_scheduler_G.step()
        lr_scheduler_D.step()

    torch.save(G_A.state_dict(),
               "checkpoints/%s/G_A_done.pth" % opt.style_name)
    torch.save(G_B.state_dict(),
               "checkpoints/%s/G_B_done.pth" % opt.style_name)
    print("Training Process has been Done!")
Exemple #25
0
from torchvision.utils import save_image

if __name__ == '__main__':
    opt = get_config()

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    # Load Training Data
    train_data, train_loader = load_data(opt)

    # Define Generator and Discriminator
    G = Generator(opt).to(device)
    G.apply(weights_init)

    D = Discriminator(opt).to(device)
    D.apply(weights_init)

    # Define Optimizer
    G_optim = optim.Adam(G.parameters(),
                         lr=opt['lr'],
                         betas=(opt['b1'], opt['b2']))
    D_optim = optim.Adam(D.parameters(),
                         lr=opt['lr'],
                         betas=(opt['b1'], opt['b2']))

    # Loss Function
    criterion = HyperSphereLoss()

    # Load CheckPoint
    if os.path.exists(opt['checkpoint']):
        state = torch.load(opt['checkpoint'])
def train():

    hparams = get_hparams()
    model_path = os.path.join(hparams.model_path, hparams.task_name,
                              hparams.spec_opt)
    if not os.path.exists(model_path):
        os.makedirs(model_path)

    # Load Dataset Loader

    normalizer_clean = Tanhize('clean')
    normalizer_noisy = Tanhize('noisy')

    print('Load dataset2d loader')
    dataset_A_2d = npyDataset2d(hparams.dataset_root,
                                hparams.list_dir_train_A_2d,
                                hparams.frame_len,
                                normalizer=normalizer_noisy)
    dataset_B_2d = npyDataset2d(hparams.dataset_root,
                                hparams.list_dir_train_B_2d,
                                hparams.frame_len,
                                normalizer=normalizer_clean)

    dataloader_A = DataLoader(
        dataset_A_2d,
        batch_size=hparams.batch_size,
        shuffle=True,
        drop_last=True,
    )
    dataloader_B = DataLoader(
        dataset_B_2d,
        batch_size=hparams.batch_size,
        shuffle=True,
        drop_last=True,
    )

    # Load Generator / Disciminator model
    generator_A = Generator()
    generator_B = Generator()

    discriminator_A = Discriminator()
    discriminator_B = Discriminator()

    ContEncoder_A = ContentEncoder()
    ContEncoder_B = ContentEncoder()

    StEncoder_A = StyleEncoder()
    StEncoder_B = StyleEncoder()

    generator_A.apply(weights_init)
    generator_B.apply(weights_init)

    discriminator_A.apply(weights_init)
    discriminator_B.apply(weights_init)

    ContEncoder_A.apply(weights_init)
    ContEncoder_B.apply(weights_init)

    StEncoder_A.apply(weights_init)
    StEncoder_B.apply(weights_init)

    real_label = 1
    fake_label = 0
    real_tensor = Variable(torch.FloatTensor(hparams.batch_size))
    _ = real_tensor.data.fill_(real_label)

    fake_tensor = Variable(torch.FloatTensor(hparams.batch_size))
    _ = fake_tensor.data.fill_(fake_label)

    # Define Loss function
    d = nn.MSELoss()
    bce = nn.BCELoss()

    # Cuda Process
    if hparams.cuda == True:
        print('-- Activate with CUDA --')

        generator_A = nn.DataParallel(generator_A).cuda()
        generator_B = nn.DataParallel(generator_B).cuda()
        discriminator_A = nn.DataParallel(discriminator_A).cuda()
        discriminator_B = nn.DataParallel(discriminator_B).cuda()
        ContEncoder_A = nn.DataParallel(ContEncoder_A).cuda()
        ContEncoder_B = nn.DataParallel(ContEncoder_B).cuda()
        StEncoder_A = nn.DataParallel(StEncoder_A).cuda()
        StEncoder_B = nn.DataParallel(StEncoder_B).cuda()

        d.cuda()
        bce.cuda()
        real_tensor = real_tensor.cuda()
        fake_tensor = fake_tensor.cuda()

    else:
        print('-- Activate without CUDA --')

    gen_params = chain(
        generator_A.parameters(),
        generator_B.parameters(),
        ContEncoder_A.parameters(),
        ContEncoder_B.parameters(),
        StEncoder_A.parameters(),
        StEncoder_B.parameters(),
    )

    dis_params = chain(
        discriminator_A.parameters(),
        discriminator_B.parameters(),
    )

    optimizer_g = optim.Adam(gen_params, lr=hparams.learning_rate)
    optimizer_d = optim.Adam(dis_params, lr=hparams.learning_rate)

    iters = 0
    for e in range(hparams.epoch_size):

        # input Tensor

        A_loader, B_loader = iter(dataloader_A), iter(dataloader_B)

        for i in range(len(A_loader) - 1):

            batch_A = A_loader.next()
            batch_B = B_loader.next()

            A_indx = torch.LongTensor(list(range(hparams.batch_size)))
            B_indx = torch.LongTensor(list(range(hparams.batch_size)))

            A_ = torch.FloatTensor(batch_A)
            B_ = torch.FloatTensor(batch_B)

            if hparams.cuda == True:

                x_A = Variable(A_.cuda())
                x_B = Variable(B_.cuda())

            else:
                x_A = Variable(A_)
                x_B = Variable(B_)

            real_tensor.data.resize_(hparams.batch_size).fill_(real_label)
            fake_tensor.data.resize_(hparams.batch_size).fill_(fake_label)

            ## Discrominator Update Steps

            discriminator_A.zero_grad()
            discriminator_B.zero_grad()

            # x_A, x_B, x_AB, x_BA
            # [#_batch, max_time_len, dim]

            A_c = ContEncoder_A(x_A).detach()
            B_c = ContEncoder_B(x_B).detach()

            # A,B :  N ~ (0,1)
            A_s = Variable(get_z_random(hparams.batch_size, 8))
            B_s = Variable(get_z_random(hparams.batch_size, 8))

            x_AB = generator_B(A_c, B_s).detach()
            x_BA = generator_A(B_c, A_s).detach()

            # We recommend LSGAN-loss for adversarial loss

            l_d_A_real = 0.5 * torch.mean(
                (discriminator_A(x_A) - real_tensor)**2)
            l_d_A_fake = 0.5 * torch.mean(
                (discriminator_A(x_BA) - fake_tensor)**2)

            l_d_B_real = 0.5 * torch.mean(
                (discriminator_B(x_B) - real_tensor)**2)
            l_d_B_fake = 0.5 * torch.mean(
                (discriminator_B(x_AB) - fake_tensor)**2)

            l_d_A = l_d_A_real + l_d_A_fake
            l_d_B = l_d_B_real + l_d_B_fake

            l_d = l_d_A + l_d_B

            l_d.backward()
            optimizer_d.step()

            ## Generator Update Steps

            generator_A.zero_grad()
            generator_B.zero_grad()
            ContEncoder_A.zero_grad()
            ContEncoder_B.zero_grad()
            StEncoder_A.zero_grad()
            StEncoder_B.zero_grad()

            A_c = ContEncoder_A(x_A)
            B_c = ContEncoder_B(x_B)

            A_s_prime = StEncoder_A(x_A)
            B_s_prime = StEncoder_B(x_B)

            # A,B : N ~ (0,1)
            A_s = Variable(get_z_random(hparams.batch_size, 8))
            B_s = Variable(get_z_random(hparams.batch_size, 8))

            x_BA = generator_A(B_c, A_s)
            x_AB = generator_B(A_c, B_s)

            x_A_recon = generator_A(A_c, A_s_prime)
            x_B_recon = generator_B(B_c, B_s_prime)

            B_c_recon = ContEncoder_A(x_BA)
            A_s_recon = StEncoder_A(x_BA)

            A_c_recon = ContEncoder_B(x_AB)
            B_s_recon = StEncoder_B(x_AB)

            x_ABA = generator_A(A_c_recon, A_s_prime)
            x_BAB = generator_B(B_c_recon, B_s_prime)

            l_cy_A = recon_criterion(x_ABA, x_A)
            l_cy_B = recon_criterion(x_BAB, x_B)

            l_f_A = recon_criterion(x_A_recon, x_A)
            l_f_B = recon_criterion(x_B_recon, x_B)

            l_c_A = recon_criterion(A_c_recon, A_c)
            l_c_B = recon_criterion(B_c_recon, B_c)

            l_s_A = recon_criterion(A_s_recon, A_s)
            l_s_B = recon_criterion(B_s_recon, B_s)

            # We recommend LSGAN-loss for adversarial loss

            l_gan_A = 0.5 * torch.mean(
                (discriminator_A(x_BA) - real_tensor)**2)
            l_gan_B = 0.5 * torch.mean(
                (discriminator_B(x_AB) - real_tensor)**2)

            l_g = l_gan_A + l_gan_B + lambda_f * (l_f_A + l_f_B) + lambda_s * (
                l_s_A + l_s_B) + lambda_c * (l_c_A + l_c_B) + lambda_cy * (
                    l_cy_A + l_cy_B)

            l_g.backward()
            optimizer_g.step()

            if iters % hparams.log_interval == 0:
                print("---------------------")

                print("Gen Loss :{} disc loss :{}".format(
                    l_g / hparams.batch_size, l_d / hparams.batch_size))
                print("epoch :", e, " ", "total ", hparams.epoch_size)
                print("iteration :", iters)

            if iters % hparams.model_save_interval == 0:
                torch.save(
                    generator_A.state_dict(),
                    os.path.join(model_path,
                                 'model_gen_A_{}.pth'.format(iters)))
                torch.save(
                    generator_B.state_dict(),
                    os.path.join(model_path,
                                 'model_gen_B_{}.pth'.format(iters)))
                torch.save(
                    discriminator_A.state_dict(),
                    os.path.join(model_path,
                                 'model_dis_A_{}.pth'.format(iters)))
                torch.save(
                    discriminator_B.state_dict(),
                    os.path.join(model_path,
                                 'model_dis_B_{}.pth'.format(iters)))

                torch.save(
                    ContEncoder_A.state_dict(),
                    os.path.join(model_path,
                                 'model_ContEnc_A_{}.pth'.format(iters)))
                torch.save(
                    ContEncoder_B.state_dict(),
                    os.path.join(model_path,
                                 'model_ContEnc_B_{}.pth'.format(iters)))
                torch.save(
                    StEncoder_A.state_dict(),
                    os.path.join(model_path,
                                 'model_StEnc_A_{}.pth'.format(iters)))
                torch.save(
                    StEncoder_B.state_dict(),
                    os.path.join(model_path,
                                 'model_StEnc_B_{}.pth'.format(iters)))

            iters += 1
Exemple #27
0
'''

# Generator(ngpu, nc, nz, ngf)
netG = Generator(params['ngpu'], params['nc'], params['nz'],
                 params['ngf']).to(device)
''' this part can be used later when ngpu > 1
# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
    netG = nn.DataParallel(netG, list(range(ngpu)))
'''
netG.apply(weights_init)
print(netG)

# Discriminator(ngpu, nc, ndf)
netD = Discriminator(params['ngpu'], params['nc'], params['ndf']).to(device)
netD.apply(weights_init)
print(netD)

# Criterion & Optimizer
criterion = nn.BCELoss()
fixed_noise = torch.randn(64, params['nz'], 1, 1).to(device)
real_label = 1
fake_label = 0
optimizerD = optim.Adam(netD.parameters(),
                        lr=params['lr'],
                        betas=(params['beta1'], 0.999))
optimizerG = optim.Adam(netG.parameters(),
                        lr=params['lr'],
                        betas=(params['beta1'], 0.999))

Exemple #28
0
class Solver(object):
    def __init__(self, trainset_loader, config):
        self.use_cuda = torch.cuda.is_available()
        self.device = torch.device('cuda' if self.use_cuda else 'cpu')
        self.trainset_loader = trainset_loader
        self.nc = config.nc
        self.nz = config.nz
        self.ngf = config.ngf
        self.ndf = config.ndf
        self.n_epochs = config.n_epochs
        self.resume_iters = config.resume_iters
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.exp_name = config.name
        os.makedirs(config.ckp_dir, exist_ok=True)
        self.ckp_dir = os.path.join(config.ckp_dir, self.exp_name)
        os.makedirs(self.ckp_dir, exist_ok=True)
        self.example_dir = os.path.join(self.ckp_dir, "output")
        os.makedirs(self.example_dir, exist_ok=True)
        self.log_interval = config.log_interval
        self.save_interval = config.save_interval
        self.use_wandb = config.use_wandb

        self.build_model()

    def build_model(self):
        self.G = Generator(nc=self.nc, ngf=self.ngf,
                           nz=self.nz).to(self.device)
        self.D = Discriminator(nc=self.nc, ndf=self.ndf).to(self.device)
        self.G.apply(weights_init)
        self.D.apply(weights_init)
        self.g_optimizer = torch.optim.Adam(self.G.parameters(),
                                            lr=self.g_lr,
                                            betas=[self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(self.D.parameters(),
                                            lr=self.d_lr,
                                            betas=[self.beta1, self.beta2])

    def save_checkpoint(self, step):
        G_path = os.path.join(self.ckp_dir, '{}-G.pth'.format(step + 1))
        D_path = os.path.join(self.ckp_dir, '{}-D.pth'.format(step + 1))
        torch.save(self.G.state_dict(), G_path)
        torch.save(self.D.state_dict(), D_path)
        print('Saved model checkpoints into {}...'.format(self.ckp_dir))

    def load_checkpoint(self, resume_iters):
        print(
            'Loading the trained models from step {}...'.format(resume_iters))
        G_path = os.path.join(self.ckp_dir, '{}-G.pth'.format(resume_iters))
        D_path = os.path.join(self.ckp_dir, '{}-D.pth'.format(resume_iters))
        self.G.load_state_dict(torch.load(G_path))
        self.D.load_state_dict(torch.load(D_path))

    def reset_grad(self):
        """Reset the gradient buffers."""
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()

    def train(self):
        criterion = nn.BCELoss()
        torch.manual_seed(66666)
        fixed_noise = torch.randn(32, self.nz, 1, 1, device=self.device)
        real_label = 1.
        fake_label = 0.

        iteration = 0
        if self.resume_iters:
            print("resuming step %d ..." % self.resume_iters)
            iteration = self.resume_iters
            self.load_checkpoint(self.resume_iters)

        for ep in range(self.n_epochs):
            self.G.train()
            self.D.train()

            D_loss_t = 0.0
            D_x_t = 0.0
            D_G_z1_t = 0.0
            G_loss_t = 0.0
            D_G_z2_t = 0.0

            for batch_idx, (real_data, _) in enumerate(self.trainset_loader):
                ################
                #   update D   #
                ################
                self.D.zero_grad()

                real_data = real_data.to(self.device)
                b_size = real_data.size(0)
                label = torch.full((b_size, ),
                                   real_label,
                                   dtype=torch.float,
                                   device=self.device)
                output = self.D(real_data).view(-1)
                D_loss_real = criterion(output, label)
                D_loss_real.backward()
                D_x = output.mean().item()
                D_x_t += D_x

                noise = torch.randn(b_size, self.nz, 1, 1, device=self.device)
                fake = self.G(noise)
                label.fill_(fake_label)
                output = self.D(fake.detach()).view(-1)
                D_loss_fake = criterion(output, label)
                D_loss_fake.backward()
                D_G_z1 = output.mean().item()
                D_G_z1_t += D_G_z1
                D_loss = D_loss_real + D_loss_fake
                D_loss_t += D_loss.item()
                self.d_optimizer.step()

                ################
                #   update G   #
                ################
                self.G.zero_grad()
                label.fill_(real_label)
                output = self.D(fake).view(-1)
                G_loss = criterion(output, label)
                G_loss.backward()
                G_loss_t += G_loss.item()
                D_G_z2 = output.mean().item()
                D_G_z2_t += D_G_z2
                self.g_optimizer.step()

                # Output training stats
                if (iteration + 1) % self.log_interval == 0:
                    print(
                        'Epoch: {:3d} [{:5d}/{:5d} ({:3.0f}%)]\tIteration: {:5d}\tD_loss: {:.6f}\tG_loss: {:.6f}\tD(x): {:.6f}\tD(G(z)): {:.6f} / {:.6f}'
                        .format(
                            ep, (batch_idx + 1) * len(real_data),
                            len(self.trainset_loader.dataset),
                            100. * (batch_idx + 1) / len(self.trainset_loader),
                            iteration + 1, D_loss.item(), G_loss.item(), D_x,
                            D_G_z1, D_G_z2))

                # Save model checkpoints
                if (iteration + 1) % self.save_interval == 0 and iteration > 0:
                    self.save_checkpoint(iteration)
                    g_example = self.G(fixed_noise)
                    g_example_path = os.path.join(self.example_dir,
                                                  '%d.png' % (iteration + 1))
                    torchvision.utils.save_image(g_example.data,
                                                 g_example_path,
                                                 nrow=8,
                                                 normalize=True)

                iteration += 1

            print(
                'Epoch: {:3d} [{:5d}/{:5d} ({:3.0f}%)]\tIteration: {:5d}\tD_loss: {:.6f}\tG_loss: {:.6f}\tD(x): {:.6f}\tD(G(z)): {:.6f} / {:.6f}\n'
                .format(ep, len(self.trainset_loader.dataset),
                        len(self.trainset_loader.dataset), 100., iteration,
                        D_loss_t / len(self.trainset_loader),
                        G_loss_t / len(self.trainset_loader),
                        D_x_t / len(self.trainset_loader),
                        D_G_z1_t / len(self.trainset_loader),
                        D_G_z2_t / len(self.trainset_loader)))

            if self.use_wandb:
                import wandb
                wandb.log({
                    "Loss_D": D_loss_t / len(self.trainset_loader),
                    "Loss_G": G_loss_t / len(self.trainset_loader),
                    "D(x)": D_x_t / len(self.trainset_loader),
                    "D(G(z1))": D_G_z1_t / len(self.trainset_loader),
                    "D(G(z2))": D_G_z2_t / len(self.trainset_loader)
                })

        self.save_checkpoint(iteration)
        g_example = self.G(fixed_noise)
        g_example_path = os.path.join(self.example_dir,
                                      '%d.png' % (iteration + 1))
        torchvision.utils.save_image(g_example.data,
                                     g_example_path,
                                     nrow=8,
                                     normalize=True)
Exemple #29
0
def denorm(x):
    out = (x + 1.0) / 2.0
    return nn.Tanh()(out)


num_epoch = 5
batchSize = 64
lr = 0.0002
l1_lambda = 10

text_logger = setup_logger('Train')
logger = Logger('./logs')

discriminator = Discriminator()
generator = Generator()
discriminator.apply(weights_init)
generator.apply(weights_init)
if torch.cuda.is_available():
    discriminator.cuda()
    generator.cuda()

loss_function = nn.CrossEntropyLoss()
d_optim = torch.optim.Adam(discriminator.parameters(), lr, [0.5, 0.999])
g_optim = torch.optim.Adam(generator.parameters(), lr, [0.5, 0.999])

dataloader = DataLoader(batchSize)
data_size = len(dataloader.train_index)
num_batch = data_size // batchSize
#text_logger.info('Total number of videos for train = ' + str(data_size))
#text_logger.info('Total number of batches per echo = ' + str(num_batch))
def train(dataloader):
    """ Train the model on `num_steps` batches
    Args:
        dataloader : (DataLoader) a torch.utils.data.DataLoader object that fetches training data
        num_steps : (int) # of batches to train on, each of size args.batch_size
    """

    # Define Generator, Discriminator
    G = Generator(out_channel=ch).to(device) # MNIST channel: 1, CIFAR-10 channel: 3
    D = Discriminator(in_channel=ch).to(device)

    # adversarial loss
    loss_fn = nn.BCELoss()

    # Initialize weights
    G.apply(init_weights)
    D.apply(init_weights)

    optimizer_G = optim.Adam(G.parameters(), lr=lr, betas=(b1, b2))
    optimizer_D = optim.Adam(D.parameters(), lr=lr, betas=(b1, b2))

    # Establish convention for real and fake labels during training
    real_label = 1.
    fake_label = 0.

    # -----Training----- #
    for epoch in range(epochs):
        # For each batch in the dataloader

        for i, data in enumerate(dataloader, 0):
            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            ## Train with all-real batch
            D.zero_grad()
            # Format batch
            real_cpu = data[0].to(device) # load image batch size
            b_size = real_cpu.size(0) # batch size
            label = torch.full((b_size,), real_label, dtype=torch.float, device=device, requires_grad=False) # real batch

            # Forward pass **real batch** through D
            output = D(real_cpu).view(-1)
            # Calculate loss on all-real batch
            errD_real = loss_fn(output, label)
            # Calculate gradients for D in backward pass
            errD_real.backward()

            ## Train with **all-fake** batch
            # Generate noise batch of latent vectors
            noise = torch.randn(b_size, latent_dim, 1, 1, device=device)
            # Generate fake image batch with G
            fake = G(noise)
            label.fill_(fake_label) # fake batch

            # Classify all fake batch with D
            output = D(fake.detach()).view(-1)
            # Calculate D's loss on the all-fake batch
            errD_fake = loss_fn(output, label)
            # Calculate the gradients for this batch
            errD_fake.backward()
            # Add the gradients from the all-real and all-fake batches
            errD = errD_real + errD_fake
            # Update D
            optimizer_D.step()

            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            G.zero_grad()
            label.fill_(real_label)  # fake labels are real for generator cost

            # Since we just updated D, perform another forward pass of all-fake batch through D
            output = D(fake).view(-1)
            # Calculate G's loss based on this output
            errG = loss_fn(output, label)
            # Calculate gradients for G
            errG.backward()
            # Update G
            optimizer_G.step()

            # Save fake images generated by Generator
            batches_done = epoch * len(dataloader) + i
            if batches_done % 400 == 0:
                save_image(fake.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)

        print(f"[Epoch {epoch + 1}/{epochs}] [D loss: {errD.item():.4f}] [G loss: {errG.item():.4f}]")

        # Save Generator model's parameters
        save_checkpoints(
            {'epoch': i + 1,
             'state_dict': G.state_dict(),
             'optim_dict': optimizer_G.state_dict()},
            checkpoint='./ckpt/',
            is_G=True
        )

        # Save Discriminator model's parameters
        save_checkpoints(
            {'epoch': i + 1,
             'state_dict': D.state_dict(),
             'optim_dict': optimizer_D.state_dict()},
            checkpoint='./ckpt/',
            is_G=False
        )