def main(args): # Create model directory if not os.path.exists(args.model_path): os.makedirs(args.model_path) if not os.path.exists(args.figure_path): os.makedirs(args.figure_path) # Image preprocessing # For normalization, see https://github.com/pytorch/vision#models transform = transforms.Compose([ transforms.RandomCrop(args.crop_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) # Load vocabulary wrapper. with open(args.vocab_path, 'rb') as f: vocab = pickle.load(f) # Build data loader data_loader = get_loader(args.image_dir, args.caption_path, vocab, transform, args.batch_size, shuffle=True, num_workers=args.num_workers) # Build the models (Gen) generator = Generator(args.embed_size, args.hidden_size, len(vocab), args.num_layers) # Build the models (Disc) discriminator = Discriminator(args.embed_size, args.hidden_size, len(vocab), args.num_layers) if torch.cuda.is_available(): generator.cuda() discriminator.cuda() # Loss and Optimizer (Gen) mle_criterion = nn.CrossEntropyLoss() params_gen = list(generator.parameters()) optimizer_gen = torch.optim.Adam(params_gen) # Loss and Optimizer (Disc) params_disc = list(discriminator.parameters()) optimizer_disc = torch.optim.Adam(params_disc) if int(args.pretraining) == 1: # Pre-training: train generator with MLE and discriminator with 3 losses (real + fake + wrong) total_steps = len(data_loader) print(total_steps) disc_losses = [] gen_losses = [] print('pre-training') generator.load_state_dict(torch.load(args.pretrained_gen_path)) discriminator.load_state_dict(torch.load(args.pretrained_disc_path)) for epoch in range( max([ int(args.gen_pretrain_num_epochs), int(args.disc_pretrain_num_epochs) ])): if epoch < 5: continue # for epoch in range(max([int(args.gen_pretrain_num_epochs), int(args.disc_pretrain_num_epochs)])): for i, (images, captions, lengths, wrong_captions, wrong_lengths) in enumerate(data_loader): images = to_var(images, volatile=True) captions = to_var(captions) wrong_captions = to_var(wrong_captions) targets = pack_padded_sequence(captions, lengths, batch_first=True)[0] if epoch < int(args.gen_pretrain_num_epochs): generator.zero_grad() outputs, _ = generator(images, captions, lengths) loss_gen = mle_criterion(outputs, targets) # gen_losses.append(loss_gen.cpu().data.numpy()[0]) loss_gen.backward() optimizer_gen.step() if epoch < int(args.disc_pretrain_num_epochs): discriminator.zero_grad() rewards_real = discriminator(images, captions, lengths) # rewards_fake = discriminator(images, sampled_captions, sampled_lengths) rewards_wrong = discriminator(images, wrong_captions, wrong_lengths) real_loss = -torch.mean(torch.log(rewards_real)) # fake_loss = -torch.mean(torch.clamp(torch.log(1 - rewards_fake), min=-1000)) wrong_loss = -torch.mean( torch.clamp(torch.log(1 - rewards_wrong), min=-1000)) loss_disc = real_loss + wrong_loss # + fake_loss, no fake_loss because this is pretraining # disc_losses.append(loss_disc.cpu().data.numpy()[0]) loss_disc.backward() optimizer_disc.step() if (i + 1) % args.log_step == 0: print( 'Epoch [%d], Step [%d], Disc Loss: %.4f, Gen Loss: %.4f' % (epoch + 1, i + 1, loss_disc, loss_gen)) if (i + 1) % 500 == 0: torch.save( discriminator.state_dict(), os.path.join( args.model_path, 'pretrained-discriminator-%d-%d.pkl' % (int(epoch) + 1, i + 1))) torch.save( generator.state_dict(), os.path.join( args.model_path, 'pretrained-generator-%d-%d.pkl' % (int(epoch) + 1, i + 1))) # Save pretrained models torch.save( discriminator.state_dict(), os.path.join( args.model_path, 'pretrained-discriminator-%d.pkl' % int(args.disc_pretrain_num_epochs))) torch.save( generator.state_dict(), os.path.join( args.model_path, 'pretrained-generator-%d.pkl' % int(args.gen_pretrain_num_epochs))) # Plot pretraining figures # plt.plot(disc_losses, label='pretraining_disc_loss') # plt.savefig(args.figure_path + 'pretraining_disc_losses.png') # plt.clf() # # plt.plot(gen_losses, label='pretraining_gen_loss') # plt.savefig(args.figure_path + 'pretraining_gen_losses.png') # plt.clf() else: generator.load_state_dict(torch.load(args.pretrained_gen_path)) discriminator.load_state_dict(torch.load(args.pretrained_disc_path)) # # Skip the rest for now # return # Train the Models total_step = len(data_loader) disc_gan_losses = [] gen_gan_losses = [] for epoch in range(args.num_epochs): for i, (images, captions, lengths, wrong_captions, wrong_lengths) in enumerate(data_loader): # Set mini-batch dataset images = to_var(images, volatile=True) captions = to_var(captions) wrong_captions = to_var(wrong_captions) generator.zero_grad() outputs, packed_lengths = generator(images, captions, lengths) outputs = PackedSequence(outputs, packed_lengths) outputs = pad_packed_sequence(outputs, batch_first=True) # (b, T, V) Tmax = outputs[0].size(1) if torch.cuda.is_available(): rewards = torch.zeros_like(outputs[0]).type( torch.cuda.FloatTensor) else: rewards = torch.zeros_like(outputs[0]).type(torch.FloatTensor) # getting rewards from disc # for t in tqdm(range(2, Tmax, 4)): for t in range(2, Tmax, 2): # for t in range(2, 4): if t >= min( lengths ): # TODO this makes things easier, but could min(lengths) could be too short break gen_samples = to_var(torch.zeros( (captions.size(0), Tmax)).type(torch.FloatTensor), volatile=True) # part 1: taken from real caption gen_samples[:, :t] = captions[:, :t].data predicted_ids, saved_states = generator.pre_compute( gen_samples, t) # for v in range(predicted_ids.size(1)): v = predicted_ids # pdb.set_trace() # part 2: taken from all possible vocabs # gen_samples[:,t] = predicted_ids[:,v] gen_samples[:, t] = v # part 3: taken from rollouts gen_samples[:, t:] = generator.rollout(gen_samples, t, saved_states) sampled_lengths = [] # finding sampled_lengths for batch in range(int(captions.size(0))): for b_t in range(Tmax): if gen_samples[batch, b_t].cpu().data.numpy() == 2: # <end> sampled_lengths.append(b_t + 1) break elif b_t == Tmax - 1: sampled_lengths.append(Tmax) # sort sampled_lengths sampled_lengths = np.array(sampled_lengths) sampled_lengths[::-1].sort() sampled_lengths = sampled_lengths.tolist() # get rewards from disc rewards[:, t, v] = discriminator(images, gen_samples.detach(), sampled_lengths) # rewards = rewards.detach() # pdb.set_trace() rewards_detached = rewards.data rewards_detached = to_var(rewards_detached) loss_gen = torch.dot(outputs[0], -rewards_detached) # gen_gan_losses.append(loss_gen.cpu().data.numpy()[0]) # pdb.set_trace() loss_gen.backward() optimizer_gen.step() # TODO get sampled_captions sampled_ids = generator.sample(images) # sampled_captions = torch.zeros_like(sampled_ids).type(torch.LongTensor) sampled_lengths = [] # finding sampled_lengths for batch in range(int(captions.size(0))): for b_t in range(20): #pdb.set_trace() #sampled_captions[batch, b_t].data = sampled_ids[batch, b_t].cpu().data.numpy()[0] if sampled_ids[batch, b_t].cpu().data.numpy() == 2: # <end> sampled_lengths.append(b_t + 1) break elif b_t == 20 - 1: sampled_lengths.append(20) # sort sampled_lengths sampled_lengths = np.array(sampled_lengths) sampled_lengths[::-1].sort() sampled_lengths = sampled_lengths.tolist() # Train discriminator discriminator.zero_grad() images.volatile = False captions.volatile = False wrong_captions.volatile = False rewards_real = discriminator(images, captions, lengths) rewards_fake = discriminator(images, sampled_ids, sampled_lengths) rewards_wrong = discriminator(images, wrong_captions, wrong_lengths) real_loss = -torch.mean(torch.log(rewards_real)) fake_loss = -torch.mean( torch.clamp(torch.log(1 - rewards_fake), min=-1000)) wrong_loss = -torch.mean( torch.clamp(torch.log(1 - rewards_wrong), min=-1000)) loss_disc = real_loss + fake_loss + wrong_loss # disc_gan_losses.append(loss_disc.cpu().data.numpy()[0]) loss_disc.backward() optimizer_disc.step() # Print log info if i % args.log_step == 0: print( 'Epoch [%d/%d], Step [%d/%d], Disc Loss: %.4f, Gen Loss: %.4f' % (epoch, args.num_epochs, i, total_step, loss_disc, loss_gen)) # Save the models # if (i+1) % args.save_step == 0: if ( i + 1 ) % args.log_step == 0: # jm: saving at the last iteration instead torch.save( generator.state_dict(), os.path.join( args.model_path, 'generator-gan-%d-%d.pkl' % (epoch + 1, i + 1))) torch.save( discriminator.state_dict(), os.path.join( args.model_path, 'discriminator-gan-%d-%d.pkl' % (epoch + 1, i + 1)))
def train(args): writer = SummaryWriter(log_dir=args.tensorboard_path) create_folder(args.outf) set_seed(args.manualSeed) cudnn.benchmark = True dataset, nc = get_dataset(args) dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batchSize, shuffle=True, num_workers=int(args.workers)) torch.cuda.set_device(args.local_rank) device = torch.device( "cuda", args.local_rank) #torch.device("cuda:0" if args.cuda else "cpu") ngpu = 0 nz = int(args.nz) ngf = int(args.ngf) ndf = int(args.ndf) netG = Generator(ngpu, ngf, nc, nz).to(device) netG.apply(weights_init) if args.netG != '': netG.load_state_dict(torch.load(args.netG)) netD = Discriminator(ngpu, ndf, nc).to(device) netD.apply(weights_init) if args.netD != '': netD.load_state_dict(torch.load(args.netD)) criterion = nn.BCELoss() fixed_noise = torch.randn(args.batchSize, nz, 1, 1, device=device) real_label = 1 fake_label = 0 # setup optimizer optimizerD = torch.optim.Adam(netD.parameters(), lr=args.lr, betas=(args.beta1, 0.999)) optimizerG = torch.optim.Adam(netG.parameters(), lr=args.lr, betas=(args.beta1, 0.999)) model_engineD, optimizerD, _, _ = deepspeed.initialize( args=args, model=netD, model_parameters=netD.parameters(), optimizer=optimizerD) model_engineG, optimizerG, _, _ = deepspeed.initialize( args=args, model=netG, model_parameters=netG.parameters(), optimizer=optimizerG) torch.cuda.synchronize() start = time() for epoch in range(args.epochs): for i, data in enumerate(dataloader, 0): ############################ # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) ########################### # train with real netD.zero_grad() real = data[0].to(device) batch_size = real.size(0) label = torch.full((batch_size, ), real_label, dtype=real.dtype, device=device) output = netD(real) errD_real = criterion(output, label) model_engineD.backward(errD_real) D_x = output.mean().item() # train with fake noise = torch.randn(batch_size, nz, 1, 1, device=device) fake = netG(noise) label.fill_(fake_label) output = netD(fake.detach()) errD_fake = criterion(output, label) model_engineD.backward(errD_fake) D_G_z1 = output.mean().item() errD = errD_real + errD_fake #optimizerD.step() # alternative (equivalent) step model_engineD.step() ############################ # (2) Update G network: maximize log(D(G(z))) ########################### netG.zero_grad() label.fill_(real_label) # fake labels are real for generator cost output = netD(fake) errG = criterion(output, label) model_engineG.backward(errG) D_G_z2 = output.mean().item() #optimizerG.step() # alternative (equivalent) step model_engineG.step() print( '[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f' % (epoch, args.epochs, i, len(dataloader), errD.item(), errG.item(), D_x, D_G_z1, D_G_z2)) writer.add_scalar("Loss_D", errD.item(), epoch * len(dataloader) + i) writer.add_scalar("Loss_G", errG.item(), epoch * len(dataloader) + i) if i % 100 == 0: vutils.save_image(real, '%s/real_samples.png' % args.outf, normalize=True) fake = netG(fixed_noise) vutils.save_image(fake.detach(), '%s/fake_samples_epoch_%03d.png' % (args.outf, epoch), normalize=True) # do checkpointing #torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (args.outf, epoch)) #torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (args.outf, epoch)) torch.cuda.synchronize() stop = time() print( f"total wall clock time for {args.epochs} epochs is {stop-start} secs")
optimizerG = optim.Adam(netG.parameters(), lr=0.002, betas=(0.5, 0.999)) optimizerD = optim.Adam(netD.parameters(), lr=0.002, betas=(0.5, 0.999)) fix_noise = torch.randn(batch_size, latent_size, 1, 1).cuda() lossG = [] lossD = [] Dx = [] DG1 = [] DG2 = [] for epoch in range(epoch_num): for i, data in enumerate(train_dataloader): # Update D network # train with real if cuda: data = data.cuda() netD.zero_grad() label = torch.full((data.size(0), ), 1).cuda() output = netD(data) errD_real = criterion(output, label) errD_real.backward() D_x = output.mean().item() # train with fake noise = torch.randn(data.size(0), latent_size, 1, 1).cuda() fake = netG(noise) label.fill_(0) output = netD(fake.detach()) errD_fake = criterion(output, label) errD_fake.backward() D_G_z1 = output.mean().item() errD = errD_real + errD_fake
def train_gan(): batch_size = 64 epochs = 100 disc_update = 1 gen_update = 5 latent_dimension = 100 lambduh = 10 device = torch.device( 'cuda:0') if torch.cuda.is_available() else torch.device('cpu') # load data train_loader, valid_loader, test_loader = get_data_loader( 'data', batch_size) disc_model = Discriminator().to(device) gen_model = Generator(latent_dimension).to(device) disc_optim = Adam(disc_model.parameters(), lr=1e-4, betas=(0.5, 0.9)) gen_optim = Adam(gen_model.parameters(), lr=1e-4, betas=(0.5, 0.9)) for e in range(epochs): disc_loss = 0 gen_loss = 0 for i, (images, _) in enumerate(train_loader): images = images.to(device) b_size = images.shape[0] step = i + 1 if step % disc_update == 0: disc_model.zero_grad() # sample noise noise = torch.randn((b_size, latent_dimension), device=device) # loss on fake inputs = gen_model(noise).detach() f_outputs = disc_model(inputs) loss = f_outputs.mean() # loss on real r_outputs = disc_model(images) loss -= r_outputs.mean() # add gradient penalty loss += lambduh * gradient_penalty(disc_model, images, inputs, device) disc_loss += loss loss.backward() disc_optim.step() if step % gen_update == 0: gen_model.zero_grad() noise = torch.randn((b_size, latent_dimension)).to(device) inputs = gen_model(noise) outputs = disc_model(inputs) loss = -outputs.mean() gen_loss += loss loss.backward() gen_optim.step() torch.save( { 'epoch': e, 'disc_model': disc_model.state_dict(), 'gen_model': gen_model.state_dict(), 'disc_loss': disc_loss, 'gen_loss': gen_loss, 'disc_optim': disc_optim.state_dict(), 'gen_optim': gen_optim.state_dict() }, "upsample/checkpoint_{}.pth".format(e)) print("Epoch: {} Disc loss: {}".format( e + 1, disc_loss.item() / len(train_loader))) print("Epoch: {} Gen loss: {}".format( e + 1, gen_loss.item() / len(train_loader)))
def main(args): # Create model directory if not os.path.exists(args.model_path): os.makedirs(args.model_path) # Image preprocessing # For normalization, see https://github.com/pytorch/vision#models transform = transforms.Compose([ transforms.RandomCrop(args.crop_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) # Load vocabulary wrapper. with open(args.vocab_path, 'rb') as f: vocab = pickle.load(f) # Build data loader data_loader = get_loader(args.image_dir, args.caption_path, vocab, transform, args.batch_size, shuffle=True, num_workers=args.num_workers) # Build the models (Gen) # TODO: put these in generator encoder = EncoderCNN(args.embed_size) decoder = DecoderRNN(args.embed_size, args.hidden_size, len(vocab), args.num_layers) # Build the models (Disc) discriminator = Discriminator(args.embed_size, args.hidden_size, len(vocab), args.num_layers) if torch.cuda.is_available(): encoder.cuda() decoder.cuda() discriminator.cuda() # Loss and Optimizer (Gen) criterion = nn.CrossEntropyLoss() params = list(decoder.parameters()) + list( encoder.linear.parameters()) + list(encoder.bn.parameters()) optimizer = torch.optim.Adam(params, lr=args.learning_rate) # Loss and Optimizer (Disc) params_disc = list(discriminator.parameters()) optimizer_disc = torch.optim.Adam(params_disc) # Train the Models total_step = len(data_loader) disc_losses = [] for epoch in range(args.num_epochs): for i, (images, captions, lengths, wrong_captions, wrong_lengths) in enumerate(data_loader): # pdb.set_trace() # TODO: train disc before gen # Set mini-batch dataset images = to_var(images, volatile=True) captions = to_var(captions) wrong_captions = to_var(wrong_captions) targets = pack_padded_sequence(captions, lengths, batch_first=True)[0] # Forward, Backward and Optimize decoder.zero_grad() encoder.zero_grad() features = encoder(images) outputs = decoder(features, captions, lengths) sampled_captions = decoder.sample(features) # sampled_captions = torch.zeros_like(sampled_ids) sampled_lengths = [] for row in range(sampled_captions.size(0)): for index, word_id in enumerate(sampled_captions[row, :]): # pdb.set_trace() word = vocab.idx2word[word_id.cpu().data.numpy()[0]] # sampled_captions[row, index].data = word if word == '<end>': sampled_lengths.append(index + 1) break elif index == sampled_captions.size(1) - 1: sampled_lengths.append(sampled_captions.size(1)) break sampled_lengths = np.array(sampled_lengths) sampled_lengths[::-1].sort() sampled_lengths = sampled_lengths.tolist() loss = criterion(outputs, targets) loss.backward() optimizer.step() # Train discriminator discriminator.zero_grad() rewards_real = discriminator(images, captions, lengths) rewards_fake = discriminator(images, sampled_captions, sampled_lengths) rewards_wrong = discriminator(images, wrong_captions, wrong_lengths) real_loss = -torch.mean(torch.log(rewards_real)) fake_loss = -torch.mean( torch.clamp(torch.log(1 - rewards_fake), min=-1000)) wrong_loss = -torch.mean( torch.clamp(torch.log(1 - rewards_wrong), min=-1000)) loss_disc = real_loss + fake_loss + wrong_loss disc_losses.append(loss_disc.cpu().data.numpy()[0]) loss_disc.backward() optimizer_disc.step() # print('iteration %i' % i) # Print log info if i % args.log_step == 0: print( 'Epoch [%d/%d], Step [%d/%d], Loss: %.4f, Perplexity: %5.4f' % (epoch, args.num_epochs, i, total_step, loss.data[0], np.exp(loss.data[0]))) # Save the models # if (i+1) % args.save_step == 0: if ( i + 1 ) % total_step == 0: # jm: saving at the last iteration instead torch.save( decoder.state_dict(), os.path.join(args.model_path, 'decoder-%d-%d.pkl' % (epoch + 1, i + 1))) torch.save( encoder.state_dict(), os.path.join(args.model_path, 'encoder-%d-%d.pkl' % (epoch + 1, i + 1))) torch.save( discriminator.state_dict(), os.path.join( args.model_path, 'discriminator-%d-%d.pkl' % (epoch + 1, i + 1))) # plot at the end of every epoch plt.plot(disc_losses, label='disc loss') plt.savefig('disc_losses.png') plt.clf()