コード例 #1
0
def train(discriminator, generator, show_every=250, num_epochs=20):
    start_t = time.time()
    iter_count = 0
    discriminator.add_optimizer()
    generator.add_optimizer()
    for epoch in range(num_epochs):
        for x, _ in loader_train:
            if len(x) != batch_size:
                continue
            discriminator.optimizer.zero_grad()
            real_data = Variable(x).type(torch.FloatTensor)
            logits_real = discriminator.forward(
                2 * (real_data.view(batch_size, -1) - 0.5)).type(
                    torch.FloatTensor)

            g_fake_seed = Variable(sample_noise(batch_size, noise_dim)).type(
                torch.FloatTensor)
            fake_images = generator.forward(g_fake_seed)
            logits_fake = discriminator.forward(fake_images.detach().view(
                batch_size, -1))

            d_total_error = discriminator.loss(logits_real, logits_fake)
            d_total_error.backward()
            discriminator.optimizer.step()

            generator.optimizer.zero_grad()
            g_fake_seed = Variable(sample_noise(batch_size, noise_dim)).type(
                torch.FloatTensor)
            fake_images = generator.forward(g_fake_seed)

            gen_logits_fake = discriminator.forward(
                fake_images.view(batch_size, -1))
            g_loss = generator.loss(gen_logits_fake)
            g_loss.backward()
            generator.optimizer.step()

            # if (iter_count % show_every == 0):
            #     print('Iter: {}, D: {:.4}, G:{:.4}'.format(iter_count,d_total_error.data[0],g_loss.data[0]))
            #     imgs_numpy = fake_images.data.cpu().numpy()
            #     plot_batch_images(imgs_numpy[0:16], iter_num=iter_count)
            #     print()
            #
            if (iter_count % show_every == 0):
                checkpt_t = time.time()
                print("time : {:.2f} sec".format(checkpt_t - start_t))
                print('Iter: {}, D: {:.4}, G:{:.4}'.format(
                    iter_count, d_total_error.data[0], g_loss.data[0]))
                print("real logits average ", torch.mean(logits_real).data)
                print("average output generator : ",
                      torch.mean(fake_images).data)
                print("fake logits average ", torch.mean(gen_logits_fake).data)
                imgs = fake_images[:16].data.numpy()
                show_images(imgs,
                            iter_num=iter_count,
                            save=True,
                            show=False,
                            model=generator.label)
            iter_count += 1
コード例 #2
0
    def train(self):
        # self.fixed_idx =  datasets.get_fixed_test_idx(name=self.dataset)
        self.train_batches = utils.batch_generator(self.dataset.imgs,
                                                   self.batch_size)
        self.test_batches = utils.batch_generator(
            self.dataset.imgs, batch_size=50)  # Same as train
        train_path = os.path.join(self.results_path, 'train')
        if not os.path.exists(train_path):
            os.mkdir(train_path)
        image_base_name = os.path.join(train_path, '{:s}_step_{:d}.png')
        metrics_history = {'iter': [], 'recons': [], 'disent': []}
        start_time = time.time()
        for stp in range(1, self.n_steps + 1):
            x_np = next(self.train_batches)
            _, loss_np, rec_np, reg_np = self.sess.run([
                self.optim_op, self.loss, self.loglikelihood, self.regularizer
            ],
                                                       feed_dict={
                                                           self.x: x_np,
                                                           self.Psi:
                                                           self.Psi_np,
                                                           self.nu: self.nu_np
                                                       })
            if stp % 1000 == 0 or stp == 1:
                end_time = time.time()
                print(
                    'Step: {:d} in {:.2f}s:: Loss: {:.3f} => Recons.: {:.3f}, Reg: {:.3f}'
                    .format(stp, end_time - start_time, loss_np, -rec_np,
                            -reg_np))
                start_time = end_time
                x_test_np = next(self.test_batches)
                x_recons_np = self.sess.run(self.x_recons,
                                            feed_dict={self.x_test: x_test_np})
                utils.render_reconstructions(
                    x_test_np, x_recons_np, image_base_name.format('rec', stp))

                z_np = utils.sample_noise(self.Psi_np, self.nu_np, 100)
                x_hat_np = self.sess.run(self.fake_images,
                                         feed_dict={self.noise: z_np})
                utils.render_images(x_hat_np,
                                    image_base_name.format('iw', stp))
            if stp % 10000 == 0:
                disent_metric = utils.compute_metric(self)[1]
                metrics_history['iter'].append(stp)
                metrics_history['recons'].append(-rec_np)
                metrics_history['disent'].append(disent_metric)
                print('Metric: {:.4f}'.format(disent_metric))

        with open(os.path.join(train_path, 'metrics.pkl'), 'wb') as pkl:
            pickle.dump(metrics_history, pkl)
コード例 #3
0
    def generate(self):
        gen_path = os.path.join(self.results_path, 'gen')
        if not os.path.exists(gen_path):
            os.mkdir(gen_path)
        nu = self.z_dim + 1
        while True:
            z_np = utils.sample_noise(self.Psi_np, self.nu_np, 100)
            x_hat_np = self.sess.run(self.fake_images,
                                     feed_dict={self.noise: z_np})
            utils.render_images(
                x_hat_np, os.path.join(gen_path, 'nu_{:d}.png'.format(nu)))
            if nu >= 2 * self.nu_np:
                break
            nu = nu * 4

        x_test_np = next(self.test_batches)
        x_np = np.vstack(
            [x_test_np,
             next(self.train_batches),
             next(self.test_batches)])
        means, = self.sess.run([self.mu], feed_dict={self.x: x_np})
        for num, base_point in enumerate(means):
            n_images_per_latent = 20
            z_np = []
            for i in range(self.z_dim):
                dim_i = np.repeat(base_point[None, :], n_images_per_latent, 0)
                dim_i[np.arange(n_images_per_latent),
                      i] = np.linspace(-3, 3, n_images_per_latent)
                z_np.append(dim_i)
            z_np = np.vstack(z_np)
            x_hat_np = self.sess.run(self.fake_images,
                                     feed_dict={self.noise: z_np})
            img_path = os.path.join(gen_path,
                                    'inter_{:d}_{:d}.png'.format(1, num))
            utils.render_images(x_hat_np,
                                img_path,
                                n_rows=self.z_dim,
                                n_cols=n_images_per_latent)
コード例 #4
0
def train(D, G, D_solver, G_solver, discriminator_loss, generator_loss, show_every=250, 
              batch_size=128, noise_size=100, num_epochs=10, train_loader=None, device=None, MNIST=False):
    """
    Train loop for GAN.
    
    The loop will consist of two steps: a discriminator step and a generator step.
    
    (1) In the discriminator step, you should zero gradients in the discriminator 
    and sample noise to generate a fake data batch using the generator. Calculate 
    the discriminator output for real and fake data, and use the output to compute
    discriminator loss. Call backward() on the loss output and take an optimizer
    step for the discriminator.
    
    (2) For the generator step, you should once again zero gradients in the generator
    and sample noise to generate a fake data batch. Get the discriminator output
    for the fake data batch and use this to compute the generator loss. Once again
    call backward() on the loss and take an optimizer step.
    
    You will need to reshape the fake image tensor outputted by the generator to 
    be dimensions (batch_size x input_channels x img_size x img_size).
    
    Use the sample_noise function to sample random noise, and the discriminator_loss
    and generator_loss functions for their respective loss computations.
    
    
    Inputs:
    - D, G: PyTorch models for the discriminator and generator
    - D_solver, G_solver: torch.optim Optimizers to use for training the
      discriminator and generator.
    - discriminator_loss, generator_loss: Functions to use for computing the generator and
      discriminator loss, respectively.
    - show_every: Show samples after every show_every iterations.
    - batch_size: Batch size to use for training.
    - noise_size: Dimension of the noise to use as input to the generator.
    - num_epochs: Number of epochs over the training dataset to use for training.
    - train_loader: image dataloader
    - device: PyTorch device
    """
    iter_count = 0
	
    # For running on GPU
    dtype = torch.cuda.FloatTensor
	
    for epoch in range(num_epochs):
        print('EPOCH: ', (epoch+1))
        
        for x, _ in train_loader:
            _, input_channels, img_size, _ = x.shape
            
            d_error = None
            g_error = None
            fake_images = None
            
            ####################################
            #          YOUR CODE HERE          #
            ####################################
            
            ######### Discriminator Step ####### 
            D_solver.zero_grad()
            real_data = torch.tensor(x).type(dtype)
            logits_real = D(2* (real_data - 0.5)).type(dtype)

            ####### Train with fake batch ######
            g_fake_seed = torch.tensor(sample_noise(batch_size, noise_size)).type(dtype)
            fake_images = G(g_fake_seed).detach()
            logits_fake = D(fake_images.view(batch_size, input_channels, img_size, img_size))
            d_error = discriminator_loss(logits_real, logits_fake)
            d_error.backward()        
            D_solver.step()
            
            ########## Generator Step ##########
            G_solver.zero_grad()
            g_fake_seed = torch.tensor(sample_noise(batch_size, noise_size)).type(dtype)
            fake_images = G(g_fake_seed)
            gen_logits_fake = D(fake_images.view(batch_size, input_channels, img_size, img_size))
            g_error = generator_loss(gen_logits_fake)
            g_error.backward()
            G_solver.step()
            ##########       END      ##########
            
            # Logging and output visualization
            if (iter_count % show_every == 0):
                print('Iter: {}, D: {:.4}, G:{:.4}'.format(iter_count,d_error.item(),g_error.item()))
                disp_fake_images = deprocess_img(fake_images.data)  # denormalize
                imgs_numpy = (disp_fake_images).cpu().numpy()
                show_images(imgs_numpy[0:16], color=input_channels!=1)
                plt.show()
                print()
                
                device = torch.device('cuda')
                print('Using device:', device)
                if device.type == 'cuda':
                    print(torch.cuda.get_device_name(0))
                    print('Memory Usage:')
                    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
                    print('Cached:   ', round(torch.cuda.memory_cached(0)/1024**3,1), 'GB')
                    print(" ")
            iter_count += 1
コード例 #5
0
ファイル: wdcgan_trainer.py プロジェクト: jinczing/pokeGAN
	def Train(self):
		self.log_gpu_stats()
		dataset = PokeGanDataset(images_path='./pokemon_images/pre/', csv_path='./attributes.csv')
		dataloader = DataLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers
			, shuffle=True)

		if log_statistics:
			writer = SummaryWriter(log_dir=self.log_dir)

		# device = torch.device('cpu')
		device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

		debug_sample_noises = self.get_pytorch_variable(sample_noise(self.debug_sample_size, self.z_dim), device)

		generator, discriminator = self.get_models(device)

		g_opt, d_opt = self.get_optimizers(generator, discriminator)

		one = torch.tensor(1, dtype=torch.float).to(device)
		mone = one * -1

		generator.apply(weights_init_truncated_normal)
		discriminator.apply(weights_init_truncated_normal)

		batch_num = len(dataloader.dataset)//self.batch_size + 1

		for iter in range(self.iterations):
			d_loss_average = 0
			fake_loss_average = 0
			real_loss_average = 0
			g_loss_average = 0
			gradient_penalty_average = 0

			discriminator.train()
			generator.train()

			for idx, dic in enumerate(dataloader):

				batch_size_t = dic['attribute'].size(0)

				generator.train()
				discriminator.train()

				real_images = self.get_pytorch_variable(dic['image'], device)

				# train discriminator
				for p in discriminator.parameters():
					p.requires_grad = True

				for i in range(self.d_iterations):
					discriminator.zero_grad()

					if self.mode =='wc':
						for p in discriminator.parameters():
							p.data.clamp_(-self.weight_clipping_limit, self.weight_clipping_limit)

					z_noise = self.get_pytorch_variable(sample_noise(batch_size_t, self.z_dim), device)
					fake_images = generator(z_noise).detach() 
					d_fake_loss = discriminator(fake_images).view(-1)
					d_fake_loss = d_fake_loss.mean()
					d_fake_loss.backward(one)
					fake_loss_average += d_fake_loss.item()

					d_real_loss = discriminator(real_images).view(-1)
					d_real_loss = d_real_loss.mean()
					d_real_loss.backward(mone)
					real_loss_average += d_real_loss.item()

					if self.mode == 'gp':
						gradient_penalty = self.gradient_penalty(discriminator, real_images, fake_images, batch_size_t, device)
						gradient_penalty.backward(one)
						gradient_penalty_average += gradient_penalty.item()

					if self.mode == 'gp':
						d_loss = d_fake_loss - d_real_loss + gradient_penalty
					else:
						d_loss = d_fake_loss - d_real_loss

					d_loss_average += d_loss.item()
					d_opt.step()

				# train generator
				for p in discriminator.parameters():
					p.requires_gard = False

				for i in range(self.g_iterations):
					generator.zero_grad()
					z_noise = self.get_pytorch_variable(sample_noise(batch_size_t, self.z_dim), device)
					fake_images = generator(z_noise)
					g_fake_loss = discriminator(fake_images).view(-1)
					g_fake_loss = g_fake_loss.mean()
					g_fake_loss.backward(mone)
					g_loss_average += -g_fake_loss.item()
					g_opt.step()

			# print epoch summary
			d_loss_average /= batch_num*self.d_iterations
			fake_loss_average /= batch_num*self.d_iterations
			real_loss_average /= batch_num*self.d_iterations
			g_loss_average /= batch_num*self.g_iterations
			if self.mode == 'gp':
				gradient_penalty_average /= batch_num*self.d_iterations
			print('epoch: ', iter+1, ', d_loss: ', d_loss_average, ', fake_loss: ', fake_loss_average,
				', real_loss: ', real_loss_average, ', g_loss: ', g_loss_average)

			if log_statistics:
				writer.add_scalar('data/discriminator_loss', d_loss_average, iter+1)
				writer.add_scalar('data/generator_loss', g_loss_average, iter+1)
				writer.add_scalar('data/real_loss', real_loss_average, iter+1)
				writer.add_scalar('data/fake_loss', fake_loss_average, iter+1)
				if self.mode == 'gp':
					writer.add_scalar('data/gradient_penalty', gradient_penalty_average, iter+1)

				# generator.eval()
				with torch.no_grad():
					sample_images = generator(debug_sample_noises).cpu().data
				grid_images = torchvision.utils.make_grid(sample_images, padding=2)
				writer.add_image('images', grid_images, iter+1)
コード例 #6
0
train_iter = iter(train_data)
val_iter = iter(val_data)
test_iter = iter(train_data)


"""
Prepare models
"""
g_pre_lr = config["model_config"]["generator"]["pre_lr"]
g_lr = config["model_config"]["generator"]["lr"]
d_lr = config["model_config"]["discriminator"]["lr"]

noise_size = config["model_config"]["noise_size"]

num_test_samples = 100
test_noise = sample_noise(num_test_samples,noise_size,device)

vocab = word_to_num_vocab.keys() if config["model_config"]["use_glove"].lower() == "true" else None

# intialize models
if "hidden_size" not in config["model_config"]["generator"] and "mem_slots" in config["model_config"]["generator"] and \
"head_size" in config["model_config"]["generator"] and "num_heads" in config["model_config"]["generator"]:
	generator = getattr(generators,config["model_config"]["generator"]["name"])(mem_slots=config["model_config"]["generator"]["mem_slots"],
		head_size=config["model_config"]["generator"]["head_size"],num_heads=config["model_config"]["generator"]["num_heads"],
		noise_size=noise_size,output_size=num_classes,vocab=vocab,SOS_TOKEN=SOS_TOKEN,beam_width=config["model_config"]["generator"]["beam_width"]).to(device)
elif "hidden_size" in config["model_config"]["generator"] and "sim_size" in config["model_config"]["generator"] and \
"similarity" in config["model_config"]["generator"]:
	generator = getattr(generators,config["model_config"]["generator"]["name"])(hidden_size=config["model_config"]["generator"]["hidden_size"],
		noise_size=noise_size,output_size=num_classes,max_seq_len=max_seq_len,sim_size=config["model_config"]["generator"]["sim_size"],
		similarity=getattr(nn,config["model_config"]["generator"]["similarity"])(dim=-1),vocab=vocab,SOS_TOKEN=SOS_TOKEN,beam_width=config["model_config"]["generator"]["beam_width"]).to(device)
elif "TransformerGenerator" in config["model_config"]["generator"]["name"]:
コード例 #7
0
ファイル: train.py プロジェクト: kuzeydev/GanStarter
G = Generator(NOISE_DIM, dtype)
D = Discriminator(batch_size, dtype)

G_optim = get_optimizer(G)
D_optim = get_optimizer(D)

#imgs = data.mnist_train_loader.__iter__().next()[0].view(batch_size,1,img_size).numpy().squeeze()

for epoch in range(num_epochs):
    for x, _ in data.mnist_train_loader:
        if len(x) != batch_size:
            continue

        D_optim.zero_grad()
        real_images = Variable(x).type(dtype)
        d_random_noise = Variable(sample_noise(batch_size,
                                               NOISE_DIM)).type(dtype)
        #Discriminator loss
        #Feed generator with random_noise and create fake image
        #Then feed discriminator two times, once with fake images, once with real ones
        #Calculate loss and do a backward pass.
        #dont forget to .detach generated images to prevent gradient calculation over them.

        G_optim.zero_grad()

        g_random_noise = Variable(sample_noise(batch_size,
                                               NOISE_DIM)).type(dtype)
        #Generator loss
        #Create fake images with generator
        #Feed your discriminator with new synthesised images
        #Calculate loss and do a backward pass.
コード例 #8
0
	for i in range(fake_data.size(1)):
		loss += pretrain_loss_fun(fake_data[:,i,:],real_data[:,i])
	loss /= fake_data.size(1)
	return loss

# evaluate generator
nll_gen_error = []
hypothesis_list = []
reference = []
for n_batch,batch in enumerate(val_iter):
	real_data = batch.text.to(device)
	N = real_data.size(0)
	num_steps = real_data.size(1)

	# Generate fake data
	noise = sample_noise(N,noise_size,device)
	fake_data = generator(z=noise,num_steps=num_steps,temperature=max_temperature,
						  x=real_data.long())
	# Calculate nll_gen
	nll_g_error = nll_gen(real_data,fake_data)
	nll_gen_error.append(nll_g_error.item())

	# Save sentences for bleu score calculation
	fake_data = generator(z=noise,num_steps=num_steps,temperature=max_temperature)
	fake_data_vals = torch.argmax(fake_data,dim=2)
	fake_data_text = tensor_to_list_of_words(fake_data_vals,num_to_word_vocab)
	real_data_text = tensor_to_list_of_words(real_data,num_to_word_vocab)
	hypothesis_list.extend(fake_data_text)
	reference.extend(real_data_text)

nll_gen_error = np.array(nll_gen_error)