def train(self, num_epoch):
        for epoch in range(num_epoch):
            if (self.resample):
                train_dl_iter = iter(self.train_dl)
            for i, data in enumerate(tqdm(self.train_dl)):
                # (1) : minimizes mean((D(x) - mean(D(G(z))) - 1)**2) + mean((D(G(z)) - mean(D(x)) + 1)**2)
                self.netD.zero_grad()
                real_images = data[0].to(self.device)
                bs = real_images.size(0)
                # real labels (bs)
                real_label = torch.full((bs, ),
                                        self.real_label,
                                        device=self.device)
                # fake labels (bs)
                fake_label = torch.full((bs, ),
                                        self.fake_label,
                                        device=self.device)
                # noise (bs, nz, 1, 1), fake images (bs, cn, 64, 64)
                noise = generate_noise(bs, self.nz, self.device)
                fake_images = self.netG(noise)
                # calculate the discriminator results for both real & fake
                c_xr = self.netD(real_images)  # (bs, 1, 1, 1)
                c_xr = c_xr.view(-1)  # (bs)
                c_xf = self.netD(fake_images.detach())  # (bs, 1, 1, 1)
                c_xf = c_xf.view(-1)  # (bs)
                # calculate the Discriminator loss
                errD = (torch.mean(
                    (c_xr - torch.mean(c_xf) - real_label)**2) + torch.mean(
                        (c_xf - torch.mean(c_xr) + real_label)**2)) / 2.0
                errD.backward()
                # update D using the gradients calculated previously
                self.optimizerD.step()

                # (2) : minimizes mean((D(G(z)) - mean(D(x)) - 1)**2) + mean((D(x) - mean(D(G(z))) + 1)**2)
                self.netG.zero_grad()
                if (self.resample):
                    real_images = next(train_dl_iter)[0].to(self.device)
                    noise = generate_noise(bs, self.nz, self.device)
                    fake_images = self.netG(noise)
                # we updated the discriminator once, therefore recalculate c_xr, c_xf
                c_xr = self.netD(real_images)  # (bs, 1, 1, 1)
                c_xr = c_xr.view(-1)  # (bs)
                c_xf = self.netD(fake_images)  # (bs, 1, 1, 1)
                c_xf = c_xf.view(-1)  # (bs)
                # calculate the Generator loss
                errG = (torch.mean(
                    (c_xf - torch.mean(c_xr) - real_label)**2) + torch.mean(
                        (c_xr - torch.mean(c_xf) + real_label)**2)) / 2.0
                errG.backward()
                # update G using the gradients calculated previously
                self.optimizerG.step()

                self.errD_records.append(float(errD))
                self.errG_records.append(float(errG))

                if (i % self.loss_interval == 0):
                    print('[%d/%d] [%d/%d] errD : %.4f, errG : %.4f' %
                          (epoch + 1, num_epoch, i + 1,
                           self.train_iteration_per_epoch, errD, errG))

                if (i % self.image_interval == 0):
                    if (self.special == None):
                        sample_images_list = get_sample_images_list(
                            'Unsupervised', (self.fixed_noise, self.netG))
                        plot_fig = plot_multiple_images(
                            sample_images_list, 4, 4)
                        cur_file_name = os.path.join(
                            self.save_img_dir,
                            str(self.save_cnt) + ' : ' + str(epoch) + '-' +
                            str(i) + '.jpg')
                        self.save_cnt += 1
                        save_fig(cur_file_name, plot_fig)
                        plot_fig.clf()

                    elif (self.special == 'Wave'):
                        sample_audios_list = get_sample_images_list(
                            'Unsupervised_Audio',
                            (self.fixed_noise, self.netG))
                        plot_fig = plot_multiple_spectrograms(
                            sample_audios_list, 4, 4, freq=16000)
                        cur_file_name = os.path.join(
                            self.save_img_dir,
                            str(self.save_cnt) + ' : ' + str(epoch) + '-' +
                            str(i) + '.jpg')
                        self.save_cnt += 1
                        save_fig(cur_file_name, plot_fig)
                        plot_fig.clf()

                if (self.snapshot_interval is not None):
                    if (i % self.snapshot_interval == 0):
                        save(
                            os.path.join(
                                self.save_snapshot_dir, 'Epoch' + str(epoch) +
                                '_' + str(i) + '.state'), self.netD, self.netG,
                            self.optimizerD, self.optimizerG)
    def train(self, num_epoch):
        criterion = nn.BCELoss()
        for epoch in range(num_epoch):
            for i, data in enumerate(tqdm(self.train_dl)):
                # (1) : maximize log(D(x)) + log(1 - D(G(z)))
                # 		also means minimize (-log(D(x))) + (-log(1 - D(G(z))))
                self.netD.zero_grad()

                # first, calculate -log(D(x)) and its gradients using real images
                # real images (bs, nc, 64, 64)
                real_images = data[0].to(self.device)
                bs = real_images.size(0)
                # real labels (bs)
                label = torch.full((bs, ), self.real_label, device=self.device)
                output = self.netD(real_images)  # (bs, 1, 1, 1)
                output = output.view(-1)  # (bs)
                # BCELoss of output(bs), and real label(bs)
                errD_real = criterion(output, label)  # -log(D(x))
                # calculate the gradients
                errD_real.backward()

                # second, calculate -log(1 - D(G(z))) and its gradients using fake images
                # noise (bs, nz, 1, 1), fake images (bs, nc, 64, 64)
                noise = generate_noise(bs, self.nz, self.device)
                fake_images = self.netG(noise)
                # fake labels (bs)
                label.fill_(self.fake_label)
                output = self.netD(fake_images.detach())  # (bs, 1, 1, 1)
                output = output.view(-1)  # (bs)
                # BCELoss of output(bs), and fake labels(bs)
                errD_fake = criterion(output, label)  # -log(1 - D(G(z)))
                # calculate the gradients
                errD_fake.backward()

                # calculate the final loss value, (-log(D(x))) + (-log(1 - D(G(z))))
                errD = errD_real + errD_fake
                # update D using the gradients calculated previously
                self.optimizerD.step()

                # (2) : maximize log(D(G(z)))
                #		also means minimize -log(D(G(z)))
                self.netG.zero_grad()
                if (self.resample):
                    noise = generate_noise(bs, self.nz, self.device)
                    fake_images = self.netG(noise)

                # first, calculate -log(D(G(z))) and its gradients using fake images
                # real labels (bs)
                label.fill_(self.real_label)
                output = self.netD(fake_images)  # (bs, 1, 1, 1)
                output = output.view(-1)  # (bs)
                # BCELoss of output(bs), and real labels(bs)
                errG = criterion(output, label)  # -log(D(G(z)))
                #calculate the gradients
                errG.backward()

                #update G using the gradients calculated previously
                self.optimizerG.step()

                self.errD_records.append(float(errD))
                self.errG_records.append(float(errG))

                if (i % self.loss_interval == 0):
                    print('[%d/%d] [%d/%d] errD : %.4f, errG : %.4f' %
                          (epoch + 1, num_epoch, i + 1,
                           self.train_iteration_per_epoch, errD, errG))

                if (i % self.image_interval == 0):
                    if (self.special == None):
                        sample_images_list = get_sample_images_list(
                            'Unsupervised', (self.fixed_noise, self.netG))
                        plot_fig = plot_multiple_images(
                            sample_images_list, 4, 4)
                        cur_file_name = os.path.join(
                            self.save_img_dir,
                            str(self.save_cnt) + ' : ' + str(epoch) + '-' +
                            str(i) + '.jpg')
                        self.save_cnt += 1
                        save_fig(cur_file_name, plot_fig)
                        plot_fig.clf()

                    elif (self.special == 'Wave'):
                        sample_audios_list = get_sample_images_list(
                            'Unsupervised_Audio',
                            (self.fixed_noise, self.netG))
                        plot_fig = plot_multiple_spectrograms(
                            sample_audios_list, 4, 4, freq=16000)
                        cur_file_name = os.path.join(
                            self.save_img_dir,
                            str(self.save_cnt) + ' : ' + str(epoch) + '-' +
                            str(i) + '.jpg')
                        self.save_cnt += 1
                        save_fig(cur_file_name, plot_fig)
                        plot_fig.clf()

                if (self.snapshot_interval is not None):
                    if (i % self.snapshot_interval == 0):
                        save(
                            os.path.join(
                                self.save_snapshot_dir, 'Epoch' + str(epoch) +
                                '_' + str(i) + '.state'), self.netD, self.netG,
                            self.optimizerD, self.optimizerG)
Beispiel #3
0
    def train(self, num_epoch):
        for epoch in range(num_epoch):
            for i, data in enumerate(tqdm(self.train_dl)):
                # (1) : minimize -mean(D(x)) + mean(D(G(z)))
                self.netD.zero_grad()
                real_images = data[0].to(self.device)
                bs = real_images.size(0)
                # real labels (bs)
                real_label = torch.full((bs, ),
                                        self.real_label,
                                        device=self.device)
                # fake labels (bs)
                fake_label = torch.full((bs, ),
                                        self.fake_label,
                                        device=self.device)
                # noise (bs, nz, 1, 1), fake images (bs, cn, 64, 64)
                noise = generate_noise(bs, self.nz, self.device)
                fake_images = self.netG(noise)
                # calculate the discriminator results for both real & fake
                c_xr = self.netD(real_images)  # (bs, 1, 1, 1)
                c_xr = c_xr.view(-1)  # (bs)
                c_xf = self.netD(fake_images.detach())  # (bs, 1, 1, 1)
                c_xf = c_xf.view(-1)  # (bs)
                # calculate the Discriminator loss
                errD_real = -torch.mean(c_xr)
                errD_fake = torch.mean(c_xf)

                errD = errD_real + errD_fake
                errD.backward()
                # update D using the gradients calculated previously
                self.optimizerD.step()

                # (2) : clip the parameters of the network in the range of (-c, c)
                for param in self.netD.parameters():
                    param.data.clamp_(-self.c, self.c)

                # (3) : minimize -mean(D(G(z)))
                #		only do this when i % n_critic == 0
                if (i % self.n_critic == 0):
                    self.netG.zero_grad()
                    # we updated the discriminator once, therefore recalculate c_xr, c_xf
                    noise = generate_noise(bs, self.nz, self.device)
                    fake_images = self.netG(noise)
                    c_xf = self.netD(fake_images)  # (bs, 1, 1, 1)
                    c_xf = c_xf.view(-1)  # (bs)
                    # calculate the Generator loss
                    errG = -torch.mean(c_xf)
                    errG.backward()
                    #update G using the gradients calculated previously
                    self.optimizerG.step()

                w_dist = -float(errD_real) - float(errD_fake)
                self.errD_records.append(float(errD))
                self.errG_records.append(float(errG))
                self.w_dist_records.append(w_dist)

                if (i % self.loss_interval == 0):
                    print(
                        '[%d/%d] [%d/%d] errD : %.4f, errG : %.4f, Wasserstein Distance : %.4f'
                        % (epoch + 1, num_epoch, i + 1,
                           self.train_iteration_per_epoch, errD, errG, w_dist))

                if (i % self.image_interval == 0):
                    if (self.special == None):
                        sample_images_list = get_sample_images_list(
                            'Unsupervised', (self.fixed_noise, self.netG))
                        plot_fig = plot_multiple_images(
                            sample_images_list, 4, 4)
                        cur_file_name = os.path.join(
                            self.save_img_dir,
                            str(self.save_cnt) + ' : ' + str(epoch) + '-' +
                            str(i) + '.jpg')
                        self.save_cnt += 1
                        save_fig(cur_file_name, plot_fig)
                        plot_fig.clf()

                    elif (self.special == 'Wave'):
                        sample_audios_list = get_sample_images_list(
                            'Unsupervised_Audio',
                            (self.fixed_noise, self.netG))
                        plot_fig = plot_multiple_spectrograms(
                            sample_audios_list, 4, 4, freq=16000)
                        cur_file_name = os.path.join(
                            self.save_img_dir,
                            str(self.save_cnt) + ' : ' + str(epoch) + '-' +
                            str(i) + '.jpg')
                        self.save_cnt += 1
                        save_fig(cur_file_name, plot_fig)
                        plot_fig.clf()

                if (self.snapshot_interval is not None):
                    if (i % self.snapshot_interval == 0):
                        save(
                            os.path.join(
                                self.save_snapshot_dir, 'Epoch' + str(epoch) +
                                '_' + str(i) + '.state'), self.netD, self.netG,
                            self.optimizerD, self.optimizerG)
    def train(self, num_epoch):
        criterion = nn.BCELoss()
        for epoch in range(num_epoch):
            for i, data in enumerate(tqdm(self.train_dl)):
                # (1) : minimize 0.5 * mean((D(x, y) - 1)^2) + 0.5 * mean((D(G(z, y), y) - 0)^2)
                self.netD.zero_grad()
                real_images = data[0].to(self.device)
                real_class = data[1].to(self.device)

                bs = real_images.size(0)
                # real labels (bs)
                real_label = torch.full((bs, ),
                                        self.real_label,
                                        device=self.device)
                # fake labels (bs)
                fake_label = torch.full((bs, ),
                                        self.fake_label,
                                        device=self.device)

                # one hot labels (bs, n_classes)
                one_hot_labels = torch.FloatTensor(bs, self.n_classes).to(
                    self.device)
                one_hot_labels.zero_()
                one_hot_labels.scatter_(1, real_class.view(bs, 1), 1.0)

                # noise (bs, nz, 1, 1), fake images (bs, nc, 64, 64)
                noise = generate_noise(bs, self.nz, self.device)

                fake_class = torch.randint(0, self.n_classes,
                                           size=(bs,
                                                 1)).view(bs,
                                                          1).to(self.device)
                # one hot labels (bs, n_classes)
                one_hot_labels_fake = torch.FloatTensor(bs, self.n_classes).to(
                    self.device)
                one_hot_labels_fake.zero_()
                one_hot_labels_fake.scatter_(1,
                                             fake_class.view(bs, 1).long(),
                                             1.0)

                fake_images = self.netG(noise, one_hot_labels_fake)

                # calculate the discriminator results for both real & fake
                c_xr = self.netD(real_images, one_hot_labels)  # (bs, 1, 1, 1)
                c_xr = c_xr.view(-1)  # (bs)
                c_xf = self.netD(fake_images.detach(),
                                 one_hot_labels_fake)  # (bs, 1, 1, 1)
                c_xf = c_xf.view(-1)  # (bs)
                # calculate the discriminator loss
                errD = criterion(c_xr, real_label) + criterion(
                    c_xf, fake_label)
                errD.backward()
                # update D using the gradients calculated previously
                self.optimizerD.step()

                # (2) : minimize 0.5 * mean((D(G(z)) - 1)^2)
                self.netG.zero_grad()
                if (self.resample):
                    noise = generate_noise(bs, self.nz, self.device)
                    one_hot_labels_fake = torch.FloatTensor(
                        bs, self.n_classes).to(self.device)
                    one_hot_labels_fake.zero_()
                    one_hot_labels_fake.scatter_(1,
                                                 fake_class.view(bs, 1).long(),
                                                 1.0)
                    fake_images = self.netG(noise, one_hot_labels_fake)
                # we updated the discriminator once, therefore recalculate c_xf
                c_xf = self.netD(fake_images,
                                 one_hot_labels_fake)  # (bs, 1, 1, 1)
                c_xf = c_xf.view(-1)  # (bs)
                # calculate the Generator loss
                errG = criterion(c_xf,
                                 real_label)  # 0.5 * mean((D(G(z)) - 1)^2)
                errG.backward()
                #update G using the gradients calculated previously
                self.optimizerG.step()

                self.errD_records.append(float(errD))
                self.errG_records.append(float(errG))

                if (i % self.loss_interval == 0):
                    print('[%d/%d] [%d/%d] errD : %.4f, errG : %.4f' %
                          (epoch + 1, num_epoch, i + 1,
                           self.train_iteration_per_epoch, errD, errG))

                if (i % self.image_interval == 0):
                    sample_images_list = get_sample_images_list(
                        'Conditional',
                        (self.fixed_noise, self.fixed_one_hot_labels,
                         self.n_classes, self.netG))
                    plot_fig = plot_multiple_images(sample_images_list,
                                                    self.n_classes, 1)
                    cur_file_name = os.path.join(
                        self.save_img_dir,
                        str(self.save_cnt) + ' : ' + str(epoch) + '-' +
                        str(i) + '.jpg')
                    self.save_cnt += 1
                    save_fig(cur_file_name, plot_fig)
                    plot_fig.clf()

                if (self.snapshot_interval is not None):
                    if (i % self.snapshot_interval == 0):
                        save(
                            os.path.join(
                                self.save_snapshot_dir, 'Epoch' + str(epoch) +
                                '_' + str(i) + '.state'), self.netD, self.netG,
                            self.optimizerD, self.optimizerG)
    def train(self, num_epoch):
        for epoch in range(num_epoch):
            for i, data in enumerate(tqdm(self.train_dl)):
                self.netD.zero_grad()
                real_images = data[0].to(self.device)
                real_class = data[1].to(self.device)
                bs = real_images.size(0)

                noise = generate_noise(bs, self.nz, self.device)
                fake_class = torch.randint(0, self.n_classes,
                                           size=(bs,
                                                 1)).view(bs,
                                                          1).to(self.device)
                one_hot_labels_fake = torch.FloatTensor(bs, self.n_classes).to(
                    self.device)
                one_hot_labels_fake.zero_()
                one_hot_labels_fake.scatter_(1,
                                             fake_class.view(bs, 1).long(),
                                             1.0)
                fake_images = self.netG(noise, one_hot_labels_fake)

                one_hot_labels = torch.FloatTensor(bs, self.n_classes).to(
                    self.device)
                one_hot_labels.zero_()
                one_hot_labels.scatter_(1, real_class.view(bs, 1), 1.0)

                c_xr = self.netD(real_images, one_hot_labels)
                c_xr = c_xr.view(-1)
                c_xf = self.netD(fake_images.detach(), one_hot_labels_fake)
                c_xf = c_xf.view(-1)

                if (self.require_type == 0 or self.require_type == 1):
                    errD = self.loss.d_loss(c_xr, c_xf)
                elif (self.require_type == 2):
                    errD = self.loss.d_loss(c_xr, c_xf, real_images,
                                            fake_images)

                if (self.use_gradient_penalty != False):
                    errD += self.use_gradient_penalty * self.gradient_penalty(
                        real_images, fake_images, one_hot_labels,
                        one_hot_labels_fake)

                errD.backward()
                self.optimizerD.step()

                if (self.weight_clip != None):
                    for param in self.netD.parameters():
                        param.data.clamp_(-self.weight_clip, self.weight_clip)

                self.netG.zero_grad()
                if (self.resample):
                    noise = generate_noise(bs, self.nz, self.device)
                    one_hot_labels_fake = torch.FloatTensor(
                        bs, self.n_classes).to(self.device)
                    one_hot_labels_fake.zero_()
                    one_hot_labels_fake.scatter_(1,
                                                 fake_class.view(bs, 1).long(),
                                                 1.0)
                    fake_images = self.netG(noise, one_hot_labels_fake)

                if (self.require_type == 0):
                    c_xf = self.netD(fake_images, one_hot_labels_fake)
                    c_xf = c_xf.view(-1)
                    errG = self.loss.g_loss(c_xf)
                if (self.require_type == 1 or self.require_type == 2):
                    c_xr = self.netD(real_images,
                                     one_hot_labels)  # (bs, 1, 1, 1)
                    c_xr = c_xr.view(-1)  # (bs)
                    c_xf = self.netD(fake_images,
                                     one_hot_labels_fake)  # (bs, 1, 1, 1)
                    c_xf = c_xf.view(-1)
                    errG = self.loss.g_loss(c_xr, c_xf)
                errG.backward()
                self.optimizerG.step()

                self.errD_records.append(float(errD))
                self.errG_records.append(float(errG))

                if (i % self.loss_interval == 0):
                    print('[%d/%d] [%d/%d] errD : %.4f, errG : %.4f' %
                          (epoch + 1, num_epoch, i + 1,
                           self.train_iteration_per_epoch, errD, errG))

                if (i % self.image_interval == 0):
                    sample_images_list = get_sample_images_list(
                        'Conditional',
                        (self.fixed_noise, self.fixed_one_hot_labels,
                         self.n_classes, self.netG))
                    plot_fig = plot_multiple_images(sample_images_list,
                                                    self.n_classes, 7)
                    cur_file_name = os.path.join(
                        self.save_img_dir,
                        str(self.save_cnt) + ' : ' + str(epoch) + '-' +
                        str(i) + '.jpg')
                    self.save_cnt += 1
                    save_fig(cur_file_name, plot_fig)
                    plot_fig.clf()