def train(self, num_epoch):
        for epoch in range(num_epoch):
            if (self.resample):
                train_dl_iter = iter(self.train_dl)
            for i, data in enumerate(tqdm(self.train_dl)):
                # (1) : minimizes mean((D(x) - mean(D(G(z))) - 1)**2) + mean((D(G(z)) - mean(D(x)) + 1)**2)
                self.netD.zero_grad()
                real_images = data[0].to(self.device)
                bs = real_images.size(0)
                # real labels (bs)
                real_label = torch.full((bs, ),
                                        self.real_label,
                                        device=self.device)
                # fake labels (bs)
                fake_label = torch.full((bs, ),
                                        self.fake_label,
                                        device=self.device)
                # noise (bs, nz, 1, 1), fake images (bs, cn, 64, 64)
                noise = generate_noise(bs, self.nz, self.device)
                fake_images = self.netG(noise)
                # calculate the discriminator results for both real & fake
                c_xr = self.netD(real_images)  # (bs, 1, 1, 1)
                c_xr = c_xr.view(-1)  # (bs)
                c_xf = self.netD(fake_images.detach())  # (bs, 1, 1, 1)
                c_xf = c_xf.view(-1)  # (bs)
                # calculate the Discriminator loss
                errD = (torch.mean(
                    (c_xr - torch.mean(c_xf) - real_label)**2) + torch.mean(
                        (c_xf - torch.mean(c_xr) + real_label)**2)) / 2.0
                errD.backward()
                # update D using the gradients calculated previously
                self.optimizerD.step()

                # (2) : minimizes mean((D(G(z)) - mean(D(x)) - 1)**2) + mean((D(x) - mean(D(G(z))) + 1)**2)
                self.netG.zero_grad()
                if (self.resample):
                    real_images = next(train_dl_iter)[0].to(self.device)
                    noise = generate_noise(bs, self.nz, self.device)
                    fake_images = self.netG(noise)
                # we updated the discriminator once, therefore recalculate c_xr, c_xf
                c_xr = self.netD(real_images)  # (bs, 1, 1, 1)
                c_xr = c_xr.view(-1)  # (bs)
                c_xf = self.netD(fake_images)  # (bs, 1, 1, 1)
                c_xf = c_xf.view(-1)  # (bs)
                # calculate the Generator loss
                errG = (torch.mean(
                    (c_xf - torch.mean(c_xr) - real_label)**2) + torch.mean(
                        (c_xr - torch.mean(c_xf) + real_label)**2)) / 2.0
                errG.backward()
                # update G using the gradients calculated previously
                self.optimizerG.step()

                self.errD_records.append(float(errD))
                self.errG_records.append(float(errG))

                if (i % self.loss_interval == 0):
                    print('[%d/%d] [%d/%d] errD : %.4f, errG : %.4f' %
                          (epoch + 1, num_epoch, i + 1,
                           self.train_iteration_per_epoch, errD, errG))

                if (i % self.image_interval == 0):
                    if (self.special == None):
                        sample_images_list = get_sample_images_list(
                            'Unsupervised', (self.fixed_noise, self.netG))
                        plot_fig = plot_multiple_images(
                            sample_images_list, 4, 4)
                        cur_file_name = os.path.join(
                            self.save_img_dir,
                            str(self.save_cnt) + ' : ' + str(epoch) + '-' +
                            str(i) + '.jpg')
                        self.save_cnt += 1
                        save_fig(cur_file_name, plot_fig)
                        plot_fig.clf()

                    elif (self.special == 'Wave'):
                        sample_audios_list = get_sample_images_list(
                            'Unsupervised_Audio',
                            (self.fixed_noise, self.netG))
                        plot_fig = plot_multiple_spectrograms(
                            sample_audios_list, 4, 4, freq=16000)
                        cur_file_name = os.path.join(
                            self.save_img_dir,
                            str(self.save_cnt) + ' : ' + str(epoch) + '-' +
                            str(i) + '.jpg')
                        self.save_cnt += 1
                        save_fig(cur_file_name, plot_fig)
                        plot_fig.clf()

                if (self.snapshot_interval is not None):
                    if (i % self.snapshot_interval == 0):
                        save(
                            os.path.join(
                                self.save_snapshot_dir, 'Epoch' + str(epoch) +
                                '_' + str(i) + '.state'), self.netD, self.netG,
                            self.optimizerD, self.optimizerG)
예제 #2
0
	def train(self, num_epoch):
		for epoch in range(num_epoch):
			for i, data in enumerate(tqdm(self.train_dl)):
				self.netD.zero_grad()
				real_images = data[0].to(self.device)
				bs = real_images.size(0)

				noise = generate_noise(bs, self.nz, self.device)
				fake_images = self.netG(noise)

				c_xr = self.netD(real_images)
				c_xr = c_xr.view(-1)
				c_xf = self.netD(fake_images.detach())
				c_xf = c_xf.view(-1)

				if(self.require_type == 0 or self.require_type == 1):
					errD = self.loss.d_loss(c_xr, c_xf)
				elif(self.require_type == 2):
					errD = self.loss.d_loss(c_xr, c_xf, real_images, fake_images)
				
				if(self.use_gradient_penalty != False):
					errD += self.use_gradient_penalty * self.gradient_penalty(real_images, fake_images)

				errD.backward()
				self.optimizerD.step()

				if(self.weight_clip != None):
					for param in self.netD.parameters():
						param.data.clamp_(-self.weight_clip, self.weight_clip)

			
				self.netG.zero_grad()
				if(self.resample):
					noise = generate_noise(bs, self.nz, self.device)
					fake_images = self.netG(noise)

				if(self.require_type == 0):
					c_xf = self.netD(fake_images)
					c_xf = c_xf.view(-1)
					errG = self.loss.g_loss(c_xf)
				if(self.require_type == 1 or self.require_type == 2):
					c_xr = self.netD(real_images)				# (bs, 1, 1, 1)
					c_xr = c_xr.view(-1)						# (bs)
					c_xf = self.netD(fake_images)		# (bs, 1, 1, 1)
					c_xf = c_xf.view(-1)
					errG = self.loss.g_loss(c_xr, c_xf)
				errG.backward()
				self.optimizerG.step()

				self.errD_records.append(float(errD))
				self.errG_records.append(float(errG))

				if(i % self.loss_interval == 0):
					print('[%d/%d] [%d/%d] errD : %.4f, errG : %.4f'
						  %(epoch+1, num_epoch, i+1, self.train_iteration_per_epoch, errD, errG))

				if(i % self.image_interval == 0):
					if(self.special == None):
						sample_images_list = get_sample_images_list('Unsupervised', (self.fixed_noise, self.netG))
						plot_img = get_display_samples(sample_images_list, 7, 7)
						cur_file_name = os.path.join(self.save_img_dir, str(self.save_cnt)+' : '+str(epoch)+'-'+str(i)+'.jpg')
						self.save_cnt += 1
						cv2.imwrite(cur_file_name, plot_img)

					elif(self.special == 'Wave'):
						sample_audios_list = get_sample_images_list('Unsupervised_Audio', (self.fixed_noise, self.netG))
						plot_fig = plot_multiple_spectrograms(sample_audios_list, 7, 7, freq = 16000)
						cur_file_name = os.path.join(self.save_img_dir, str(self.save_cnt)+' : '+str(epoch)+'-'+str(i)+'.jpg')
						self.save_cnt += 1
						save_fig(cur_file_name, plot_fig)
						plot_fig.clf()
						
    def train(self, num_epoch):
        criterion = nn.BCELoss()
        for epoch in range(num_epoch):
            for i, data in enumerate(tqdm(self.train_dl)):
                # (1) : maximize log(D(x)) + log(1 - D(G(z)))
                # 		also means minimize (-log(D(x))) + (-log(1 - D(G(z))))
                self.netD.zero_grad()

                # first, calculate -log(D(x)) and its gradients using real images
                # real images (bs, nc, 64, 64)
                real_images = data[0].to(self.device)
                bs = real_images.size(0)
                # real labels (bs)
                label = torch.full((bs, ), self.real_label, device=self.device)
                output = self.netD(real_images)  # (bs, 1, 1, 1)
                output = output.view(-1)  # (bs)
                # BCELoss of output(bs), and real label(bs)
                errD_real = criterion(output, label)  # -log(D(x))
                # calculate the gradients
                errD_real.backward()

                # second, calculate -log(1 - D(G(z))) and its gradients using fake images
                # noise (bs, nz, 1, 1), fake images (bs, nc, 64, 64)
                noise = generate_noise(bs, self.nz, self.device)
                fake_images = self.netG(noise)
                # fake labels (bs)
                label.fill_(self.fake_label)
                output = self.netD(fake_images.detach())  # (bs, 1, 1, 1)
                output = output.view(-1)  # (bs)
                # BCELoss of output(bs), and fake labels(bs)
                errD_fake = criterion(output, label)  # -log(1 - D(G(z)))
                # calculate the gradients
                errD_fake.backward()

                # calculate the final loss value, (-log(D(x))) + (-log(1 - D(G(z))))
                errD = errD_real + errD_fake
                # update D using the gradients calculated previously
                self.optimizerD.step()

                # (2) : maximize log(D(G(z)))
                #		also means minimize -log(D(G(z)))
                self.netG.zero_grad()
                if (self.resample):
                    noise = generate_noise(bs, self.nz, self.device)
                    fake_images = self.netG(noise)

                # first, calculate -log(D(G(z))) and its gradients using fake images
                # real labels (bs)
                label.fill_(self.real_label)
                output = self.netD(fake_images)  # (bs, 1, 1, 1)
                output = output.view(-1)  # (bs)
                # BCELoss of output(bs), and real labels(bs)
                errG = criterion(output, label)  # -log(D(G(z)))
                #calculate the gradients
                errG.backward()

                #update G using the gradients calculated previously
                self.optimizerG.step()

                self.errD_records.append(float(errD))
                self.errG_records.append(float(errG))

                if (i % self.loss_interval == 0):
                    print('[%d/%d] [%d/%d] errD : %.4f, errG : %.4f' %
                          (epoch + 1, num_epoch, i + 1,
                           self.train_iteration_per_epoch, errD, errG))

                if (i % self.image_interval == 0):
                    if (self.special == None):
                        sample_images_list = get_sample_images_list(
                            'Unsupervised', (self.fixed_noise, self.netG))
                        plot_fig = plot_multiple_images(
                            sample_images_list, 4, 4)
                        cur_file_name = os.path.join(
                            self.save_img_dir,
                            str(self.save_cnt) + ' : ' + str(epoch) + '-' +
                            str(i) + '.jpg')
                        self.save_cnt += 1
                        save_fig(cur_file_name, plot_fig)
                        plot_fig.clf()

                    elif (self.special == 'Wave'):
                        sample_audios_list = get_sample_images_list(
                            'Unsupervised_Audio',
                            (self.fixed_noise, self.netG))
                        plot_fig = plot_multiple_spectrograms(
                            sample_audios_list, 4, 4, freq=16000)
                        cur_file_name = os.path.join(
                            self.save_img_dir,
                            str(self.save_cnt) + ' : ' + str(epoch) + '-' +
                            str(i) + '.jpg')
                        self.save_cnt += 1
                        save_fig(cur_file_name, plot_fig)
                        plot_fig.clf()

                if (self.snapshot_interval is not None):
                    if (i % self.snapshot_interval == 0):
                        save(
                            os.path.join(
                                self.save_snapshot_dir, 'Epoch' + str(epoch) +
                                '_' + str(i) + '.state'), self.netD, self.netG,
                            self.optimizerD, self.optimizerG)
예제 #4
0
    def train(self, num_epoch):
        for epoch in range(num_epoch):
            for i, data in enumerate(tqdm(self.train_dl)):
                # (1) : minimize -mean(D(x)) + mean(D(G(z)))
                self.netD.zero_grad()
                real_images = data[0].to(self.device)
                bs = real_images.size(0)
                # real labels (bs)
                real_label = torch.full((bs, ),
                                        self.real_label,
                                        device=self.device)
                # fake labels (bs)
                fake_label = torch.full((bs, ),
                                        self.fake_label,
                                        device=self.device)
                # noise (bs, nz, 1, 1), fake images (bs, cn, 64, 64)
                noise = generate_noise(bs, self.nz, self.device)
                fake_images = self.netG(noise)
                # calculate the discriminator results for both real & fake
                c_xr = self.netD(real_images)  # (bs, 1, 1, 1)
                c_xr = c_xr.view(-1)  # (bs)
                c_xf = self.netD(fake_images.detach())  # (bs, 1, 1, 1)
                c_xf = c_xf.view(-1)  # (bs)
                # calculate the Discriminator loss
                errD_real = -torch.mean(c_xr)
                errD_fake = torch.mean(c_xf)

                errD = errD_real + errD_fake
                errD.backward()
                # update D using the gradients calculated previously
                self.optimizerD.step()

                # (2) : clip the parameters of the network in the range of (-c, c)
                for param in self.netD.parameters():
                    param.data.clamp_(-self.c, self.c)

                # (3) : minimize -mean(D(G(z)))
                #		only do this when i % n_critic == 0
                if (i % self.n_critic == 0):
                    self.netG.zero_grad()
                    # we updated the discriminator once, therefore recalculate c_xr, c_xf
                    noise = generate_noise(bs, self.nz, self.device)
                    fake_images = self.netG(noise)
                    c_xf = self.netD(fake_images)  # (bs, 1, 1, 1)
                    c_xf = c_xf.view(-1)  # (bs)
                    # calculate the Generator loss
                    errG = -torch.mean(c_xf)
                    errG.backward()
                    #update G using the gradients calculated previously
                    self.optimizerG.step()

                w_dist = -float(errD_real) - float(errD_fake)
                self.errD_records.append(float(errD))
                self.errG_records.append(float(errG))
                self.w_dist_records.append(w_dist)

                if (i % self.loss_interval == 0):
                    print(
                        '[%d/%d] [%d/%d] errD : %.4f, errG : %.4f, Wasserstein Distance : %.4f'
                        % (epoch + 1, num_epoch, i + 1,
                           self.train_iteration_per_epoch, errD, errG, w_dist))

                if (i % self.image_interval == 0):
                    if (self.special == None):
                        sample_images_list = get_sample_images_list(
                            'Unsupervised', (self.fixed_noise, self.netG))
                        plot_fig = plot_multiple_images(
                            sample_images_list, 4, 4)
                        cur_file_name = os.path.join(
                            self.save_img_dir,
                            str(self.save_cnt) + ' : ' + str(epoch) + '-' +
                            str(i) + '.jpg')
                        self.save_cnt += 1
                        save_fig(cur_file_name, plot_fig)
                        plot_fig.clf()

                    elif (self.special == 'Wave'):
                        sample_audios_list = get_sample_images_list(
                            'Unsupervised_Audio',
                            (self.fixed_noise, self.netG))
                        plot_fig = plot_multiple_spectrograms(
                            sample_audios_list, 4, 4, freq=16000)
                        cur_file_name = os.path.join(
                            self.save_img_dir,
                            str(self.save_cnt) + ' : ' + str(epoch) + '-' +
                            str(i) + '.jpg')
                        self.save_cnt += 1
                        save_fig(cur_file_name, plot_fig)
                        plot_fig.clf()

                if (self.snapshot_interval is not None):
                    if (i % self.snapshot_interval == 0):
                        save(
                            os.path.join(
                                self.save_snapshot_dir, 'Epoch' + str(epoch) +
                                '_' + str(i) + '.state'), self.netD, self.netG,
                            self.optimizerD, self.optimizerG)