def train(self, num_epoch):
        criterion = nn.BCELoss()
        for epoch in range(num_epoch):
            for i, data in enumerate(tqdm(self.train_dl)):
                # (1) : maximize log(D(x)) + log(1 - D(G(z)))
                # 		also means minimize (-log(D(x))) + (-log(1 - D(G(z))))
                self.netD.zero_grad()

                # first, calculate -log(D(x)) and its gradients using real images
                # real images (bs, nc, 64, 64)
                real_images = data[0].to(self.device)
                bs = real_images.size(0)
                # real labels (bs)
                label = torch.full((bs, ), self.real_label, device=self.device)
                output = self.netD(real_images)  # (bs, 1, 1, 1)
                output = output.view(-1)  # (bs)
                # BCELoss of output(bs), and real label(bs)
                errD_real = criterion(output, label)  # -log(D(x))
                # calculate the gradients
                errD_real.backward()

                # second, calculate -log(1 - D(G(z))) and its gradients using fake images
                # noise (bs, nz, 1, 1), fake images (bs, nc, 64, 64)
                noise = generate_noise(bs, self.nz, self.device)
                fake_images = self.netG(noise)
                # fake labels (bs)
                label.fill_(self.fake_label)
                output = self.netD(fake_images.detach())  # (bs, 1, 1, 1)
                output = output.view(-1)  # (bs)
                # BCELoss of output(bs), and fake labels(bs)
                errD_fake = criterion(output, label)  # -log(1 - D(G(z)))
                # calculate the gradients
                errD_fake.backward()

                # calculate the final loss value, (-log(D(x))) + (-log(1 - D(G(z))))
                errD = errD_real + errD_fake
                # update D using the gradients calculated previously
                self.optimizerD.step()

                # (2) : maximize log(D(G(z)))
                #		also means minimize -log(D(G(z)))
                self.netG.zero_grad()
                if (self.resample):
                    noise = generate_noise(bs, self.nz, self.device)
                    fake_images = self.netG(noise)

                # first, calculate -log(D(G(z))) and its gradients using fake images
                # real labels (bs)
                label.fill_(self.real_label)
                output = self.netD(fake_images)  # (bs, 1, 1, 1)
                output = output.view(-1)  # (bs)
                # BCELoss of output(bs), and real labels(bs)
                errG = criterion(output, label)  # -log(D(G(z)))
                #calculate the gradients
                errG.backward()

                #update G using the gradients calculated previously
                self.optimizerG.step()

                self.errD_records.append(float(errD))
                self.errG_records.append(float(errG))

                if (i % self.loss_interval == 0):
                    print('[%d/%d] [%d/%d] errD : %.4f, errG : %.4f' %
                          (epoch + 1, num_epoch, i + 1,
                           self.train_iteration_per_epoch, errD, errG))

                if (i % self.image_interval == 0):
                    if (self.special == None):
                        sample_images_list = get_sample_images_list(
                            'Unsupervised', (self.fixed_noise, self.netG))
                        plot_fig = plot_multiple_images(
                            sample_images_list, 4, 4)
                        cur_file_name = os.path.join(
                            self.save_img_dir,
                            str(self.save_cnt) + ' : ' + str(epoch) + '-' +
                            str(i) + '.jpg')
                        self.save_cnt += 1
                        save_fig(cur_file_name, plot_fig)
                        plot_fig.clf()

                    elif (self.special == 'Wave'):
                        sample_audios_list = get_sample_images_list(
                            'Unsupervised_Audio',
                            (self.fixed_noise, self.netG))
                        plot_fig = plot_multiple_spectrograms(
                            sample_audios_list, 4, 4, freq=16000)
                        cur_file_name = os.path.join(
                            self.save_img_dir,
                            str(self.save_cnt) + ' : ' + str(epoch) + '-' +
                            str(i) + '.jpg')
                        self.save_cnt += 1
                        save_fig(cur_file_name, plot_fig)
                        plot_fig.clf()

                if (self.snapshot_interval is not None):
                    if (i % self.snapshot_interval == 0):
                        save(
                            os.path.join(
                                self.save_snapshot_dir, 'Epoch' + str(epoch) +
                                '_' + str(i) + '.state'), self.netD, self.netG,
                            self.optimizerD, self.optimizerG)
    def train(self, num_epoch):
        for epoch in range(num_epoch):
            if (self.resample):
                train_dl_iter = iter(self.train_dl)
            for i, data in enumerate(tqdm(self.train_dl)):
                # (1) : minimizes mean((D(x) - mean(D(G(z))) - 1)**2) + mean((D(G(z)) - mean(D(x)) + 1)**2)
                self.netD.zero_grad()
                real_images = data[0].to(self.device)
                bs = real_images.size(0)
                # real labels (bs)
                real_label = torch.full((bs, ),
                                        self.real_label,
                                        device=self.device)
                # fake labels (bs)
                fake_label = torch.full((bs, ),
                                        self.fake_label,
                                        device=self.device)
                # noise (bs, nz, 1, 1), fake images (bs, cn, 64, 64)
                noise = generate_noise(bs, self.nz, self.device)
                fake_images = self.netG(noise)
                # calculate the discriminator results for both real & fake
                c_xr = self.netD(real_images)  # (bs, 1, 1, 1)
                c_xr = c_xr.view(-1)  # (bs)
                c_xf = self.netD(fake_images.detach())  # (bs, 1, 1, 1)
                c_xf = c_xf.view(-1)  # (bs)
                # calculate the Discriminator loss
                errD = (torch.mean(
                    (c_xr - torch.mean(c_xf) - real_label)**2) + torch.mean(
                        (c_xf - torch.mean(c_xr) + real_label)**2)) / 2.0
                errD.backward()
                # update D using the gradients calculated previously
                self.optimizerD.step()

                # (2) : minimizes mean((D(G(z)) - mean(D(x)) - 1)**2) + mean((D(x) - mean(D(G(z))) + 1)**2)
                self.netG.zero_grad()
                if (self.resample):
                    real_images = next(train_dl_iter)[0].to(self.device)
                    noise = generate_noise(bs, self.nz, self.device)
                    fake_images = self.netG(noise)
                # we updated the discriminator once, therefore recalculate c_xr, c_xf
                c_xr = self.netD(real_images)  # (bs, 1, 1, 1)
                c_xr = c_xr.view(-1)  # (bs)
                c_xf = self.netD(fake_images)  # (bs, 1, 1, 1)
                c_xf = c_xf.view(-1)  # (bs)
                # calculate the Generator loss
                errG = (torch.mean(
                    (c_xf - torch.mean(c_xr) - real_label)**2) + torch.mean(
                        (c_xr - torch.mean(c_xf) + real_label)**2)) / 2.0
                errG.backward()
                # update G using the gradients calculated previously
                self.optimizerG.step()

                self.errD_records.append(float(errD))
                self.errG_records.append(float(errG))

                if (i % self.loss_interval == 0):
                    print('[%d/%d] [%d/%d] errD : %.4f, errG : %.4f' %
                          (epoch + 1, num_epoch, i + 1,
                           self.train_iteration_per_epoch, errD, errG))

                if (i % self.image_interval == 0):
                    if (self.special == None):
                        sample_images_list = get_sample_images_list(
                            'Unsupervised', (self.fixed_noise, self.netG))
                        plot_fig = plot_multiple_images(
                            sample_images_list, 4, 4)
                        cur_file_name = os.path.join(
                            self.save_img_dir,
                            str(self.save_cnt) + ' : ' + str(epoch) + '-' +
                            str(i) + '.jpg')
                        self.save_cnt += 1
                        save_fig(cur_file_name, plot_fig)
                        plot_fig.clf()

                    elif (self.special == 'Wave'):
                        sample_audios_list = get_sample_images_list(
                            'Unsupervised_Audio',
                            (self.fixed_noise, self.netG))
                        plot_fig = plot_multiple_spectrograms(
                            sample_audios_list, 4, 4, freq=16000)
                        cur_file_name = os.path.join(
                            self.save_img_dir,
                            str(self.save_cnt) + ' : ' + str(epoch) + '-' +
                            str(i) + '.jpg')
                        self.save_cnt += 1
                        save_fig(cur_file_name, plot_fig)
                        plot_fig.clf()

                if (self.snapshot_interval is not None):
                    if (i % self.snapshot_interval == 0):
                        save(
                            os.path.join(
                                self.save_snapshot_dir, 'Epoch' + str(epoch) +
                                '_' + str(i) + '.state'), self.netD, self.netG,
                            self.optimizerD, self.optimizerG)
    def train(self, num_epoch):
        criterion = nn.BCELoss()
        for epoch in range(num_epoch):
            for i, data in enumerate(tqdm(self.train_dl)):
                # (1) : minimize 0.5 * mean((D(x, y) - 1)^2) + 0.5 * mean((D(G(z, y), y) - 0)^2)
                self.netD.zero_grad()
                real_images = data[0].to(self.device)
                real_class = data[1].to(self.device)

                bs = real_images.size(0)
                # real labels (bs)
                real_label = torch.full((bs, ),
                                        self.real_label,
                                        device=self.device)
                # fake labels (bs)
                fake_label = torch.full((bs, ),
                                        self.fake_label,
                                        device=self.device)

                # one hot labels (bs, n_classes)
                one_hot_labels = torch.FloatTensor(bs, self.n_classes).to(
                    self.device)
                one_hot_labels.zero_()
                one_hot_labels.scatter_(1, real_class.view(bs, 1), 1.0)

                # noise (bs, nz, 1, 1), fake images (bs, nc, 64, 64)
                noise = generate_noise(bs, self.nz, self.device)

                fake_class = torch.randint(0, self.n_classes,
                                           size=(bs,
                                                 1)).view(bs,
                                                          1).to(self.device)
                # one hot labels (bs, n_classes)
                one_hot_labels_fake = torch.FloatTensor(bs, self.n_classes).to(
                    self.device)
                one_hot_labels_fake.zero_()
                one_hot_labels_fake.scatter_(1,
                                             fake_class.view(bs, 1).long(),
                                             1.0)

                fake_images = self.netG(noise, one_hot_labels_fake)

                # calculate the discriminator results for both real & fake
                c_xr = self.netD(real_images, one_hot_labels)  # (bs, 1, 1, 1)
                c_xr = c_xr.view(-1)  # (bs)
                c_xf = self.netD(fake_images.detach(),
                                 one_hot_labels_fake)  # (bs, 1, 1, 1)
                c_xf = c_xf.view(-1)  # (bs)
                # calculate the discriminator loss
                errD = criterion(c_xr, real_label) + criterion(
                    c_xf, fake_label)
                errD.backward()
                # update D using the gradients calculated previously
                self.optimizerD.step()

                # (2) : minimize 0.5 * mean((D(G(z)) - 1)^2)
                self.netG.zero_grad()
                if (self.resample):
                    noise = generate_noise(bs, self.nz, self.device)
                    one_hot_labels_fake = torch.FloatTensor(
                        bs, self.n_classes).to(self.device)
                    one_hot_labels_fake.zero_()
                    one_hot_labels_fake.scatter_(1,
                                                 fake_class.view(bs, 1).long(),
                                                 1.0)
                    fake_images = self.netG(noise, one_hot_labels_fake)
                # we updated the discriminator once, therefore recalculate c_xf
                c_xf = self.netD(fake_images,
                                 one_hot_labels_fake)  # (bs, 1, 1, 1)
                c_xf = c_xf.view(-1)  # (bs)
                # calculate the Generator loss
                errG = criterion(c_xf,
                                 real_label)  # 0.5 * mean((D(G(z)) - 1)^2)
                errG.backward()
                #update G using the gradients calculated previously
                self.optimizerG.step()

                self.errD_records.append(float(errD))
                self.errG_records.append(float(errG))

                if (i % self.loss_interval == 0):
                    print('[%d/%d] [%d/%d] errD : %.4f, errG : %.4f' %
                          (epoch + 1, num_epoch, i + 1,
                           self.train_iteration_per_epoch, errD, errG))

                if (i % self.image_interval == 0):
                    sample_images_list = get_sample_images_list(
                        'Conditional',
                        (self.fixed_noise, self.fixed_one_hot_labels,
                         self.n_classes, self.netG))
                    plot_fig = plot_multiple_images(sample_images_list,
                                                    self.n_classes, 1)
                    cur_file_name = os.path.join(
                        self.save_img_dir,
                        str(self.save_cnt) + ' : ' + str(epoch) + '-' +
                        str(i) + '.jpg')
                    self.save_cnt += 1
                    save_fig(cur_file_name, plot_fig)
                    plot_fig.clf()

                if (self.snapshot_interval is not None):
                    if (i % self.snapshot_interval == 0):
                        save(
                            os.path.join(
                                self.save_snapshot_dir, 'Epoch' + str(epoch) +
                                '_' + str(i) + '.state'), self.netD, self.netG,
                            self.optimizerD, self.optimizerG)
Esempio n. 4
0
	def train(self, res_num_epochs, res_percentage, bs):
		l1 = nn.L1Loss()
		
		p = 0
		res_percentage = [None] + res_percentage

		for i, (num_epoch, percentage, cur_bs) in enumerate(zip(res_num_epochs, res_percentage, bs)):
			train_dl = self.train_ds.get_loader(32 * (2**i), cur_bs)
			val_dl = list(self.val_ds.get_loader(32 * (2**i), 3))[0]
			train_dl_len = len(train_dl)
				
			if(percentage is None):
				num_epoch_transition = 0
			else:
				num_epoch_transition = int(num_epoch * percentage)

			cnt = 1
			for epoch in range(num_epoch):
				p = i
				if(self.resample):
					train_dl_iter = iter(train_dl)

				for j, (x, y) in enumerate(tqdm(train_dl)):
					if(epoch < num_epoch_transition):
						p = i + cnt / (train_dl_len * num_epoch_transition) - 1
						cnt+=1

					x = x.to(self.device)
					y = y.to(self.device)
					bs = x.size(0)
					noise = generate_noise(bs, self.nz, self.device)
					fake_y = self.netG(x, p, noise)

					self.netD.zero_grad()

					c_xr = self.netD(x, y)
					c_xr = c_xr.view(-1)
					c_xf = self.netD(x, fake_y.detach())
					c_xf = c_xf.view(-1)

					if(self.require_type == 0 or self.require_type == 1):
						errD = self.loss.d_loss(c_xr, c_xf)
					elif(self.require_type == 2):
						errD = self.loss.d_loss(c_xr, c_xf, y, fake_y)
					
					if(self.use_gradient_penalty != False):
						errD += self.use_gradient_penalty * self.gradient_penalty(x, y, fake_y)

					errD.backward()
					self.optimizerD.step()

					if(self.weight_clip != None):
						for param in self.netD.parameters():
							param.data.clamp_(-self.weight_clip, self.weight_clip)


					self.netG.zero_grad()
					if(self.resample):
						x, y = next(train_dl_iter)
						x = x.to(self.device)
						y = y.to(self.device)
						bs = x.size(0)
						noise = generate_noise(bs, self.nz, self.device)
						fake_y = self.netG(x, p, noise)

					if(self.require_type == 0):
						c_xr = None
						c_xf, f1 = self.netD(x, fake_y, True)		# (bs, 1, 1, 1)
						c_xf = c_xf.view(-1)						# (bs)	
						errG_1 = self.loss.g_loss(c_xf)
					if(self.require_type == 1 or self.require_type == 2):
						c_xr, f2 = self.netD(x, y, True)				# (bs, 1, 1, 1)
						c_xr = c_xr.view(-1)						# (bs)
						c_xf, f1 = self.netD(x, fake_y, True)		# (bs, 1, 1, 1)
						c_xf = c_xf.view(-1)						# (bs)		
						errG_1 = self.loss.g_loss(c_xr, c_xf)

					if(self.ds_weight == 0):
						ds_loss = 0
					else:
						noise1 = generate_noise(bs, self.nz, self.device)
						noise2 = generate_noise(bs, self.nz, self.device)
						fake_y1 = self.netG(x, noise1)
						fake_y2 = self.netG(x, noise2)
						ds_loss = self.ds_loss.get_loss(fake_y1, fake_y2, noise1, noise2)
					
					if(self.rec_weight == 0):
						rec_loss = 0
					else:
						if(self.use_rec_feature):
							rec_loss = 0
							if(c_xr == None):
								c_xr, f2 = self.netD(x, y, True)				# (bs, 1, 1, 1)
								c_xr = c_xr.view(-1)						# (bs)
								for f1_, f2_ in zip(f1, f2):
									rec_loss += (f1_ - f2_).abs().mean()
								rec_loss /= len(f1)

						else:
							rec_loss = l1(fake_y, y)

					errG = errG_1 + rec_loss * self.rec_weight + ds_loss * self.ds_weight
					errG.backward()
					# update G using the gradients calculated previously
					self.optimizerG.step()

					if(j % self.loss_interval == 0):
						print('[%d/%d] [%d/%d] errD : %.4f, errG : %.4f'
							  %(epoch+1, num_epoch, i+1, train_dl_len, errD, errG))

					if(j % self.image_interval == 0):
						if(self.nz == None):
							sample_images_list = get_sample_images_list((val_dl, self.netG, p, self.device))
							plot_image = get_display_samples(sample_images_list, 3, 3)
						else:
							sample_images_list = get_sample_images_list_noise((val_dl, self.netG, p, self.fixed_noise, self.device))
							plot_image = get_display_samples(sample_images_list, 9, 3)

						cur_file_name = os.path.join(self.save_img_dir, str(self.save_cnt)+' : '+str(epoch)+'-'+str(j)+'.jpg')
						self.save_cnt += 1
						cv2.imwrite(cur_file_name, plot_image)
	def train(self, num_epoch):
		for epoch in range(num_epoch):
			for i, data in enumerate(tqdm(self.train_dl)):
				self.netD.zero_grad()
				real_images = data[0].to(self.device)
				bs = real_images.size(0)

				noise = generate_noise(bs, self.nz, self.device)
				fake_images = self.netG(noise)

				c_xr = self.netD(real_images)
				c_xr = c_xr.view(-1)
				c_xf = self.netD(fake_images.detach())
				c_xf = c_xf.view(-1)

				if(self.require_type == 0 or self.require_type == 1):
					errD = self.loss.d_loss(c_xr, c_xf)
				elif(self.require_type == 2):
					errD = self.loss.d_loss(c_xr, c_xf, real_images, fake_images)
				
				if(self.use_gradient_penalty != False):
					errD += self.use_gradient_penalty * self.gradient_penalty(real_images, fake_images)

				errD.backward()
				self.optimizerD.step()

				if(self.weight_clip != None):
					for param in self.netD.parameters():
						param.data.clamp_(-self.weight_clip, self.weight_clip)

			
				self.netG.zero_grad()
				if(self.resample):
					noise = generate_noise(bs, self.nz, self.device)
					fake_images = self.netG(noise)

				if(self.require_type == 0):
					c_xf = self.netD(fake_images)
					c_xf = c_xf.view(-1)
					errG = self.loss.g_loss(c_xf)
				if(self.require_type == 1 or self.require_type == 2):
					c_xr = self.netD(real_images)				# (bs, 1, 1, 1)
					c_xr = c_xr.view(-1)						# (bs)
					c_xf = self.netD(fake_images)		# (bs, 1, 1, 1)
					c_xf = c_xf.view(-1)
					errG = self.loss.g_loss(c_xr, c_xf)
				errG.backward()
				self.optimizerG.step()

				self.errD_records.append(float(errD))
				self.errG_records.append(float(errG))

				if(i % self.loss_interval == 0):
					print('[%d/%d] [%d/%d] errD : %.4f, errG : %.4f'
						  %(epoch+1, num_epoch, i+1, self.train_iteration_per_epoch, errD, errG))

				if(i % self.image_interval == 0):
					if(self.special == None):
						sample_images_list = get_sample_images_list('Unsupervised', (self.fixed_noise, self.netG))
						plot_img = get_display_samples(sample_images_list, 7, 7)
						cur_file_name = os.path.join(self.save_img_dir, str(self.save_cnt)+' : '+str(epoch)+'-'+str(i)+'.jpg')
						self.save_cnt += 1
						cv2.imwrite(cur_file_name, plot_img)

					elif(self.special == 'Wave'):
						sample_audios_list = get_sample_images_list('Unsupervised_Audio', (self.fixed_noise, self.netG))
						plot_fig = plot_multiple_spectrograms(sample_audios_list, 7, 7, freq = 16000)
						cur_file_name = os.path.join(self.save_img_dir, str(self.save_cnt)+' : '+str(epoch)+'-'+str(i)+'.jpg')
						self.save_cnt += 1
						save_fig(cur_file_name, plot_fig)
						plot_fig.clf()
						
Esempio n. 6
0
    def train(self, num_epoch):
        for epoch in range(num_epoch):
            for i, data in enumerate(tqdm(self.train_dl)):
                # (1) : minimize -mean(D(x)) + mean(D(G(z)))
                self.netD.zero_grad()
                real_images = data[0].to(self.device)
                bs = real_images.size(0)
                # real labels (bs)
                real_label = torch.full((bs, ),
                                        self.real_label,
                                        device=self.device)
                # fake labels (bs)
                fake_label = torch.full((bs, ),
                                        self.fake_label,
                                        device=self.device)
                # noise (bs, nz, 1, 1), fake images (bs, cn, 64, 64)
                noise = generate_noise(bs, self.nz, self.device)
                fake_images = self.netG(noise)
                # calculate the discriminator results for both real & fake
                c_xr = self.netD(real_images)  # (bs, 1, 1, 1)
                c_xr = c_xr.view(-1)  # (bs)
                c_xf = self.netD(fake_images.detach())  # (bs, 1, 1, 1)
                c_xf = c_xf.view(-1)  # (bs)
                # calculate the Discriminator loss
                errD_real = -torch.mean(c_xr)
                errD_fake = torch.mean(c_xf)

                errD = errD_real + errD_fake
                errD.backward()
                # update D using the gradients calculated previously
                self.optimizerD.step()

                # (2) : clip the parameters of the network in the range of (-c, c)
                for param in self.netD.parameters():
                    param.data.clamp_(-self.c, self.c)

                # (3) : minimize -mean(D(G(z)))
                #		only do this when i % n_critic == 0
                if (i % self.n_critic == 0):
                    self.netG.zero_grad()
                    # we updated the discriminator once, therefore recalculate c_xr, c_xf
                    noise = generate_noise(bs, self.nz, self.device)
                    fake_images = self.netG(noise)
                    c_xf = self.netD(fake_images)  # (bs, 1, 1, 1)
                    c_xf = c_xf.view(-1)  # (bs)
                    # calculate the Generator loss
                    errG = -torch.mean(c_xf)
                    errG.backward()
                    #update G using the gradients calculated previously
                    self.optimizerG.step()

                w_dist = -float(errD_real) - float(errD_fake)
                self.errD_records.append(float(errD))
                self.errG_records.append(float(errG))
                self.w_dist_records.append(w_dist)

                if (i % self.loss_interval == 0):
                    print(
                        '[%d/%d] [%d/%d] errD : %.4f, errG : %.4f, Wasserstein Distance : %.4f'
                        % (epoch + 1, num_epoch, i + 1,
                           self.train_iteration_per_epoch, errD, errG, w_dist))

                if (i % self.image_interval == 0):
                    if (self.special == None):
                        sample_images_list = get_sample_images_list(
                            'Unsupervised', (self.fixed_noise, self.netG))
                        plot_fig = plot_multiple_images(
                            sample_images_list, 4, 4)
                        cur_file_name = os.path.join(
                            self.save_img_dir,
                            str(self.save_cnt) + ' : ' + str(epoch) + '-' +
                            str(i) + '.jpg')
                        self.save_cnt += 1
                        save_fig(cur_file_name, plot_fig)
                        plot_fig.clf()

                    elif (self.special == 'Wave'):
                        sample_audios_list = get_sample_images_list(
                            'Unsupervised_Audio',
                            (self.fixed_noise, self.netG))
                        plot_fig = plot_multiple_spectrograms(
                            sample_audios_list, 4, 4, freq=16000)
                        cur_file_name = os.path.join(
                            self.save_img_dir,
                            str(self.save_cnt) + ' : ' + str(epoch) + '-' +
                            str(i) + '.jpg')
                        self.save_cnt += 1
                        save_fig(cur_file_name, plot_fig)
                        plot_fig.clf()

                if (self.snapshot_interval is not None):
                    if (i % self.snapshot_interval == 0):
                        save(
                            os.path.join(
                                self.save_snapshot_dir, 'Epoch' + str(epoch) +
                                '_' + str(i) + '.state'), self.netD, self.netG,
                            self.optimizerD, self.optimizerG)
    def train(self, num_epoch):
        for epoch in range(num_epoch):
            for i, data in enumerate(tqdm(self.train_dl)):
                self.netD.zero_grad()
                real_images = data[0].to(self.device)
                real_class = data[1].to(self.device)
                bs = real_images.size(0)

                noise = generate_noise(bs, self.nz, self.device)
                fake_class = torch.randint(0, self.n_classes,
                                           size=(bs,
                                                 1)).view(bs,
                                                          1).to(self.device)
                one_hot_labels_fake = torch.FloatTensor(bs, self.n_classes).to(
                    self.device)
                one_hot_labels_fake.zero_()
                one_hot_labels_fake.scatter_(1,
                                             fake_class.view(bs, 1).long(),
                                             1.0)
                fake_images = self.netG(noise, one_hot_labels_fake)

                one_hot_labels = torch.FloatTensor(bs, self.n_classes).to(
                    self.device)
                one_hot_labels.zero_()
                one_hot_labels.scatter_(1, real_class.view(bs, 1), 1.0)

                c_xr = self.netD(real_images, one_hot_labels)
                c_xr = c_xr.view(-1)
                c_xf = self.netD(fake_images.detach(), one_hot_labels_fake)
                c_xf = c_xf.view(-1)

                if (self.require_type == 0 or self.require_type == 1):
                    errD = self.loss.d_loss(c_xr, c_xf)
                elif (self.require_type == 2):
                    errD = self.loss.d_loss(c_xr, c_xf, real_images,
                                            fake_images)

                if (self.use_gradient_penalty != False):
                    errD += self.use_gradient_penalty * self.gradient_penalty(
                        real_images, fake_images, one_hot_labels,
                        one_hot_labels_fake)

                errD.backward()
                self.optimizerD.step()

                if (self.weight_clip != None):
                    for param in self.netD.parameters():
                        param.data.clamp_(-self.weight_clip, self.weight_clip)

                self.netG.zero_grad()
                if (self.resample):
                    noise = generate_noise(bs, self.nz, self.device)
                    one_hot_labels_fake = torch.FloatTensor(
                        bs, self.n_classes).to(self.device)
                    one_hot_labels_fake.zero_()
                    one_hot_labels_fake.scatter_(1,
                                                 fake_class.view(bs, 1).long(),
                                                 1.0)
                    fake_images = self.netG(noise, one_hot_labels_fake)

                if (self.require_type == 0):
                    c_xf = self.netD(fake_images, one_hot_labels_fake)
                    c_xf = c_xf.view(-1)
                    errG = self.loss.g_loss(c_xf)
                if (self.require_type == 1 or self.require_type == 2):
                    c_xr = self.netD(real_images,
                                     one_hot_labels)  # (bs, 1, 1, 1)
                    c_xr = c_xr.view(-1)  # (bs)
                    c_xf = self.netD(fake_images,
                                     one_hot_labels_fake)  # (bs, 1, 1, 1)
                    c_xf = c_xf.view(-1)
                    errG = self.loss.g_loss(c_xr, c_xf)
                errG.backward()
                self.optimizerG.step()

                self.errD_records.append(float(errD))
                self.errG_records.append(float(errG))

                if (i % self.loss_interval == 0):
                    print('[%d/%d] [%d/%d] errD : %.4f, errG : %.4f' %
                          (epoch + 1, num_epoch, i + 1,
                           self.train_iteration_per_epoch, errD, errG))

                if (i % self.image_interval == 0):
                    sample_images_list = get_sample_images_list(
                        'Conditional',
                        (self.fixed_noise, self.fixed_one_hot_labels,
                         self.n_classes, self.netG))
                    plot_fig = plot_multiple_images(sample_images_list,
                                                    self.n_classes, 7)
                    cur_file_name = os.path.join(
                        self.save_img_dir,
                        str(self.save_cnt) + ' : ' + str(epoch) + '-' +
                        str(i) + '.jpg')
                    self.save_cnt += 1
                    save_fig(cur_file_name, plot_fig)
                    plot_fig.clf()