Example #1
0
    def train(self):
        self.train_hist = {}
        self.train_hist['D_loss'] = []
        self.train_hist['G_loss'] = []
        self.train_hist['E_loss'] = []
        self.train_hist['per_epoch_time'] = []
        self.train_hist['total_time'] = []

        if torch.cuda.is_available():
            self.y_real_, self.y_fake_ = Variable(
                torch.ones(self.batch_size, 1).cuda()), Variable(
                    torch.zeros(self.batch_size, 1).cuda())
        else:
            self.y_real_, self.y_fake_ = Variable(
                torch.ones(self.batch_size,
                           1)), Variable(torch.zeros(self.batch_size, 1))

        self.D.train()
        print('training start!!')
        start_time = time.time()
        for epoch in range(self.epoch):
            # reset training mode of G and E
            self.G.train()
            self.E.train()
            epoch_start_time = time.time()

            for iter, (X, _) in enumerate(self.data_loader):
                X = utils.to_var(X)
                """Discriminator"""
                z = utils.to_var(
                    torch.randn((self.batch_size,
                                 self.z_dim)).view(-1, self.z_dim, 1, 1))
                X_hat = self.G(z)
                D_real = self.D(X).squeeze().view(-1, 1)
                D_fake = self.D(X_hat).squeeze().view(-1, 1)
                D_loss = self.BCE_loss(D_real, self.y_real_) + self.BCE_loss(
                    D_fake, self.y_fake_)
                self.train_hist['D_loss'].append(D_loss.data[0])
                # Optimize
                D_loss.backward()
                self.D_optimizer.step()
                self.__reset_grad()
                """Encoder"""
                z = utils.to_var(
                    torch.randn((self.batch_size,
                                 self.z_dim)).view(-1, self.z_dim, 1, 1))
                X_hat = self.G(z)
                z_mu, z_sigma = self.E(X_hat)
                z_mu, z_sigma = z_mu.squeeze(), z_sigma.squeeze()
                # - loglikehood
                E_loss = torch.mean(
                    torch.mean(
                        0.5 * (z - z_mu)**2 * torch.exp(-z_sigma) +
                        0.5 * z_sigma + 0.5 * np.log(2 * np.pi), 1))
                self.train_hist['E_loss'].append(E_loss.data[0])
                # Optimize
                E_loss.backward()
                self.E_optimizer.step()
                self.__reset_grad()
                """Generator"""
                # Use both Discriminator and Encoder to update Generator
                z = utils.to_var(
                    torch.randn((self.batch_size,
                                 self.z_dim)).view(-1, self.z_dim, 1, 1))
                X_hat = self.G(z)
                D_fake = self.D(X_hat).squeeze().view(-1, 1)
                z_mu, z_sigma = self.E(X_hat)
                z_mu, z_sigma = z_mu.squeeze(), z_sigma.squeeze()
                mode_loss = torch.mean(
                    torch.mean(
                        0.5 * (z - z_mu)**2 * torch.exp(-z_sigma) +
                        0.5 * z_sigma + 0.5 * np.log(2 * np.pi), 1))
                G_loss = self.BCE_loss(D_fake, self.y_real_)
                total_loss = G_loss + mode_loss
                self.train_hist['G_loss'].append(G_loss.data[0])
                # Optimize
                total_loss.backward()
                self.G_optimizer.step()
                self.__reset_grad()
                """ Plot """
                if (iter + 1
                    ) == self.data_loader.dataset.__len__() // self.batch_size:
                    # Print and plot every epoch
                    print(
                        'Epoch-{}; D_loss: {:.4}; G_loss: {:.4}; E_loss: {:.4}\n'
                        .format(epoch, D_loss.data[0], G_loss.data[0],
                                E_loss.data[0]))
                    for iter, (X, _) in enumerate(self.valid_loader):
                        X = utils.to_var(X)
                        self.visualize_results(X, epoch + 1)
                        break

                    break

            self.train_hist['per_epoch_time'].append(time.time() -
                                                     epoch_start_time)

            # Save model every 5 epochs
            if epoch % 5 == 0:
                self.save()

        self.train_hist['total_time'].append(time.time() - start_time)
        print("Avg one epoch time: %.2f, total %d epochs time: %.2f" %
              (np.mean(self.train_hist['per_epoch_time']), self.epoch,
               self.train_hist['total_time'][0]))
        print("Training finish!... save final training results")
        self.save()

        # Generate animation of reconstructed plot
        utils.generate_animation(
            self.root + '/' + self.result_dir + '/' + self.dataset + '/' +
            self.model_name + '/reconstructed', self.epoch)
        utils.loss_plot(
            self.train_hist,
            os.path.join(self.root, self.save_dir, self.dataset,
                         self.model_name), self.model_name)
    def train(self):
        self.train_hist = {}
        self.train_hist['D_loss'] = []
        self.train_hist['G_loss'] = []
        self.train_hist['per_epoch_time'] = []
        self.train_hist['total_time'] = []

        if self.gpu_mode:
            self.y_real_, self.y_fake_ = Variable(torch.ones(self.batch_size, 1).cuda()), Variable(torch.zeros(self.batch_size, 1).cuda())
        else:
            self.y_real_, self.y_fake_ = Variable(torch.ones(self.batch_size, 1)), Variable(torch.zeros(self.batch_size, 1))

        self.D.train()
        print('training start!!')
        start_time = time.time()
        for epoch in range(self.epoch):
            self.G.train()
            epoch_start_time = time.time()
            for iter, (x_, _) in enumerate(self.data_loader):
                if iter == self.data_loader.dataset.__len__() // self.batch_size:
                    break

                z_ = torch.rand((self.batch_size, self.z_dim))

                if self.gpu_mode:
                    x_, z_ = Variable(x_.cuda()), Variable(z_.cuda())
                else:
                    x_, z_ = Variable(x_), Variable(z_)

                # update D network
                self.D_optimizer.zero_grad()

                D_real = self.D(x_)
                D_real_loss = self.BCE_loss(D_real, self.y_real_)

                G_ = self.G(z_)
                D_fake = self.D(G_)
                D_fake_loss = self.BCE_loss(D_fake, self.y_fake_)

                D_loss = D_real_loss + D_fake_loss
                self.train_hist['D_loss'].append(D_loss.data[0])

                D_loss.backward()
                self.D_optimizer.step()

                # update G network
                self.G_optimizer.zero_grad()

                G_ = self.G(z_)
                D_fake = self.D(G_)
                G_loss = self.BCE_loss(D_fake, self.y_real_)
                self.train_hist['G_loss'].append(G_loss.data[0])

                G_loss.backward()
                self.G_optimizer.step()

                if ((iter + 1) % 100) == 0:
                    print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" %
                          ((epoch + 1), (iter + 1), self.data_loader.dataset.__len__() // self.batch_size, D_loss.data[0], G_loss.data[0]))

            self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time)
            self.visualize_results((epoch+1))

        self.train_hist['total_time'].append(time.time() - start_time)
        print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']),
              self.epoch, self.train_hist['total_time'][0]))
        print("Training finish!... save training results")

        self.save()
        utils.generate_animation(self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name,
                                 self.epoch)
        utils.loss_plot(self.train_hist, os.path.join(self.save_dir, self.dataset, self.model_name), self.model_name)
Example #3
0
    def train(self):
        self.train_hist = {}
        self.train_hist['D_loss'] = []
        self.train_hist['G_loss'] = []
        self.train_hist['per_epoch_time'] = []
        self.train_hist['total_time'] = []

        self.y_real_, self.y_fake_ = torch.ones(self.batch_size,
                                                1), torch.zeros(
                                                    self.batch_size, 1)
        if self.gpu_mode:
            self.y_real_, self.y_fake_ = self.y_real_.cuda(
            ), self.y_fake_.cuda()

        self.D.train()
        print('training start!!')
        start_time = time.time()
        Q_stack = []
        for epoch in range(self.epoch):
            self.G.train()
            epoch_start_time = time.time()

            for iter, (x_, _) in enumerate(self.data_loader):

                if iter == self.data_loader.dataset.__len__(
                ) // self.batch_size:
                    break

                z_ = torch.rand((self.batch_size, self.z_dim))

                if self.gpu_mode:
                    x_, z_ = x_.cuda(), z_.cuda()

                # update D network
                self.D_optimizer.zero_grad()

                D_real = self.D(x_)
                D_real_loss = self.BCE_loss(D_real, self.y_real_)

                G_ = self.G(z_)
                D_fake = self.D(G_)
                D_fake_loss = self.BCE_loss(D_fake, self.y_fake_)

                D_loss = D_real_loss + D_fake_loss
                self.train_hist['D_loss'].append(D_loss.item())

                D_loss.backward()
                self.D_optimizer.step()

                # update G network

                self.G_optimizer.zero_grad()
                G_ = self.G(z_)
                D_fake = self.D(G_)

                ######## following codes are practical implementation for the paper ##########
                pertubation_del = (torch.randn(self.batch_size,
                                               self.z_dim)).cuda()
                eps = self.eps
                pertu_length = torch.norm(pertubation_del, dim=1, keepdim=True)
                pertubation_del = (pertubation_del / pertu_length) * eps
                z_prime = z_ + pertubation_del
                pertube_images = self.G(z_) - self.G(z_prime)
                pertube_latent_var = z_ - z_prime
                Q = torch.norm(pertube_images.view(
                    self.batch_size, -1), dim=1) / torch.norm(
                        pertube_latent_var.view(self.batch_size, -1), dim=1)
                print(Q)

                L_max = 0.0
                L_min = 0.0
                count_max = 0
                count_min = 0

                for i in range(self.batch_size):
                    if Q[i] > self.eig_max:
                        L_max += (Q[i] - self.eig_max)**2
                        count_max += 1
                    if Q[i] < self.eig_min:
                        L_min += (Q[i] - self.eig_min)**2
                        count_min += 1
                L = L_max + L_min
                #################### end of implementation for the paper ####################

                G_loss = self.BCE_loss(D_fake, self.y_real_)

                self.train_hist['G_loss'].append(G_loss.item())

                G_loss.backward()

                self.G_optimizer.step()

                if ((iter + 1) % 100) == 0:
                    print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" %
                          ((epoch + 1),
                           (iter + 1), self.data_loader.dataset.__len__() //
                           self.batch_size, D_loss.item(), G_loss.item()))
                    print(L)
                    print(count_max)
                    print(count_min)

            self.train_hist['per_epoch_time'].append(time.time() -
                                                     epoch_start_time)
            with torch.no_grad():
                self.gen_mode(4, epoch)
                self.visualize_results((epoch + 1))

        self.train_hist['total_time'].append(time.time() - start_time)
        print("Avg one epoch time: %.2f, total %d epochs time: %.2f" %
              (np.mean(self.train_hist['per_epoch_time']), self.epoch,
               self.train_hist['total_time'][0]))
        print("Training finish!... save training results")

        self.save()
        utils.generate_animation(
            self.result_dir + '/' + self.dataset + '/' + self.model_name +
            '/' + self.model_name, self.epoch)
        utils.loss_plot(
            self.train_hist,
            os.path.join(self.save_dir, self.dataset, self.model_name),
            self.model_name)
Example #4
0
    def train(self):

        self.G.apply(self.G.weights_init)
        print(' training start!! (no conditional)')
        start_time = time.time()

        for classe in range(10):
            self.train_hist = {}
            self.train_hist['D_loss'] = []
            self.train_hist['G_loss'] = []
            self.train_hist['per_epoch_time'] = []
            self.train_hist['total_time'] = []
            # self.G.apply(self.G.weights_init) does not work for instance
            del self.E
            self.E = Encoder(self.z_dim, self.dataset, self.conditional)
            self.E_optimizer = optim.Adam(
                self.E.parameters(),
                lr=self.lr)  #, lr=args.lrD, betas=(args.beta1, args.beta2))
            if self.gpu_mode:
                self.E.cuda(self.device)

            best = 100000
            self.data_loader_train = get_iter_dataset(self.dataset_train,
                                                      self.list_class_train,
                                                      self.batch_size, classe)
            self.data_loader_valid = get_iter_dataset(self.dataset_valid,
                                                      self.list_class_valid,
                                                      self.batch_size, classe)
            early_stop = 0.
            for epoch in range(self.epoch):

                epoch_start_time = time.time()
                # print("number of batch data")
                # print(len(self.data_loader_train))
                self.E.train()
                self.G.train()
                sum_loss_train = 0.
                n_batch = 0.
                #for iter in range(self.size_epoch):
                for iter, (x_, t_) in enumerate(self.data_loader_train):
                    n_batch += 1
                    #x_ = sort_utils.get_batch(list_classes, classe, self.batch_size)
                    #x_ = torch.FloatTensor(x_)
                    x_ = Variable(x_)
                    if self.gpu_mode:
                        x_ = x_.cuda(self.device)
                    # VAE
                    z_, mu, logvar = self.E(x_)
                    recon_batch = self.G(z_)

                    # train
                    self.G_optimizer.zero_grad()
                    self.E_optimizer.zero_grad()
                    g_loss = self.loss_function(recon_batch, x_, mu, logvar)
                    g_loss.backward()  #retain_variables=True)
                    sum_loss_train += g_loss.data[0]
                    self.G_optimizer.step()
                    self.E_optimizer.step()

                    self.train_hist['D_loss'].append(g_loss.data[0])
                    self.train_hist['G_loss'].append(g_loss.data[0])

                    if ((iter + 1) % 100) == 0:
                        print(
                            "classe : [%1d] Epoch: [%2d] [%4d/%4d] G_loss: %.8f, E_loss: %.8f"
                            % (classe, (epoch + 1),
                               (iter + 1), self.size_epoch, g_loss.data[0],
                               g_loss.data[0]))
                sum_loss_train = sum_loss_train / np.float(n_batch)
                sum_loss_valid = 0.
                n_batch = 0.
                n_batch = 1.
                self.E.eval()
                self.G.eval()
                for iter, (x_, t_) in enumerate(self.data_loader_valid):
                    n_batch += 1
                    max_val, max_indice = torch.max(t_, 0)
                    mask_idx = torch.nonzero(t_ == classe)
                    if mask_idx.dim() == 0:
                        continue
                    x_ = torch.index_select(x_, 0, mask_idx[:, 0])
                    t_ = torch.index_select(t_, 0, mask_idx[:, 0])
                    if self.gpu_mode:
                        x_ = Variable(x_.cuda(self.device), volatile=True)
                    else:
                        x_ = Variable(x_)
                    # VAE
                    z_, mu, logvar = self.E(x_)
                    recon_batch = self.G(z_)

                    G_loss = self.loss_function(recon_batch, x_, mu, logvar)
                    sum_loss_valid += G_loss.data[0]

                sum_loss_valid = sum_loss_valid / np.float(n_batch)
                print(
                    "classe : [%1d] Epoch: [%2d] Train_loss: %.8f, Valid_loss: %.8f"
                    % (classe, (epoch + 1), sum_loss_train, sum_loss_valid))
                self.train_hist['per_epoch_time'].append(time.time() -
                                                         epoch_start_time)
                self.visualize_results((epoch + 1), classe)
                if sum_loss_valid < best:
                    best = sum_loss_valid
                    self.save_G(classe)
                    early_stop = 0.
                # We dit early stopping of the valid performance doesn't
                # improve anymore after 50 epochs
                if early_stop == 150:
                    break
                else:
                    early_stop += 1
            result_dir = self.result_dir + '/' + 'classe-' + str(classe)
            utils.generate_animation(result_dir + '/' + self.model_name,
                                     epoch + 1)
            utils.loss_plot(self.train_hist, result_dir, self.model_name)

            np.savetxt(
                os.path.join(result_dir,
                             'vae_training_' + self.dataset + '.txt'),
                np.transpose([self.train_hist['G_loss']]))

        self.train_hist['total_time'].append(time.time() - start_time)
        print("Avg one epoch time: %.2f, total %d epochs time: %.2f" %
              (np.mean(self.train_hist['per_epoch_time']), self.epoch,
               self.train_hist['total_time'][0]))
        print("Training finish!... save training results")
    def train(self):
        self.train_hist = {}
        self.train_hist['D_loss'] = []
        self.train_hist['G_loss'] = []
        self.train_hist['per_epoch_time'] = []
        self.train_hist['total_time'] = []

        self.y_real_, self.y_fake_ = torch.ones(self.batch_size,
                                                1), torch.zeros(
                                                    self.batch_size, 1)
        if self.gpu_mode:
            self.y_real_, self.y_fake_ = self.y_real_.cuda(
            ), self.y_fake_.cuda()

        self.D.train()
        print('training start!!')
        start_time = time.time()
        for epoch in range(self.epoch):
            self.G.train()
            epoch_start_time = time.time()
            for iter, (x_, _) in enumerate(self.data_loader):
                if iter == self.data_loader.dataset.__len__(
                ) // self.batch_size:
                    break

                z_ = torch.rand((self.batch_size, self.z_dim))
                if self.gpu_mode:
                    x_, z_ = x_.cuda(), z_.cuda()

                # update D network
                self.D_optimizer.zero_grad()

                D_real = self.D(x_)
                D_real_loss = -torch.mean(D_real)

                G_ = self.G(z_)
                D_fake = self.D(G_)
                D_fake_loss = torch.mean(D_fake)

                # gradient penalty
                alpha = torch.rand((self.batch_size, 1, 1, 1))
                if self.gpu_mode:
                    alpha = alpha.cuda()

                x_hat = alpha * x_.data + (1 - alpha) * G_.data
                x_hat.requires_grad = True

                pred_hat = self.D(x_hat)
                if self.gpu_mode:
                    gradients = grad(outputs=pred_hat,
                                     inputs=x_hat,
                                     grad_outputs=torch.ones(
                                         pred_hat.size()).cuda(),
                                     create_graph=True,
                                     retain_graph=True,
                                     only_inputs=True)[0]
                else:
                    gradients = grad(outputs=pred_hat,
                                     inputs=x_hat,
                                     grad_outputs=torch.ones(pred_hat.size()),
                                     create_graph=True,
                                     retain_graph=True,
                                     only_inputs=True)[0]

                gradient_penalty = self.lambda_ * (
                    (gradients.view(gradients.size()[0], -1).norm(2, 1) - 1)**
                    2).mean()

                D_loss = D_real_loss + D_fake_loss + gradient_penalty

                D_loss.backward()
                self.D_optimizer.step()

                if ((iter + 1) % self.n_critic) == 0:
                    # update G network
                    self.G_optimizer.zero_grad()

                    G_ = self.G(z_)
                    D_fake = self.D(G_)
                    G_loss = -torch.mean(D_fake)
                    self.train_hist['G_loss'].append(G_loss.item())

                    G_loss.backward()
                    self.G_optimizer.step()

                    self.train_hist['D_loss'].append(D_loss.item())

                if ((iter + 1) % 100) == 0:
                    print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" %
                          ((epoch + 1),
                           (iter + 1), self.data_loader.dataset.__len__() //
                           self.batch_size, D_loss.item(), G_loss.item()))

            self.train_hist['per_epoch_time'].append(time.time() -
                                                     epoch_start_time)
            with torch.no_grad():
                self.visualize_results((epoch + 1))

        self.train_hist['total_time'].append(time.time() - start_time)
        print("Avg one epoch time: %.2f, total %d epochs time: %.2f" %
              (np.mean(self.train_hist['per_epoch_time']), self.epoch,
               self.train_hist['total_time'][0]))
        print("Training finish!... save training results")

        self.save()
        # utils.generate_animation(self.result_dir + '/' + self.dataset + '/' + '%03d'% self.epoch +'/' + self.model_name + '/'
        #                   + self.model_name,self.epoch)
        utils.loss_plot(
            self.train_hist,
            os.path.join(self.save_dir, self.dataset, self.model_name),
            self.model_name)
Example #6
0
    def train(self):
        self.train_hist = {}
        self.train_hist['D_loss'] = []
        self.train_hist['G_loss'] = []
        self.train_hist['per_epoch_time'] = []
        self.train_hist['total_time'] = []

        self.D.train()
        self.E.train()

        print('training start!!')
        start_time = time.time()
        first_time=True
        self.accuracy_hist=[]

        for epoch in range(self.epoch):
            self.G.train()  # check here!!
            epoch_start_time = time.time()
            decay=0.98**epoch
            self.E_optimizer = optim.Adam(self.E.parameters(), lr=decay * 0.3 * self.args.lrD,
                                          betas=(self.args.beta1, self.args.beta2))
            self.G_optimizer = optim.Adam(self.G.parameters(), lr=decay * 3 * self.args.lrG, betas=(self.args.beta1, self.args.beta2))
            self.D_optimizer = optim.Adam(self.D.parameters(), lr=decay * self.args.lrD, betas=(self.args.beta1, self.args.beta2))
            for M_epoch in range(5):
                for iter, (batch_x, batch_y) in enumerate(self.train_loader):

                    x_=batch_x
                    z_=torch.rand((self.batch_size, self.z_dim))
                    x_, z_ = Variable(x_.cuda()), Variable(z_.cuda())
                    G_batch_size = batch_x.size()[0]
                    if G_batch_size < self.batch_size:
                        break
                    # x_  (batch, 1L, 28L, 28L)
                    # z_  (batch, 62L)

                    # update D network:

                    image_real = Variable(batch_x.cuda())
                    self.E.eval()
                    y_real = self.E(image_real)
                    y_real = nn.Softmax()(y_real)
                    y_real = (y_real).data.cpu().numpy()  #


                    self.D_optimizer.zero_grad()

                    D_real = self.D(x_)
                    if first_time:
                        y_real = (1 / float(self.class_num)) * np.ones((G_batch_size, self.class_num)) # first_time

                    y_real = np.concatenate((y_real, 2*np.ones((np.shape(y_real)[0], 1))), axis=1)

                    ones=np.ones((np.shape(y_real)[0],np.shape(y_real)[1]))
                    ones[:,-1]=0
                    ones=torch.FloatTensor(ones)
                    ones=Variable(ones).cuda()
                    y_real=torch.FloatTensor(y_real).cuda()

                    D_real_loss = torch.nn.BCEWithLogitsLoss(weight=y_real)(D_real,ones)

                    G_input, conditional_label = self.gen_cond_label(self.batch_size)
                    G_ = self.G(G_input, 0)
                    D_fake = self.D(G_)
                    y_fake_1 = np.tile(np.zeros((self.class_num)), (self.batch_size, 1))
                    y_fake_2 = np.tile(np.ones((1)), (self.batch_size, 1))
                    y_fake = np.concatenate((y_fake_1,y_fake_2),axis=1)
                    y_fake = Variable(torch.FloatTensor(y_fake).cuda())
                    D_fake_loss=torch.nn.BCEWithLogitsLoss()(D_fake,y_fake)

                    D_loss = D_real_loss + D_fake_loss

                    self.train_hist['D_loss'].append(D_loss.data[0])
                    D_loss.backward()
                    self.D_optimizer.step()


                    # update G network:

                    self.G_optimizer.zero_grad()
                    G_input, conditional_label = self.gen_cond_label(self.batch_size)
                    G_ = self.G(G_input, 0)
                    D_fake = self.D(G_)

                    G_y_real=np.concatenate((conditional_label.numpy(),np.tile([0],(self.batch_size,1))),axis=1)
                    G_y_real=Variable(torch.FloatTensor(G_y_real)).cuda()
                    G_loss=torch.nn.BCEWithLogitsLoss()(D_fake,G_y_real)


                    self.train_hist['G_loss'].append(G_loss.data[0])
                    G_loss.backward()
                    self.G_optimizer.step()

                    if ((iter + 1) % 100) == 0:
                        print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" %
                              ((epoch + 1), (iter + 1), len(self.data_X) // self.batch_size, D_loss.data[0], G_loss.data[0]))



            self.E_training(200)
            first_time = False
            self.visualize_results((epoch+1))
            self.compute_accuracy()
            self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time)
            self.save()

        self.train_hist['total_time'].append(time.time() - start_time)
        print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']),
              self.epoch, self.train_hist['total_time'][0]))
        print("Training finish!... save training results")

        self.save()
        utils.generate_animation(self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name,
                                 self.epoch)
        utils.loss_plot(self.train_hist, os.path.join(self.save_dir, self.dataset, self.model_name), self.model_name)
Example #7
0
    def train(self):
        self.train_hist = {}
        self.train_hist['D_loss'] = []
        self.train_hist['G_loss'] = []
        self.train_hist['per_epoch_time'] = []
        self.train_hist['total_time'] = []
        self.train_hist['D_norm'] = []

        f = open("%s/results.txt" % self.log_dir, "w")
        f.write("d_loss,g_loss,d_norm\n")
    
        if self.gpu_mode:
            self.y_real_, self.y_fake_ = Variable(torch.ones(self.batch_size, 1).cuda()), Variable(torch.zeros(self.batch_size, 1).cuda())
        else:
            self.y_real_, self.y_fake_ = Variable(torch.ones(self.batch_size, 1)), Variable(torch.zeros(self.batch_size, 1))

        #for iter, ((x1_,_), (x2_,_)) in enumerate(zip(self.data_loader, self.data_loader)):
        #    import pdb
        #    pdb.set_trace()

        self.D.train()
        print('training start!!')
        start_time = time.time()
        for epoch in range(self.epoch):
            self.G.train()
            epoch_start_time = time.time()
            for iter, (x_, _) in enumerate(self.data_loader):
                
                if iter == self.data_loader.dataset.__len__() // self.batch_size:
                    break

                z_ = torch.rand((self.batch_size, self.z_dim))

                if self.gpu_mode:
                    x_, z_ = Variable(x_.cuda(), requires_grad=True), \
                            Variable(z_.cuda())
                else:
                    x_, z_ = Variable(x_, requires_grad=True), \
                            Variable(z_)

                # update D network

                D_real = self.D(x_)
                # compute gradient penalty
                grad_wrt_x = grad(outputs=D_real, inputs=x_,
                                 grad_outputs=torch.ones(D_real.size()).cuda(),
                                 create_graph=True, retain_graph=True, only_inputs=True)[0]
                g_norm  = ((grad_wrt_x.view(grad_wrt_x.size()[0], -1).norm(2, 1) - 1) ** 2).mean()
                self.train_hist['D_norm'].append(g_norm.data.item())

                self.D_optimizer.zero_grad()

                G_ = self.G(z_).detach()
                alpha = float(np.random.random())
                Xz = Variable(alpha*x_.data + (1.-alpha)*G_.data)
                D_Xz = self.D(Xz)
                D_loss = self.BCE_loss(D_Xz, alpha*self.y_real_)
                
                self.train_hist['D_loss'].append(D_loss.data.item())

                D_loss.backward()
                self.D_optimizer.step()

                # update G network
                self.G_optimizer.zero_grad()

                G_ = self.G(z_)
                D_fake = self.D(G_)
                G_loss = self.BCE_loss(D_fake, self.y_real_)
                self.train_hist['G_loss'].append(G_loss.data.item())

                G_loss.backward()
                self.G_optimizer.step()

                if ((iter + 1) % 100) == 0:
                    print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f, D_norm: %.8f" %
                          ((epoch + 1),
                           (iter + 1),
                           self.data_loader.dataset.__len__() // self.batch_size,
                           D_loss.data.item(),
                           G_loss.data.item(),
                           g_norm.data.item()))
                    f.write("%.8f,%.8f,%.8f\n" % (D_loss.data.item(), G_loss.data.item(), g_norm.data.item()))
                    f.flush()

            self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time)
            self.visualize_results((epoch+1))

        self.train_hist['total_time'].append(time.time() - start_time)
        print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']),
              self.epoch, self.train_hist['total_time'][0]))
        print("Training finish!... save training results")

        f.close()

        self.save()
        utils.generate_animation(self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name,
                                 self.epoch)
        utils.loss_plot(self.train_hist, os.path.join(self.save_dir, self.dataset, self.model_name), self.model_name)
Example #8
0
    def train(self):
        self.train_hist = {}
        self.train_hist['D_loss'] = []
        self.train_hist['G_loss'] = []
        self.train_hist['per_epoch_time'] = []
        self.train_hist['total_time'] = []

        if self.gpu_mode:
            self.y_real_, self.y_fake_ = Variable(torch.ones(self.batch_size, 1).cuda()), Variable(torch.zeros(self.batch_size, 1).cuda())
        else:
            self.y_real_, self.y_fake_ = Variable(torch.ones(self.batch_size, 1)), Variable(torch.zeros(self.batch_size, 1))

        self.D.train()
        print('training start!!')
        start_time = time.time()
        for epoch in range(self.epoch):
            self.G.train()
            epoch_start_time = time.time()
            for iter, (x_, _) in enumerate(self.data_loader):
                if iter == self.data_loader.dataset.__len__() // self.batch_size:
                    break

                z_ = torch.rand((self.batch_size, self.z_dim))

                if self.gpu_mode:
                    x_, z_ = Variable(x_.cuda()), Variable(z_.cuda())
                else:
                    x_, z_ = Variable(x_), Variable(z_)

                # update D network
                self.D_optimizer.zero_grad()

                D_real = self.D(x_)
                D_real_loss = -torch.mean(D_real)

                G_ = self.G(z_)
                D_fake = self.D(G_)
                D_fake_loss = torch.mean(D_fake)

                D_loss = D_real_loss + D_fake_loss

                D_loss.backward()
                self.D_optimizer.step()

                # clipping D
                for p in self.D.parameters():
                    p.data.clamp_(-self.c, self.c)

                if ((iter+1) % self.n_critic) == 0:
                    # update G network
                    self.G_optimizer.zero_grad()

                    G_ = self.G(z_)
                    D_fake = self.D(G_)
                    G_loss = -torch.mean(D_fake)
                    self.train_hist['G_loss'].append(G_loss.data[0])

                    G_loss.backward()
                    self.G_optimizer.step()

                    self.train_hist['D_loss'].append(D_loss.data[0])

                if ((iter + 1) % 100) == 0:
                    print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" %
                          ((epoch + 1), (iter + 1), self.data_loader.dataset.__len__() // self.batch_size, D_loss.data[0], G_loss.data[0]))

            self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time)
            self.visualize_results((epoch+1))

        self.train_hist['total_time'].append(time.time() - start_time)
        print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']),
              self.epoch, self.train_hist['total_time'][0]))
        print("Training finish!... save training results")

        self.save()
        utils.generate_animation(self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name,
                                 self.epoch)
        utils.loss_plot(self.train_hist, os.path.join(self.save_dir, self.dataset, self.model_name), self.model_name)
Example #9
0
    def train(self):
        train_hist_keys = [
            'D2d_loss', 'D3d_loss', 'D2d_acc', 'D3d_acc', 'G_loss',
            'per_epoch_time', 'total_time'
        ]
        if 'recon' in self.loss_option:
            train_hist_keys.append('G_loss_recon')
        if 'dist' in self.loss_option:
            train_hist_keys.append('G_loss_dist')

        if not hasattr(self, 'epoch_start'):
            self.epoch_start = 0
        if not hasattr(self, 'train_hist'):
            self.train_hist = {}
            for key in train_hist_keys:
                self.train_hist[key] = []
        else:
            existing_keys = self.train_hist.keys()
            num_hist = [len(self.train_hist[key]) for key in existing_keys]
            num_hist = max(num_hist)
            for key in train_hist_keys:
                if key not in existing_keys:
                    self.train_hist[key] = [0] * num_hist
                    print('new key added: {}'.format(key))

        if self.gpu_mode:
            self.y_real_ = Variable((torch.ones(self.batch_size, 1)).cuda())
            self.y_fake_ = Variable((torch.zeros(self.batch_size, 1)).cuda())
        else:
            self.y_real_ = Variable((torch.ones(self.batch_size, 1)))
            self.y_fake_ = Variable((torch.zeros(self.batch_size, 1)))

        nPairs = self.batch_size * (self.batch_size - 1)
        normalizerA = self.data_loader.dataset.muA / self.data_loader.dataset.stddevA  # normalization
        normalizerB = self.data_loader.dataset.muB / self.data_loader.dataset.stddevB  # normalization
        eps = 1e-16

        self.D2d.train()
        self.D3d.train()
        start_time = time.time()

        nBatchesPerEpoch = self.data_loader.dataset.__len__(
        ) // self.batch_size
        print('training start from epoch {}!!'.format(self.epoch_start + 1))
        for epoch in range(self.epoch_start, self.epoch):
            self.G.train()
            epoch_start_time = time.time()
            start_time_epoch = time.time()

            for iB, (x3D_, y_, x2D_) in enumerate(self.data_loader):
                if iB == nBatchesPerEpoch:
                    break

                projected, _ = torch.max(x3D_[:, 1:, :, :, :],
                                         4,
                                         keepdim=False)
                x3D_ = x3D_[:, 0:1, :, :, :]
                y_random_pcode_ = torch.floor(
                    torch.rand(self.batch_size) * self.num_c_expr).long()
                y_random_pcode_onehot_ = torch.zeros(self.batch_size,
                                                     self.num_c_expr)
                y_random_pcode_onehot_.scatter_(1, y_random_pcode_.view(-1, 1),
                                                1)
                y_id_ = y_['id']
                y_pcode_ = y_['pcode']
                y_pcode_onehot_ = torch.zeros(self.batch_size, self.num_c_expr)
                y_pcode_onehot_.scatter_(1, y_pcode_.view(-1, 1), 1)

                if self.gpu_mode:
                    x2D_ = Variable(x2D_.cuda())
                    x3D_ = Variable(x3D_.cuda())
                    projected = Variable(projected.cuda())
                    y_id_ = Variable(y_id_.cuda())
                    y_pcode_ = Variable(y_pcode_.cuda())
                    y_pcode_onehot_ = Variable(y_pcode_onehot_.cuda())
                    y_random_pcode_ = Variable(y_random_pcode_.cuda())
                    y_random_pcode_onehot_ = Variable(
                        y_random_pcode_onehot_.cuda())
                else:
                    x2D_ = Variable(x2D_)
                    x3D_ = Variable(x3D_)
                    projected = Variable(projected)
                    y_id_ = Variable(y_id_)
                    y_pcode_ = Variable(y_pcode_)
                    y_pcode_onehot_ = Variable(y_pcode_onehot_)
                    y_random_pcode_ = Variable(y_random_pcode_)
                    y_random_pcode_onehot_ = Variable(y_random_pcode_onehot_)

                # update D network
                for iD in range(self.n_critic):
                    self.D2d_optimizer.zero_grad()
                    self.D3d_optimizer.zero_grad()

                    d_gan2d, d_id2d, d_expr2d = self.D2d(projected)
                    loss_d_real_gan2d = self.BCE_loss(d_gan2d, self.y_real_)
                    loss_d_real_id2d = self.CE_loss(d_id2d, y_id_)
                    loss_d_real_expr2d = self.CE_loss(d_expr2d, y_pcode_)

                    d_gan3d, d_id3d, d_expr3d = self.D3d(x3D_)
                    loss_d_real_gan3d = self.BCE_loss(d_gan3d, self.y_real_)
                    loss_d_real_id3d = self.CE_loss(d_id3d, y_id_)
                    loss_d_real_expr3d = self.CE_loss(d_expr3d, y_pcode_)

                    xhat2d, xhat3d = self.G(x2D_, y_random_pcode_onehot_)
                    d_fake_gan2d, _, _ = self.D2d(xhat2d)
                    d_fake_gan3d, _, _ = self.D3d(xhat3d)
                    loss_d_fake_gan2d = self.BCE_loss(d_fake_gan2d,
                                                      self.y_fake_)
                    loss_d_fake_gan3d = self.BCE_loss(d_fake_gan3d,
                                                      self.y_fake_)

                    num_correct_real2d = torch.sum(d_gan2d > 0.5)
                    num_correct_fake2d = torch.sum(d_fake_gan2d < 0.5)
                    D2d_acc = float(num_correct_real2d.data[0] +
                                    num_correct_fake2d.data[0]) / (
                                        self.batch_size * 2)
                    num_correct_real3d = torch.sum(d_gan3d > 0.5)
                    num_correct_fake3d = torch.sum(d_fake_gan3d < 0.5)
                    D3d_acc = float(num_correct_real3d.data[0] +
                                    num_correct_fake3d.data[0]) / (
                                        self.batch_size * 2)

                    D2d_loss = loss_d_real_gan2d + loss_d_real_id2d + loss_d_real_expr2d + loss_d_fake_gan2d
                    D3d_loss = loss_d_real_gan3d + loss_d_real_id3d + loss_d_real_expr3d + loss_d_fake_gan3d

                    if iD == 0:
                        self.train_hist['D2d_loss'].append(D2d_loss.data[0])
                        self.train_hist['D3d_loss'].append(D3d_loss.data[0])
                        self.train_hist['D2d_acc'].append(D2d_acc)
                        self.train_hist['D3d_acc'].append(D3d_acc)

                    D2d_loss.backward(retain_graph=True)
                    D3d_loss.backward()
                    if D2d_acc < 0.8:
                        self.D2d_optimizer.step()
                    if D3d_acc < 0.8:
                        self.D3d_optimizer.step()

                # update G network
                for iG in range(self.n_gen):
                    self.G_optimizer.zero_grad()

                    xhat2d, xhat3d = self.G(x2D_, y_pcode_onehot_)

                    d_gan2d, d_id2d, d_expr2d = self.D2d(xhat2d)
                    loss_g_gan2d = self.BCE_loss(d_gan2d, self.y_real_)
                    loss_g_id2d = self.CE_loss(d_id2d, y_id_)
                    loss_g_expr2d = self.CE_loss(d_expr2d, y_pcode_)

                    d_gan3d, d_id3d, d_expr3d = self.D3d(xhat3d)
                    loss_g_gan3d = self.BCE_loss(d_gan3d, self.y_real_)
                    loss_g_id3d = self.CE_loss(d_id3d, y_id_)
                    loss_g_expr3d = self.CE_loss(d_expr3d, y_pcode_)

                    G_loss = loss_g_gan2d + loss_g_id2d + loss_g_expr2d + \
                       loss_g_gan3d + loss_g_id3d + loss_g_expr3d

                    if iG == 0:
                        self.train_hist['G_loss'].append(G_loss.data[0])

                    G_loss.backward()
                    self.G_optimizer.step()

                if ((iB + 1) % 10) == 0 or (iB + 1) == nBatchesPerEpoch:
                    secs = time.time() - start_time_epoch
                    hours = secs // 3600
                    mins = secs / 60 % 60
                    #print("%2dh%2dm E[%2d] B[%d/%d] D: %.4f, G: %.4f, D_acc:%.4f"%
                    print(
                        "%2dh%2dm E[%2d] B[%d/%d] D: %.4f/%.4f, G: %.4f, D_acc:%.4f/%.4f"
                        % (hours, mins, (epoch + 1),
                           (iB + 1), nBatchesPerEpoch, D2d_loss.data[0],
                           D3d_loss.data[0], G_loss.data[0], D3d_acc, D2d_acc))

            self.train_hist['per_epoch_time'].append(time.time() -
                                                     epoch_start_time)
            if epoch == 0 or (epoch + 1) % 5 == 0:
                self.dump_x_hat(xhat2d, xhat3d, epoch + 1)
            self.save()
            utils.loss_plot(self.train_hist,
                            os.path.join(self.save_dir, self.dataset,
                                         self.model_name),
                            self.model_name,
                            use_subplot=True)

        self.train_hist['total_time'].append(time.time() - start_time)
        print("Avg one epoch time: %.2f, total %d epochs time: %.2f" %
              (np.mean(self.train_hist['per_epoch_time']), self.epoch,
               self.train_hist['total_time'][0]))
        print("Training finish!... save training results")

        self.save()
        utils.loss_plot(self.train_hist,
                        os.path.join(self.save_dir, self.dataset,
                                     self.model_name),
                        self.model_name,
                        use_subplot=True)
    def train(self):
        self.train_hist['D_loss'] = []
        self.train_hist['G_loss'] = []
        self.train_hist["per_epoch_time"] = []
        self.train_hist['total_time'] = []

        if self.gpu_mode:
            self.y_real_, self.y_fake_ = Variable(
                torch.ones(self.batch_size, 1).cuda()), Variable(
                    torch.zeros(self.batch_size, 1).cuda())
        else:
            self.y_real_, self.y_fake_ = Variable(
                torch.ones(self.batch_size,
                           1)), Variable(torch.zeros(self.batch_size, 1))

        # self.D.tarin()
        print('train start!!')
        start_time = time.time()

        for epoch in range(self.epoch):
            # self.G.train()
            epoch_start_time = time.time()

            for iter, (x_, _) in enumerate(self.data_loader):
                if iter == self.data_loader.dataset.__len__(
                ) // self.batch_size:
                    break

                z_ = torch.rand((self.batch_size, self.z_dim))

                if self.gpu_mode:
                    x_, z_ = Variable(x_.cuda()), Variable(z_.cuda())
                else:
                    x_, z_ = Variable(x_), Variable(z_)
                """Update D network"""
                self.D_optimizer.zero_grad()

                # train with real images
                y_hat_real = self.D(x_)  # forward pass
                D_real_loss = self.BCE_loss(y_hat_real, self.y_real_)

                generated_images_ = self.G(z_)
                y_hat_fake = self.D(generated_images_)
                D_fake_loss = self.BCE_loss(y_hat_fake, self.y_fake_)

                D_loss = D_fake_loss + D_real_loss
                self.train_hist['D_loss'].append(D_loss.data[0])

                D_loss.backward()
                self.D_optimizer.step()
                """Update generator network"""
                self.G_optimizer.zero_grad()

                generated_images_ = self.G(z_)
                y_hat_fake = self.D(generated_images_)
                G_loss = self.BCE_loss(y_hat_fake, self.y_real_)
                self.train_hist['G_loss'].append(G_loss.data[0])

                G_loss.backward()
                self.G_optimizer.step()
                if ((iter + 1) % 100) == 0:
                    print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" %
                          ((epoch + 1),
                           (iter + 1), self.data_loader.dataset.__len__() //
                           self.batch_size, D_loss.data[0], G_loss.data[0]))

            self.train_hist['per_epoch_time'].append(time.time() -
                                                     epoch_start_time)
            self.visualize_results((epoch + 1))

        self.train_hist['total_time'].append(time.time() - start_time)
        print("Avg one epoch time: %.2f, total %d epochs time: %.2f" %
              (np.mean(self.train_hist['per_epoch_time']), self.epoch,
               self.train_hist['total_time'][0]))
        print("Training finish!... save training results")

        self.save()
        utils.generate_animation(
            self.result_dir + '/' + self.dataset + '/' + self.model_name +
            '/' + self.model_name, self.epoch)
        utils.loss_plot(
            self.train_hist,
            os.path.join(self.save_dir, self.dataset, self.model_name),
            self.model_name)
Example #11
0
    def train(self):
        self.train_hist = {}
        self.train_hist['D_loss'] = []
        self.train_hist['G_loss'] = []
        self.train_hist['per_epoch_time'] = []
        self.train_hist['total_time'] = []

        # print('train at LSGAN')
        # self.y_real_, self.y_fake_ = torch.ones(self.batch_size, 1), torch.zeros(self.batch_size, 1)
        self.y_real_, self.y_fake_ = torch.zeros(self.batch_size,
                                                 1), torch.ones(
                                                     self.batch_size, 1)

        if self.gpu_mode:
            self.y_real_, self.y_fake_ = self.y_real_.cuda(
            ), self.y_fake_.cuda()

        self.D.train()
        print('{} training start!!,epoch:{},module stored at:{}'.format(
            self.model_name, self.epoch, self.dataset))
        start_time = time.time()
        # url = os.path.join(self.save_dir, self.dataset, self.model_name)

        for epoch in range(self.epoch):
            # if epoch == 105:
            #     self.G = torch.load(os.path.join(url,'LSGAN_105_G.pkl'))
            #     self.D = torch.load(os.path.join(url,'LSGAN_105_D.pkl'))
            #     print('reload success!','*'*40)
            self.G.train()
            try:
                if epoch >= 15:
                    self.G_optimizer.param_groups[0]['lr'] = 0.00009
                    self.D_optimizer.param_groups[0]['lr'] = 0.00009
                elif epoch >= 40:
                    self.G_optimizer.param_groups[0]['lr'] = 0.00001
                    self.D_optimizer.param_groups[0]['lr'] = 0.00001
                elif epoch >= 70:
                    self.G_optimizer.param_groups[0]['lr'] = 0.000009
                    self.D_optimizer.param_groups[0]['lr'] = 0.000009
                elif epoch >= 90:
                    self.G_optimizer.param_groups[0]['lr'] = 0.000001
                    self.D_optimizer.param_groups[0]['lr'] = 0.000001
                elif epoch >= 110:
                    self.G_optimizer.param_groups[0]['lr'] = 0.0000001
                    self.D_optimizer.param_groups[0]['lr'] = 0.0000001

            except:
                # print('1',self.G.__getattribute__)
                # print('2',self.G._parameters.__class__,self.G._parameters.__getattribute__)
                # try:
                #     print('3',self.G.param_groups.keys())
                # except:
                #     pass
                print('error arise for param_groups at train part')

            epoch_start_time = time.time()
            # for iter, (x_, _) in enumerate(self.data_loader):

            for iter, x_, in enumerate(self.data_loader):
                x_ = x_[0]

                if iter == self.data_loader.dataset.__len__(
                ) // self.batch_size:
                    break

                z_ = torch.rand((self.batch_size, self.z_dim))
                if self.gpu_mode:
                    x_, z_ = x_.cuda(), z_.cuda()

                # update D network
                self.D_optimizer.zero_grad()

                D_real = self.D(x_)
                D_real_loss = self.MSE_loss(D_real, self.y_real_)

                G_ = self.G(z_)
                D_fake = self.D(G_)
                D_fake_loss = self.MSE_loss(D_fake, self.y_fake_)

                D_loss = D_real_loss + D_fake_loss
                self.train_hist['D_loss'].append(D_loss.item())

                D_loss.backward()
                self.D_optimizer.step()

                # update G network
                self.G_optimizer.zero_grad()

                G_ = self.G(z_)
                D_fake = self.D(G_)
                G_loss = self.MSE_loss(D_fake, self.y_real_)
                self.train_hist['G_loss'].append(G_loss.item())

                G_loss.backward()
                self.G_optimizer.step()

                if ((iter + 1) % 100) == 0:
                    print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" %
                          ((epoch + 1),
                           (iter + 1), self.data_loader.dataset.__len__() //
                           self.batch_size, D_loss.item(), G_loss.item()),
                          end=',')
                    print('lr:%7f' % self.G_optimizer.param_groups[0]['lr'])
                    self.writer.add_scalar('G_loss', G_loss.item(), self.X)
                    # writer.add_scalar('G_loss', -G_loss_D, X)
                    self.writer.add_scalar('D_loss', D_loss.item(), self.X)
                    self.writer.add_scalars('cross loss', {
                        'G_loss': D_loss.item(),
                        'D_loss': D_loss.item()
                    }, self.X)

            self.train_hist['per_epoch_time'].append(time.time() -
                                                     epoch_start_time)
            # with torch.no_grad():
            #     self.visualize_results((epoch+1))
            if epoch % 5 == 0:
                self.load_interval(epoch)

        self.train_hist['total_time'].append(time.time() - start_time)
        print("Avg one epoch time: %.2f, total %d epochs time: %.2f" %
              (np.mean(self.train_hist['per_epoch_time']), self.epoch,
               self.train_hist['total_time'][0]))
        print("Training finish!... save training results")
        save_dir = os.path.join(self.save_dir, self.dataset, self.model_name)

        with open(os.path.join(save_dir, self.model_name + '_train_hist.json'),
                  "a") as f:
            json.dump(self.train_hist, f)

        self.writer.export_scalars_to_json(
            os.path.join(save_dir, self.model_name + '.json'))
        self.writer.close()
        self.load_interval(epoch)

        # self.save()
        # utils.generate_animation(self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name,
        #                          self.epoch)
        utils.loss_plot(
            self.train_hist,
            os.path.join(self.save_dir, self.dataset, self.model_name),
            self.model_name)
Example #12
0
    def train(self):
        self.train_hist = {}
        self.train_hist['D_loss'] = []
        self.train_hist['G_loss'] = []
        self.train_hist['info_loss'] = []
        self.train_hist['per_epoch_time'] = []
        self.train_hist['total_time'] = []

        self.y_real_, self.y_fake_ = torch.ones(self.batch_size, 1), torch.zeros(self.batch_size, 1)
        self.y_real_, self.y_fake_ = self.y_real_.cpu(), self.y_fake_.cpu()

        self.D.train()
        print('training start!!')
        start_time = time.time()
        for epoch in range(self.epoch):
            self.G.train()
            epoch_start_time = time.time()
            for iter, (x_, y_) in enumerate(self.data_loader):
                if iter == self.data_loader.dataset.__len__() // self.batch_size:
                    break
                z_ = torch.rand((self.batch_size, self.z_dim))
                if self.SUPERVISED == True:
                    y_disc_ = torch.zeros((self.batch_size, self.len_discrete_code)).scatter_(1, y_.type(torch.LongTensor).unsqueeze(1), 1)
                else:
                    y_disc_ = torch.from_numpy(
                        np.random.multinomial(1, self.len_discrete_code * [float(1.0 / self.len_discrete_code)],
                                              size=[self.batch_size])).type(torch.FloatTensor)

                y_cont_ = torch.from_numpy(np.random.uniform(-1, 1, size=(self.batch_size, 2))).type(torch.FloatTensor)

                x_, z_, y_disc_, y_cont_ = x_.cpu(), z_.cpu(), y_disc_.cpu(), y_cont_.cpu()

                # update D network
                self.D_optimizer.zero_grad()

                D_real, _, _ = self.D(x_)
                D_real_loss = self.BCE_loss(D_real, self.y_real_)

                G_ = self.G(z_, y_cont_, y_disc_)
                D_fake, _, _ = self.D(G_)
                D_fake_loss = self.BCE_loss(D_fake, self.y_fake_)

                D_loss = D_real_loss + D_fake_loss
                self.train_hist['D_loss'].append(D_loss.item())

                D_loss.backward(retain_graph=True)
                self.D_optimizer.step()

                # update G network
                self.G_optimizer.zero_grad()

                G_ = self.G(z_, y_cont_, y_disc_)
                D_fake, D_cont, D_disc = self.D(G_)

                G_loss = self.BCE_loss(D_fake, self.y_real_)
                self.train_hist['G_loss'].append(G_loss.item())

                G_loss.backward(retain_graph=True)
                self.G_optimizer.step()

                # information loss
                disc_loss = self.CE_loss(D_disc, torch.max(y_disc_, 1)[1])
                cont_loss = self.MSE_loss(D_cont, y_cont_)
                info_loss = disc_loss + cont_loss
                self.train_hist['info_loss'].append(info_loss.item())

                info_loss.backward()
                self.info_optimizer.step()


                if ((iter + 1) % 100) == 0:
                    print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f, info_loss: %.8f" %
                          ((epoch + 1), (iter + 1), self.data_loader.dataset.__len__() // self.batch_size, D_loss.item(), G_loss.item(), info_loss.item()))
                    with torch.no_grad():
                        self.visualize_results((epoch + 1), (iter + 1))
            self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time)


        self.train_hist['total_time'].append(time.time() - start_time)
        print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']),
                                                                        self.epoch, self.train_hist['total_time'][0]))
        print("Training finish!... save training results")

        self.save()
        utils.generate_animation(self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name,
                                 self.epoch)
        utils.generate_animation(self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name + '_cont',
                                 self.epoch)
        utils.loss_plot(self.train_hist, os.path.join(self.save_dir, self.dataset, self.model_name), self.model_name)
    def train(self):
        self.train_hist = {}
        self.train_hist['D_loss'] = []
        self.train_hist['G_loss'] = []
        self.train_hist['per_epoch_time'] = []
        self.train_hist['total_time'] = []

        if self.gpu_mode:
            self.y_real_, self.y_fake_ = Variable(torch.ones(self.batch_size, 1).cuda()), Variable(torch.zeros(self.batch_size, 1).cuda())
        else:
            self.y_real_, self.y_fake_ = Variable(torch.ones(self.batch_size, 1)), Variable(torch.zeros(self.batch_size, 1))

        self.D.train()
        print('training start!!')
        start_time = time.time()
        for epoch in range(self.epoch):
            self.G.train()
            epoch_start_time = time.time()
            for iter, (x_, _) in enumerate(self.data_loader):
                if iter == self.data_loader.dataset.__len__() // self.batch_size:
                    break

                z_ = torch.rand((self.batch_size, self.z_dim))

                if self.gpu_mode:
                    x_, z_ = Variable(x_.cuda()), Variable(z_.cuda())
                else:
                    x_, z_ = Variable(x_), Variable(z_)

                # update D network
                self.D_optimizer.zero_grad()

                D_real = self.D(x_)
                D_real_loss = -torch.mean(D_real)

                G_ = self.G(z_)
                D_fake = self.D(G_)
                D_fake_loss = torch.mean(D_fake)

                # gradient penalty
                if self.gpu_mode:
                    alpha = torch.rand(x_.size()).cuda()
                else:
                    alpha = torch.rand(x_.size())

                x_hat = Variable(alpha * x_.data + (1 - alpha) * G_.data, requires_grad=True)

                pred_hat = self.D(x_hat)
                if self.gpu_mode:
                    gradients = grad(outputs=pred_hat, inputs=x_hat, grad_outputs=torch.ones(pred_hat.size()).cuda(),
                                 create_graph=True, retain_graph=True, only_inputs=True)[0]
                else:
                    gradients = grad(outputs=pred_hat, inputs=x_hat, grad_outputs=torch.ones(pred_hat.size()),
                                     create_graph=True, retain_graph=True, only_inputs=True)[0]

                gradient_penalty = self.lambda_ * ((gradients.view(gradients.size()[0], -1).norm(2, 1) - 1) ** 2).mean()

                D_loss = D_real_loss + D_fake_loss + gradient_penalty

                D_loss.backward()
                self.D_optimizer.step()

                if ((iter+1) % self.n_critic) == 0:
                    # update G network
                    self.G_optimizer.zero_grad()

                    G_ = self.G(z_)
                    D_fake = self.D(G_)
                    G_loss = -torch.mean(D_fake)
                    self.train_hist['G_loss'].append(G_loss.data[0])

                    G_loss.backward()
                    self.G_optimizer.step()

                    self.train_hist['D_loss'].append(D_loss.data[0])

                if ((iter + 1) % 100) == 0:
                    print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" %
                          ((epoch + 1), (iter + 1), self.data_loader.dataset.__len__() // self.batch_size, D_loss.data[0], G_loss.data[0]))

            self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time)
            self.visualize_results((epoch+1))

        self.train_hist['total_time'].append(time.time() - start_time)
        print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']),
              self.epoch, self.train_hist['total_time'][0]))
        print("Training finish!... save training results")

        self.save()
        utils.generate_animation(self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name,
                                 self.epoch)
        utils.loss_plot(self.train_hist, os.path.join(self.save_dir, self.dataset, self.model_name), self.model_name)