예제 #1
0
class TrainerVaDE:
    """This is the trainer for the Variational Deep Embedding (VaDE).
    """
    def __init__(self, args, device, dataloader_sup, dataloader_unsup,
                 dataloader_test, n_classes):
        if args.dataset == 'mnist':
            from models import Autoencoder, VaDE
            self.autoencoder = Autoencoder().to(device)
            self.autoencoder.apply(weights_init_normal)
            self.VaDE = VaDE().to(device)
        elif args.dataset == 'webcam':
            from models_office import Autoencoder, VaDE, feature_extractor
            self.autoencoder = Autoencoder().to(device)
            checkpoint = torch.load('weights/imagenet_params.pth.tar',
                                    map_location=device)
            self.autoencoder.load_state_dict(checkpoint['state_dict'],
                                             strict=False)
            checkpoint = torch.load('weights/feature_extractor_params.pth.tar',
                                    map_location=device)
            self.feature_extractor = feature_extractor().to(device)
            self.feature_extractor.load_state_dict(checkpoint['state_dict'])
            self.freeze_extractor()
            self.VaDE = VaDE().to(device)

        self.dataloader_sup = dataloader_sup
        self.dataloader_unsup = dataloader_unsup
        self.dataloader_test = dataloader_test
        self.device = device
        self.args = args
        self.n_classes = n_classes

    def pretrain(self):
        """Here we train an stacked autoencoder which will be used as the initialization for the VaDE. 
        This initialization is usefull because reconstruction in VAEs would be weak at the begining
        and the models are likely to get stuck in local minima.
        """
        optimizer = optim.Adam(self.autoencoder.parameters(),
                               lr=self.args.lr_ae)
        self.autoencoder.train()
        print('Training the autoencoder...')
        for epoch in range(1500):
            total_loss = 0
            for x, _ in self.dataloader_unsup:
                optimizer.zero_grad()
                x = x.to(self.device)
                if self.args.dataset == 'webcam':
                    x = self.feature_extractor(x)
                    x = x.detach()
                x_hat = self.autoencoder(x)
                loss = F.binary_cross_entropy(
                    x_hat, x, reduction='mean')  #reconstruction error
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
            print('Training Autoencoder... Epoch: {}, Loss: {}'.format(
                epoch, total_loss / len(self.dataloader_unsup)))
        self.save_weights_ae()
        #self.train_GMM() #training a GMM for initialize the VaDE
        #self.predict_GMM() #Predict and assign supervised points to its GMM components
        #self.save_weights_for_VaDE() #saving weights for the VaDE

    def train_GMM(self):
        """It is possible to fit a Gaussian Mixture Model (GMM) using the latent space 
        generated by the stacked autoencoder. This way, we generate an initialization for 
        the priors (pi, mu, var) of the VaDE model.
        """
        print('Fiting Gaussian Mixture Model...')
        x = torch.cat([data[0] for data in self.dataloader_unsup
                       ]).to(self.device)  #all x samples.
        if self.args.dataset == 'webcam':
            x = self.feature_extractor(x)
            x = x.detach()
        z = self.autoencoder.encode(x)
        self.gmm = GaussianMixture(n_components=self.n_classes,
                                   covariance_type='diag')
        self.gmm.fit(z.cpu().detach().numpy())

    def predict_GMM(self):
        """It is possible to fit a Gaussian Mixture Model (GMM) using the latent space 
        generated by the stacked autoencoder. This way, we generate an initialization for 
        the priors (pi, mu, var) of the VaDE model.
        """
        print('Predicting over Gaussian Mixture Model...')
        x, y = torch.cat([(data[0], data[1]) for data in self.dataloader_sup
                          ]).to(self.device)  #all x samples.
        x = x[np.argsort(y.cpu().detach().numpy())]
        if self.args.dataset == 'webcam':
            x = self.feature_extractor(x)
            x = x.detach()
        z = self.autoencoder.encode(x)
        probas = self.gmm.predict_proba(z.cpu().detach().numpy())
        self.assign_GMMS(probas)

    def assign_GMMS(self, probas):
        assignation = []
        possibilities = np.arange(self.n_classes)
        index = 0
        toselect = 1
        while len(possibilities) > 0:
            sorted_ = np.argsort(probas[index])
            max_ = sorted_[-toselect]
            if max_ in possibilities:
                assignation.append(max_)
                possibilities = np.setdiff1d(possibilities, max_)
                index += 1
                toselect = 1
            else:
                toselect += 1
        self.assignation = assignation

    def save_weights_for_VaDE(self):
        """Saving the pretrained weights for the encoder, decoder, pi, mu, var.
        """
        print('Saving weights.')
        state_dict = self.autoencoder.state_dict()

        self.VaDE.load_state_dict(state_dict, strict=False)
        self.VaDE.pi_prior.data = torch.from_numpy(
            self.gmm.weights_[self.assignation]).float().to(self.device)
        self.VaDE.mu_prior.data = torch.from_numpy(
            self.gmm.means_[self.assignation]).float().to(self.device)
        self.VaDE.log_var_prior.data = torch.log(
            torch.from_numpy(
                self.gmm.covariances_[self.assignation])).float().to(
                    self.device)
        torch.save(
            self.VaDE.state_dict(),
            'weights/pretrained_parameters_{}.pth'.format(self.args.dataset))

    def save_weights_ae(self):
        """Saving the pretrained weights for the encoder, decoder, pi, mu, var.
        """
        print('Saving weights.')
        state = {'state_dict': self.autoencoder.state_dict()}
        torch.save(
            state, 'weights/autoencoder_parameters_{}.pth.tar'.format(
                self.args.dataset))

    def save_weights_vade(self):
        """Saving the pretrained weights for the encoder, decoder, pi, mu, var.
        """
        print('Saving weights.')
        state = {'state_dict': self.autoencoder.state_dict()}
        torch.save(
            state,
            'weights/vade_parameters_{}.pth.tar'.format(self.args.dataset))

    def train(self):
        """
        """
        if self.args.pretrain == True:
            self.VaDE.load_state_dict(
                torch.load('weights/pretrained_parameters_{}.pth'.format(
                    self.args.dataset),
                           map_location=self.device))
        else:
            self.VaDE.apply(weights_init_normal)
        self.optimizer = optim.Adam(self.VaDE.parameters(), lr=self.args.lr)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer,
                                                       step_size=10,
                                                       gamma=0.9)

        self.acc = []
        self.acc_t = []
        self.rec = []
        self.rec_t = []
        self.dkl = []
        self.dkl_t = []

        self.forward_step = ComputeLosses(self.VaDE, self.args)
        print('Training VaDE...')
        self.test_VaDE(-1)
        for epoch in range(self.args.epochs):
            self.train_VaDE(epoch)
            self.test_VaDE(epoch)
            lr_scheduler.step()
        self.save_weights_vade()

    def train_VaDE(self, epoch):
        self.VaDE.train()
        total_loss = 0
        total_dkl = 0
        total_rec = 0
        for (x_s, y_s), (x_u, _) in zip(cycle(self.dataloader_sup),
                                        self.dataloader_unsup):
            self.optimizer.zero_grad()
            x_s, y_s = x_s.to(self.device), y_s.to(self.device)
            x_u = x_u.to(self.device)
            if self.args.dataset == 'webcam':
                x_s = self.feature_extractor(x_s)
                x_s = x_s.detach()
                x_u = self.feature_extractor(x_u)
                x_u = x_u.detach()

            loss, reconst_loss, kl_div, acc = self.forward_step.forward(
                'train', x_s, y_s, x_u)
            loss.backward()
            self.optimizer.step()
            total_loss += loss.item()
            total_dkl += kl_div.item()
            total_rec += reconst_loss.item()
        self.acc.append(acc)
        self.dkl.append(total_dkl / len(self.dataloader_unsup))
        self.rec.append(total_rec / len(self.dataloader_unsup))
        print('Training VaDE... Epoch: {}, Loss: {}, Acc: {}'.format(
            epoch, total_loss / len(self.dataloader_unsup), acc))

    def test_VaDE(self, epoch):
        self.VaDE.eval()
        with torch.no_grad():
            total_loss = 0
            total_acc = 0
            total_dkl = 0
            total_rec = 0
            for x, y in self.dataloader_test:
                x, y = x.to(self.device), y.to(self.device)
                if self.args.dataset == 'webcam':
                    x = self.feature_extractor(x)
                    x = x.detach()
                loss, reconst_loss, kl_div, acc = self.forward_step.forward(
                    'test', x, y)
                total_loss += loss.item()
                total_acc += acc.item()
                total_dkl += kl_div.item()
                total_rec += reconst_loss.item()
        self.acc_t.append(total_acc / len(self.dataloader_test))
        self.dkl_t.append(total_dkl / len(self.dataloader_test))
        self.rec_t.append(total_rec / len(self.dataloader_test))
        print('Testing VaDE... Epoch: {}, Loss: {}, Acc: {}'.format(
            epoch, total_loss / len(self.dataloader_test),
            total_acc / len(self.dataloader_test)))

    def freeze_extractor(self):
        for _, param in self.feature_extractor.named_parameters():
            param.requires_grad = False
        self.feature_extractor.eval()
예제 #2
0
class TrainerVaDE:
    """This is the trainer for the Variational Deep Embedding (VaDE).
    """
    def __init__(self, args, device, dataloader):
        self.autoencoder = Autoencoder().to(device)
        self.VaDE = VaDE().to(device)
        self.dataloader = dataloader
        self.device = device
        self.args = args

    def pretrain(self):
        """Here we train an stacked autoencoder which will be used as the initialization for the VaDE. 
        This initialization is usefull because reconstruction in VAEs would be weak at the begining
        and the models are likely to get stuck in local minima.
        """
        optimizer = optim.Adam(self.autoencoder.parameters(), lr=0.002)
        self.autoencoder.apply(
            weights_init_normal
        )  #intializing weights using normal distribution.
        self.autoencoder.train()
        print('Training the autoencoder...')
        for epoch in range(30):
            total_loss = 0
            for x, _ in self.dataloader:
                optimizer.zero_grad()
                x = x.to(self.device)
                x_hat = self.autoencoder(x)
                loss = F.binary_cross_entropy(
                    x_hat, x, reduction='mean')  # just reconstruction
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
            print('Training Autoencoder... Epoch: {}, Loss: {}'.format(
                epoch, total_loss))
        self.train_GMM()  #training a GMM for initialize the VaDE
        self.save_weights_for_VaDE()  #saving weights for the VaDE

    def train_GMM(self):
        """It is possible to fit a Gaussian Mixture Model (GMM) using the latent space 
        generated by the stacked autoencoder. This way, we generate an initialization for 
        the priors (pi, mu, var) of the VaDE model.
        """
        print('Fiting Gaussian Mixture Model...')
        x = torch.cat([data[0] for data in self.dataloader
                       ]).view(-1, 784).to(self.device)  #all x samples.
        z = self.autoencoder.encode(x)
        self.gmm = GaussianMixture(n_components=10, covariance_type='diag')
        self.gmm.fit(z.cpu().detach().numpy())

    def save_weights_for_VaDE(self):
        """Saving the pretrained weights for the encoder, decoder, pi, mu, var.
        """
        print('Saving weights.')
        state_dict = self.autoencoder.state_dict()

        self.VaDE.load_state_dict(state_dict, strict=False)
        self.VaDE.pi_prior.data = torch.from_numpy(
            self.gmm.weights_).float().to(self.device)
        self.VaDE.mu_prior.data = torch.from_numpy(self.gmm.means_).float().to(
            self.device)
        self.VaDE.log_var_prior.data = torch.log(
            torch.from_numpy(self.gmm.covariances_)).float().to(self.device)
        torch.save(self.VaDE.state_dict(), self.args.pretrained_path)

    def train(self):
        """
        """
        if self.args.pretrain == True:
            self.VaDE.load_state_dict(
                torch.load(self.args.pretrained_path,
                           map_location=self.device))
        else:
            self.VaDE.apply(weights_init_normal)
        self.optimizer = optim.Adam(self.VaDE.parameters(), lr=self.args.lr)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer,
                                                       step_size=10,
                                                       gamma=0.9)
        print('Training VaDE...')
        for epoch in range(self.args.epochs):
            self.train_VaDE(epoch)
            self.test_VaDE(epoch)
            lr_scheduler.step()

    def train_VaDE(self, epoch):
        self.VaDE.train()

        total_loss = 0
        for x, _ in self.dataloader:
            self.optimizer.zero_grad()
            x = x.to(self.device)
            x_hat, mu, log_var, z = self.VaDE(x)
            #print('Before backward: {}'.format(self.VaDE.pi_prior))
            loss = self.compute_loss(x, x_hat, mu, log_var, z)
            loss.backward()
            self.optimizer.step()
            total_loss += loss.item()
            #print('After backward: {}'.format(self.VaDE.pi_prior))
        print('Training VaDE... Epoch: {}, Loss: {}'.format(epoch, total_loss))

    def test_VaDE(self, epoch):
        self.VaDE.eval()
        with torch.no_grad():
            total_loss = 0
            y_true, y_pred = [], []
            for x, true in self.dataloader:
                x = x.to(self.device)
                x_hat, mu, log_var, z = self.VaDE(x)
                gamma = self.compute_gamma(z, self.VaDE.pi_prior)
                pred = torch.argmax(gamma, dim=1)
                loss = self.compute_loss(x, x_hat, mu, log_var, z)
                total_loss += loss.item()
                y_true.extend(true.numpy())
                y_pred.extend(pred.cpu().detach().numpy())

            acc = self.cluster_acc(np.array(y_true), np.array(y_pred))
            print('Testing VaDE... Epoch: {}, Loss: {}, Acc: {}'.format(
                epoch, total_loss, acc[0]))

    def compute_loss(self, x, x_hat, mu, log_var, z):
        p_c = self.VaDE.pi_prior
        gamma = self.compute_gamma(z, p_c)

        log_p_x_given_z = F.binary_cross_entropy(x_hat, x, reduction='sum')
        h = log_var.exp().unsqueeze(1) + (mu.unsqueeze(1) -
                                          self.VaDE.mu_prior).pow(2)
        h = torch.sum(self.VaDE.log_var_prior +
                      h / self.VaDE.log_var_prior.exp(),
                      dim=2)
        log_p_z_given_c = 0.5 * torch.sum(gamma * h)
        log_p_c = torch.sum(gamma * torch.log(p_c + 1e-9))
        log_q_c_given_x = torch.sum(gamma * torch.log(gamma + 1e-9))
        log_q_z_given_x = 0.5 * torch.sum(1 + log_var)

        loss = log_p_x_given_z + log_p_z_given_c - log_p_c + log_q_c_given_x - log_q_z_given_x
        loss /= x.size(0)
        return loss

    def compute_gamma(self, z, p_c):
        h = (z.unsqueeze(1) -
             self.VaDE.mu_prior).pow(2) / self.VaDE.log_var_prior.exp()
        h += self.VaDE.log_var_prior
        h += torch.Tensor([np.log(np.pi * 2)]).to(self.device)
        p_z_c = torch.exp(
            torch.log(p_c + 1e-9).unsqueeze(0) -
            0.5 * torch.sum(h, dim=2)) + 1e-9
        gamma = p_z_c / torch.sum(p_z_c, dim=1, keepdim=True)
        return gamma

    def cluster_acc(self, real, pred):
        D = max(pred.max(), real.max()) + 1
        w = np.zeros((D, D), dtype=np.int64)
        for i in range(pred.size):
            w[pred[i], real[i]] += 1
        ind = linear_assignment(w.max() - w)
        return sum([w[i, j] for i, j in ind]) * 1.0 / pred.size * 100, w