Example #1
0
class VAE:
    def __init__(self, args):
        self.num_epochs = args.epochs
        self.cuda = args.cuda
        self.verbose = args.verbose

        self.batch_size = args.batch_size
        self.batch_size_val = args.batch_size_val
        self.learning_rate = args.learning_rate
        self.decay_epoch = args.decay_epoch
        self.lr_decay = args.lr_decay
        self.w_cat = args.w_categ
        self.w_gauss = args.w_gauss
        self.w_rec = args.w_rec
        self.rec_type = args.rec_type

        self.gaussian_size = args.gaussian_size
        self.input_size = args.input_size
        self.output_size = args.output_size

        self.network = VAENet(self.input_size, self.gaussian_size,
                              self.output_size)
        self.losses = LossFunctions()
        self.metrics = Metrics()

        if self.cuda:
            self.network = self.network.cuda()

    # ===============================================================================
    # This is where all the losses are computed; helper method for train_epoch.

    def unlabeled_loss(self, chunk, out_net):

        # obtain network variables
        z, data_recon = out_net['gaussian'], out_net['x_rec']
        mu, var = out_net['mean'], out_net['var']

        # reconstruction loss
        loss_rec = self.losses.reconstruction_loss(chunk, data_recon,
                                                   self.rec_type)

        # gaussian loss
        loss_gauss = self.losses.kl_divergence_v2(mu, var)

        # total loss
        loss_total = self.w_rec * loss_rec + self.w_gauss * loss_gauss

        loss_dic = {
            'total': loss_total,
            'reconstruction': loss_rec,
            'gaussian': loss_gauss,
        }
        return loss_dic

    # ===============================================================================

    # ===============================================================================
    # Train the model for one epoch; helper method for train.

    def train_epoch(self, optimizer, data_loader):

        self.network.train()
        total_loss = 0.
        recon_loss = 0.
        gauss_loss = 0.

        accuracy = 0.
        num_batches = 0.

        for (path, chunk) in data_loader:

            if self.cuda == 1:
                path = path.cuda()
                chunk = chunk.cuda()

            optimizer.zero_grad()

            # flatten data
            path = path.view(path.size(0), -1)
            chunk = chunk.view(chunk.size(0), -1)

            # forward call
            out_net = self.network(path)
            unlab_loss_dic = self.unlabeled_loss(chunk, out_net)
            total = unlab_loss_dic['total']

            # accumulate values
            total_loss += total.item()
            recon_loss += unlab_loss_dic['reconstruction'].item()
            gauss_loss += unlab_loss_dic['gaussian'].item()

            # perform backpropagation
            total.backward()
            optimizer.step()

            num_batches += 1.

        # average per batch
        total_loss /= num_batches
        recon_loss /= num_batches
        gauss_loss /= num_batches

        return total_loss, recon_loss, gauss_loss

    # ==============================================================================

    # ==============================================================================
    # Test the model. This method is essentially the same as train_epoch except that
    # self.network.eval is called at the beginning of this method.

    def test(self, data_loader):

        self.network.eval()
        total_loss = 0.
        recon_loss = 0.
        gauss_loss = 0.

        accuracy = 0.
        num_batches = 0.

        with torch.no_grad():
            for path, chunk in data_loader:

                if self.cuda == 1:
                    path = path.cuda()
                    chunk = chunk.cuda()

                # flatten data
                path = path.view(path.size(0), -1)
                chunk = chunk.view(chunk.size(0), -1)

                # forward call
                out_net = self.network(path)
                unlab_loss_dic = self.unlabeled_loss(chunk, out_net)

                # accumulate values
                total_loss += unlab_loss_dic['total'].item()
                recon_loss += unlab_loss_dic['reconstruction'].item()
                gauss_loss += unlab_loss_dic['gaussian'].item()

                num_batches += 1.

        # average per batch
        total_loss /= num_batches
        recon_loss /= num_batches
        gauss_loss /= num_batches

        return total_loss, recon_loss, gauss_loss

    # ==============================================================================

    def train(self, train_loader, val_loader):
        """Train the model

    Args:
        train_loader: (DataLoader) corresponding loader containing the training data
        val_loader: (DataLoader) corresponding loader containing the validation data

    Returns:
        output: (dict) contains the history of train/val loss
    """
        optimizer = optim.Adam(self.network.parameters(),
                               lr=self.learning_rate)

        train_total_losses = []
        train_recon_losses = []
        train_gauss_losses = []

        valid_total_losses = []
        valid_recon_losses = []
        valid_gauss_losses = []

        for epoch in range(1, self.num_epochs + 1):
            train_total_loss, train_recon_loss, train_gauss_loss = self.train_epoch(
                optimizer, train_loader)
            valid_total_loss, valid_recon_loss, valid_gauss_loss = self.test(
                val_loader)

            print(
                f"Epoch {epoch:5} / {self.num_epochs} | Total loss: {round(valid_total_loss, 2):7} | Recon loss: {round(valid_recon_loss, 2):7} | Gauss loss: {round(valid_gauss_loss, 2):7}"
            )

            train_total_losses.append(train_total_loss)
            train_recon_losses.append(train_recon_loss)
            train_gauss_losses.append(train_gauss_loss)

            valid_total_losses.append(valid_total_loss)
            valid_recon_losses.append(valid_recon_loss)
            valid_gauss_losses.append(valid_gauss_loss)

        return {
            'train_total_losses': train_total_losses,
            'train_recon_losses': train_recon_losses,
            'train_gauss_losses': train_gauss_losses,
            'valid_total_losses': valid_total_losses,
            'valid_recon_losses': valid_recon_losses,
            'valid_gauss_losses': valid_gauss_losses,
        }

    def latent_features(self, data_loader):
        """Obtain latent features learnt by the model

    Args:
        data_loader: (DataLoader) loader containing the data
        return_labels: (boolean) whether to return true labels or not

    Returns:
       features: (array) array containing the features from the data
    """
        self.network.eval()

        N = len(data_loader.dataset)

        features = np.zeros((N, self.gaussian_size))

        start_ind = 0

        with torch.no_grad():
            for (data, ) in data_loader:

                if self.cuda == 1:
                    data = data.cuda()

                data = data.view(data.size(0), -1)

                out = self.network(data)
                latent_feat = out['mean']

                end_ind = min(start_ind + data.size(0), N + 1)
                features[start_ind:end_ind] = latent_feat.cpu().numpy()

                start_ind += data.size(0)

        return features

    def reconstruct_data(self, data_loader, sample_size=-1):

        self.network.eval()

        # sample random data from loader
        indices = np.random.randint(0,
                                    len(data_loader.dataset),
                                    size=sample_size)
        test_random_loader = torch.utils.data.DataLoader(
            data_loader.dataset,
            batch_size=sample_size,
            sampler=SubsetRandomSampler(indices))

        # obtain values
        it = iter(test_random_loader)
        test_batch_data, _ = it.next()
        original = test_batch_data.data.numpy()
        if self.cuda:
            test_batch_data = test_batch_data.cuda()

        # obtain reconstructed data
        out = self.network(test_batch_data)
        reconstructed = out['x_rec']
        return original, reconstructed.data.cpu().numpy()

    def plot_latent_space(self, data_loader, save=False):
        """Plot the latent space learnt by the model

    Args:
        data: (array) corresponding array containing the data
        labels: (array) corresponding array containing the labels
        save: (bool) whether to save the latent space plot

    Returns:
        fig: (figure) plot of the latent space
    """
        # obtain the latent features
        features = self.latent_features(data_loader)

        # plot only the first 2 dimensions
        fig = plt.figure(figsize=(8, 6))
        plt.scatter(features[:, 0],
                    features[:, 1],
                    c=labels,
                    marker='o',
                    edgecolor='none',
                    cmap=plt.cm.get_cmap('jet', 10),
                    s=10)
        plt.colorbar()
        if (save):
            fig.savefig('latent_space.png')
        return fig

    def random_generation(self, num_chunks):

        latent_vecs = torch.randn(num_chunks, 64).cuda()
        chunks = self.network.generative.pxz(latent_vecs)

        return chunks.cpu().detach().numpy()
Example #2
0
class GMVAE:
    def __init__(self, args):
        self.num_epochs = args.epochs
        self.cuda = args.cuda
        self.verbose = args.verbose

        self.batch_size = args.batch_size
        self.batch_size_val = args.batch_size_val
        self.learning_rate = args.learning_rate
        self.decay_epoch = args.decay_epoch
        self.lr_decay = args.lr_decay
        self.w_cat = args.w_categ
        self.w_gauss = args.w_gauss
        self.w_rec = args.w_rec
        self.rec_type = args.rec_type

        self.num_classes = args.num_classes
        self.gaussian_size = args.gaussian_size
        self.input_size = args.input_size

        # gumbel
        self.init_temp = args.init_temp
        self.decay_temp = args.decay_temp
        self.hard_gumbel = args.hard_gumbel
        self.min_temp = args.min_temp
        self.decay_temp_rate = args.decay_temp_rate
        self.gumbel_temp = self.init_temp

        self.network = GMVAENet(self.input_size, self.gaussian_size,
                                self.num_classes)
        self.losses = LossFunctions()
        self.metrics = Metrics()

        if self.cuda:
            self.network = self.network.cuda()

    def unlabeled_loss(self, data, out_net):
        """Method defining the loss functions derived from the variational lower bound
    Args:
        data: (array) corresponding array containing the input data
        out_net: (dict) contains the graph operations or nodes of the network output

    Returns:
        loss_dic: (dict) contains the values of each loss function and predictions
    """
        # obtain network variables
        z, data_recon = out_net['gaussian'], out_net['x_rec']
        logits, prob_cat = out_net['logits'], out_net['prob_cat']
        y_mu, y_var = out_net['y_mean'], out_net['y_var']
        mu, var = out_net['mean'], out_net['var']

        # reconstruction loss
        loss_rec = self.losses.reconstruction_loss(data, data_recon,
                                                   self.rec_type)

        # gaussian loss
        loss_gauss = self.losses.gaussian_loss(z, mu, var, y_mu, y_var)

        # categorical loss
        loss_cat = -self.losses.entropy(logits, prob_cat) - np.log(0.1)

        # total loss
        loss_total = self.w_rec * loss_rec + self.w_gauss * loss_gauss + self.w_cat * loss_cat

        # obtain predictions
        _, predicted_labels = torch.max(logits, dim=1)

        loss_dic = {
            'total': loss_total,
            'predicted_labels': predicted_labels,
            'reconstruction': loss_rec,
            'gaussian': loss_gauss,
            'categorical': loss_cat
        }
        return loss_dic

    def train_epoch(self, optimizer, data_loader):
        """Train the model for one epoch

    Args:
        optimizer: (Optim) optimizer to use in backpropagation
        data_loader: (DataLoader) corresponding loader containing the training data

    Returns:
        average of all loss values, accuracy, nmi
    """
        self.network.train()
        total_loss = 0.
        recon_loss = 0.
        cat_loss = 0.
        gauss_loss = 0.

        accuracy = 0.
        nmi = 0.
        num_batches = 0.

        true_labels_list = []
        predicted_labels_list = []

        # iterate over the dataset
        for (data, labels) in data_loader:
            if self.cuda == 1:
                data = data.cuda()

            optimizer.zero_grad()

            # flatten data
            data = data.view(data.size(0), -1)

            # forward call
            out_net = self.network(data, self.gumbel_temp, self.hard_gumbel)
            unlab_loss_dic = self.unlabeled_loss(data, out_net)
            total = unlab_loss_dic['total']

            # accumulate values
            total_loss += total.item()
            recon_loss += unlab_loss_dic['reconstruction'].item()
            gauss_loss += unlab_loss_dic['gaussian'].item()
            cat_loss += unlab_loss_dic['categorical'].item()

            # perform backpropagation
            total.backward()
            optimizer.step()

            # save predicted and true labels
            predicted = unlab_loss_dic['predicted_labels']
            true_labels_list.append(labels)
            predicted_labels_list.append(predicted)

            num_batches += 1.

        # average per batch
        total_loss /= num_batches
        recon_loss /= num_batches
        gauss_loss /= num_batches
        cat_loss /= num_batches

        # concat all true and predicted labels
        true_labels = torch.cat(true_labels_list, dim=0).cpu().numpy()
        predicted_labels = torch.cat(predicted_labels_list,
                                     dim=0).cpu().numpy()

        # compute metrics
        accuracy = 100.0 * self.metrics.cluster_acc(predicted_labels,
                                                    true_labels)
        nmi = 100.0 * self.metrics.nmi(predicted_labels, true_labels)

        return total_loss, recon_loss, gauss_loss, cat_loss, accuracy, nmi

    def test(self, data_loader, return_loss=False):
        """Test the model with new data

    Args:
        data_loader: (DataLoader) corresponding loader containing the test/validation data
        return_loss: (boolean) whether to return the average loss values
          
    Return:
        accuracy and nmi for the given test data

    """
        self.network.eval()
        total_loss = 0.
        recon_loss = 0.
        cat_loss = 0.
        gauss_loss = 0.

        accuracy = 0.
        nmi = 0.
        num_batches = 0.

        true_labels_list = []
        predicted_labels_list = []

        with torch.no_grad():
            for data, labels in data_loader:
                if self.cuda == 1:
                    data = data.cuda()

                # flatten data
                data = data.view(data.size(0), -1)

                # forward call
                out_net = self.network(data, self.gumbel_temp,
                                       self.hard_gumbel)
                unlab_loss_dic = self.unlabeled_loss(data, out_net)

                # accumulate values
                total_loss += unlab_loss_dic['total'].item()
                recon_loss += unlab_loss_dic['reconstruction'].item()
                gauss_loss += unlab_loss_dic['gaussian'].item()
                cat_loss += unlab_loss_dic['categorical'].item()

                # save predicted and true labels
                predicted = unlab_loss_dic['predicted_labels']
                true_labels_list.append(labels)
                predicted_labels_list.append(predicted)

                num_batches += 1.

        # average per batch
        if return_loss:
            total_loss /= num_batches
            recon_loss /= num_batches
            gauss_loss /= num_batches
            cat_loss /= num_batches

        # concat all true and predicted labels
        true_labels = torch.cat(true_labels_list, dim=0).cpu().numpy()
        predicted_labels = torch.cat(predicted_labels_list,
                                     dim=0).cpu().numpy()

        # compute metrics
        accuracy = 100.0 * self.metrics.cluster_acc(predicted_labels,
                                                    true_labels)
        nmi = 100.0 * self.metrics.nmi(predicted_labels, true_labels)

        if return_loss:
            return total_loss, recon_loss, gauss_loss, cat_loss, accuracy, nmi
        else:
            return accuracy, nmi

    def train(self, train_loader, val_loader):
        """Train the model

    Args:
        train_loader: (DataLoader) corresponding loader containing the training data
        val_loader: (DataLoader) corresponding loader containing the validation data

    Returns:
        output: (dict) contains the history of train/val loss
    """
        optimizer = optim.Adam(self.network.parameters(),
                               lr=self.learning_rate)
        train_history_acc, val_history_acc = [], []
        train_history_nmi, val_history_nmi = [], []

        for epoch in range(1, self.num_epochs + 1):
            train_loss, train_rec, train_gauss, train_cat, train_acc, train_nmi = self.train_epoch(
                optimizer, train_loader)
            val_loss, val_rec, val_gauss, val_cat, val_acc, val_nmi = self.test(
                val_loader, True)

            # if verbose then print specific information about training
            if self.verbose == 1:
                print("(Epoch %d / %d)" % (epoch, self.num_epochs))
                print("Train - REC: %.5lf;  Gauss: %.5lf;  Cat: %.5lf;" % \
                      (train_rec, train_gauss, train_cat))
                print("Valid - REC: %.5lf;  Gauss: %.5lf;  Cat: %.5lf;" % \
                      (val_rec, val_gauss, val_cat))
                print("Accuracy=Train: %.5lf; Val: %.5lf   NMI=Train: %.5lf; Val: %.5lf   Total Loss=Train: %.5lf; Val: %.5lf" % \
                      (train_acc, val_acc, train_nmi, val_nmi, train_loss, val_loss))
            else:
                print('(Epoch %d / %d) Train_Loss: %.3lf; Val_Loss: %.3lf   Train_ACC: %.3lf; Val_ACC: %.3lf   Train_NMI: %.3lf; Val_NMI: %.3lf' % \
                      (epoch, self.num_epochs, train_loss, val_loss, train_acc, val_acc, train_nmi, val_nmi))

            # decay gumbel temperature
            if self.decay_temp == 1:
                self.gumbel_temp = np.maximum(
                    self.init_temp * np.exp(-self.decay_temp_rate * epoch),
                    self.min_temp)
                if self.verbose == 1:
                    print("Gumbel Temperature: %.3lf" % self.gumbel_temp)

            train_history_acc.append(train_acc)
            val_history_acc.append(val_acc)
            train_history_nmi.append(train_nmi)
            val_history_nmi.append(val_nmi)
        return {
            'train_history_nmi': train_history_nmi,
            'val_history_nmi': val_history_nmi,
            'train_history_acc': train_history_acc,
            'val_history_acc': val_history_acc
        }

    def latent_features(self, data_loader, return_labels=False):
        """Obtain latent features learnt by the model

    Args:
        data_loader: (DataLoader) loader containing the data
        return_labels: (boolean) whether to return true labels or not

    Returns:
       features: (array) array containing the features from the data
    """
        self.network.eval()
        N = len(data_loader.dataset)
        features = np.zeros((N, self.gaussian_size))
        if return_labels:
            true_labels = np.zeros(N, dtype=np.int64)
        start_ind = 0
        with torch.no_grad():
            for (data, labels) in data_loader:
                if self.cuda == 1:
                    data = data.cuda()
                # flatten data
                data = data.view(data.size(0), -1)
                out = self.network.inference(data, self.gumbel_temp,
                                             self.hard_gumbel)
                latent_feat = out['mean']
                end_ind = min(start_ind + data.size(0), N + 1)

                # return true labels
                if return_labels:
                    true_labels[start_ind:end_ind] = labels.cpu().numpy()
                features[start_ind:end_ind] = latent_feat.cpu().detach().numpy(
                )
                start_ind += data.size(0)
        if return_labels:
            return features, true_labels
        return features

    def reconstruct_data(self, data_loader, sample_size=-1):
        """Reconstruct Data

    Args:
        data_loader: (DataLoader) loader containing the data
        sample_size: (int) size of random data to consider from data_loader
      
    Returns:
        reconstructed: (array) array containing the reconstructed data
    """
        self.network.eval()

        # sample random data from loader
        indices = np.random.randint(0,
                                    len(data_loader.dataset),
                                    size=sample_size)
        test_random_loader = torch.utils.data.DataLoader(
            data_loader.dataset,
            batch_size=sample_size,
            sampler=SubsetRandomSampler(indices))

        # obtain values
        it = iter(test_random_loader)
        test_batch_data, _ = it.next()
        original = test_batch_data.data.numpy()
        if self.cuda:
            test_batch_data = test_batch_data.cuda()

        # obtain reconstructed data
        out = self.network(test_batch_data, self.gumbel_temp, self.hard_gumbel)
        reconstructed = out['x_rec']
        return original, reconstructed.data.cpu().numpy()

    def plot_latent_space(self, data_loader, save=False):
        """Plot the latent space learnt by the model

    Args:
        data: (array) corresponding array containing the data
        labels: (array) corresponding array containing the labels
        save: (bool) whether to save the latent space plot

    Returns:
        fig: (figure) plot of the latent space
    """
        # obtain the latent features
        features = self.latent_features(data_loader)

        # plot only the first 2 dimensions
        fig = plt.figure(figsize=(8, 6))
        plt.scatter(features[:, 0],
                    features[:, 1],
                    c=labels,
                    marker='o',
                    edgecolor='none',
                    cmap=plt.cm.get_cmap('jet', 10),
                    s=10)
        plt.colorbar()
        if (save):
            fig.savefig('latent_space.png')
        return fig

    def random_generation(self, num_elements=1):
        """Random generation for each category

    Args:
        num_elements: (int) number of elements to generate

    Returns:
        generated data according to num_elements
    """
        # categories for each element
        arr = np.array([])
        for i in range(self.num_classes):
            arr = np.hstack([arr, np.ones(num_elements) * i])
        indices = arr.astype(int).tolist()

        categorical = F.one_hot(torch.tensor(indices),
                                self.num_classes).float()

        if self.cuda:
            categorical = categorical.cuda()

        # infer the gaussian distribution according to the category
        mean, var = self.network.generative.pzy(categorical)

        # gaussian random sample by using the mean and variance
        noise = torch.randn_like(var)
        std = torch.sqrt(var)
        gaussian = mean + noise * std

        # generate new samples with the given gaussian
        generated = self.network.generative.pxz(gaussian)

        return generated.cpu().detach().numpy()