Exemple #1
0
class SsVae(nn.Module):
    """
    This class encapsulates the parameters (neural networks) and models & guides needed to train a
    semi-supervised variational auto-encoder on the MNIST image dataset

    :param z_dim: size of the tensor representing the latent random variable z
                  (handwriting style for our MNIST dataset)
    :param h_dims: a tuple (or list) of MLP layers to be used in the neural networks
                   representing the parameters of the distributions in our model
    :param eps: a small float value used to scale down the output of Softmax and Sigmoid
                opertations in pytorch for numerical stability
    :param enum_discrete: if True, sum out the discrete latent variables to reduce variance of
                          the ELBO gradient
    :param aux_loss: use the auxiliary loss as the model variant 3
                     (http://pyro.ai/examples/ss-vae.html#Third-Variant:-Adding-a-Term-to-the-Objective)
    :param aux_loss_multiplier: the multiplier to use with the auxiliary loss
    :param use_cuda: use GPUs for faster training
    :param batch_size: batch size of calculation
    :param init_lr: initial learning rate to setup the optimizer
    :param continue_from: model file path to load the model states
    """
    def __init__(self, x_dim=p.NUM_PIXELS, y_dim=p.NUM_LABELS, z_dim=p.NUM_STYLE, h_dims=p.NUM_HIDDEN,
                 eps=p.EPS, enum_discrete=True, aux_loss=True, aux_loss_multiplier=300,
                 use_cuda=False, batch_size=100, init_lr=0.001, continue_from=None,
                 *args, **kwargs):
        super().__init__()

        # initialize the class with all arguments provided to the constructor
        self.x_dim = x_dim
        self.y_dim = y_dim

        self.use_cuda = use_cuda
        self.batch_size = batch_size
        self.init_lr = init_lr
        self.epoch = 1

        if continue_from is None:
            self.z_dim = z_dim
            self.h_dims = h_dims
            self.eps = eps
            self.enum_discrete = enum_discrete
            self.aux_loss = aux_loss
            self.aux_loss_multiplier = aux_loss_multiplier
            # define and instantiate the neural networks representing
            # the paramters of various distributions in the model
            self.__setup_networks()
        else:
            self.load(continue_from)

        # using GPUs for faster training of the networks
        if self.use_cuda:
            self.cuda()

    def __setup_networks(self):
        # define the neural networks used later in the model and the guide.
        self.encoder_y = ConvEncoderY(x_dim=self.x_dim, y_dim=self.y_dim, eps=self.eps)
        self.encoder_z = MlpEncoderZ(x_dim=self.x_dim, y_dim=self.y_dim, z_dim=self.z_dim,
                                     h_dims=self.h_dims, eps=self.eps)
        self.decoder = ConvDecoder(x_dim=self.x_dim, y_dim=self.y_dim, z_dim=self.z_dim, eps=self.eps)

        # setup the optimizer
        params = {"lr": self.init_lr, "betas": (0.9, 0.999)}
        self.optimizer = Adam(params)

        # set up the loss(es) for inference setting the enum_discrete parameter builds the loss as a sum
        # by enumerating each class label for the sampled discrete categorical distribution in the model
        loss_basic = SVI(self.model, self.guide, self.optimizer, loss="ELBO",
                         enum_discrete=self.enum_discrete)
        self.losses = [loss_basic]

        # aux_loss: whether to use the auxiliary loss from NIPS 14 paper (Kingma et al)
        if self.aux_loss:
            loss_aux = SVI(self.model_classify, self.guide_classify, self.optimizer, loss="ELBO")
            self.losses.append(loss_aux)

    def model(self, xs, ys=None):
        """
        The model corresponds to the following generative process:
        p(z) = normal(0,I)              # handwriting style (latent)
        p(y|x) = categorical(I/10.)     # which digit (semi-supervised)
        p(x|y,z) = bernoulli(mu(y,z))   # an image
        mu is given by a neural network  `decoder`

        :param xs: a batch of scaled vectors of pixels from an image
        :param ys: (optional) a batch of the class labels i.e.
                   the digit corresponding to the image(s)
        :return: None
        """
        # register this pytorch module and all of its sub-modules with pyro
        pyro.module("ss_vae", self)

        # inform Pyro that the variables in the batch of xs, ys are conditionally independent
        batch_size = xs.size(0)
        with pyro.iarange("independent"):
            # sample the handwriting style from the constant prior distribution
            prior_mu = Variable(torch.zeros([batch_size, self.z_dim]))
            prior_sigma = Variable(torch.ones([batch_size, self.z_dim]))
            zs = pyro.sample("z", dist.Normal(prior_mu, prior_sigma).reshape(extra_event_dims=1))

            # if the label y (which digit to write) is supervised, sample from the
            # constant prior, otherwise, observe the value (i.e. score it against the constant prior)
            alpha_prior = Variable(torch.ones([batch_size, self.y_dim]) / (1.0 * self.y_dim))
            if ys is None:
                ys = pyro.sample("y", dist.OneHotCategorical(alpha_prior))
            else:
                pyro.sample("y", dist.OneHotCategorical(alpha_prior), obs=ys)

            # finally, score the image (x) using the handwriting style (z) and
            # the class label y (which digit to write) against the
            # parametrized distribution p(x|y,z) = bernoulli(decoder(y,z))
            # where `decoder` is a neural network
            mu = self.decoder.forward(zs, ys)
            pyro.sample("x", dist.Bernoulli(mu).reshape(extra_event_dims=1), obs=xs)

    def guide(self, xs, ys=None):
        """
        The guide corresponds to the following:
        q(y|x) = categorical(alpha(x))              # infer digit from an image
        q(z|x,y) = normal(mu(x,y),sigma(x,y))       # infer handwriting style from an image and the digit
        mu, sigma are given by a neural network `encoder_z`
        alpha is given by a neural network `encoder_y`

        :param xs: a batch of scaled vectors of pixels from an image
        :param ys: (optional) a batch of the class labels i.e.
                   the digit corresponding to the image(s)
        :return: None
        """
        # inform Pyro that the variables in the batch of xs, ys are conditionally independent
        with pyro.iarange("independent"):
            # if the class label (the digit) is not supervised, sample
            # (and score) the digit with the variational distribution
            # q(y|x) = categorical(alpha(x))
            if ys is None:
                alpha = self.encoder_y.forward(xs)
                ys = pyro.sample("y", dist.OneHotCategorical(alpha))

            # sample (and score) the latent handwriting-style with the variational
            # distribution q(z|x,y) = normal(mu(x,y),sigma(x,y))
            mu, sigma = self.encoder_z.forward(xs, ys)
            zs = pyro.sample("z", dist.Normal(mu, sigma).reshape(extra_event_dims=1))

    def classifier(self, xs):
        """
        classify an image (or a batch of images)

        :param xs: a batch of scaled vectors of pixels from an image
        :return: a batch of the corresponding class labels (as one-hots)
        """
        # use the trained model q(y|x) = categorical(alpha(x))
        # compute all class probabilities for the image(s)
        alpha = self.encoder_y.forward(xs)

        # get the index (digit) that corresponds to
        # the maximum predicted class probability
        res, ind = torch.topk(alpha, 1)

        # convert the digit(s) to one-hot tensor(s)
        ys = Variable(torch.zeros(alpha.size()))
        ys = ys.scatter_(1, ind, 1.0)
        return ys

    def model_classify(self, xs, ys=None):
        """
        this model is used to add an auxiliary (supervised) loss as described in the
        NIPS 2014 paper by Kingma et al titled
        "Semi-Supervised Learning with Deep Generative Models"
        """
        # register all pytorch (sub)modules with pyro
        pyro.module("ss_vae", self)

        # inform Pyro that the variables in the batch of xs, ys are conditionally independent
        with pyro.iarange("independent"):
            # this here is the extra Term to yield an auxiliary loss that we do gradient descend on
            # similar to the NIPS 14 paper (Kingma et al).
            if ys is not None:
                alpha = self.encoder_y.forward(xs)
                with pyro.poutine.scale(None, self.aux_loss_multiplier):
                    pyro.sample("y_aux", dist.OneHotCategorical(alpha), obs=ys)

    def guide_classify(self, xs, ys=None):
        """
        dummy guide function to accompany model_classify in inference
        """
        pass

    def model_sample(self, ys, batch_size=1):
        with torch.no_grad():
            # sample the handwriting style from the constant prior distribution
            prior_mu = Variable(torch.zeros([batch_size, self.z_dim]))
            prior_sigma = Variable(torch.ones([batch_size, self.z_dim]))
            zs = pyro.sample("z", dist.Normal(prior_mu, prior_sigma).reshape(extra_event_dims=1))

            # sample an image using the decoder
            mu = self.decoder.forward(zs, ys)
            xs = pyro.sample("sample", dist.Bernoulli(mu).reshape(extra_event_dims=1))
            return xs, mu

    def guide_sample(self, xs, ys, batch_size=1):
        with torch.no_grad():
            # obtain z using `encoder_z`
            xs, ys = Variable(xs), Variable(ys)
            z_mu, z_sigma = self.encoder_z(xs, ys)
            return z_mu, z_sigma

    def train_epoch(self, data_loaders):
        """
        runs the inference algorithm for an epoch
        returns the values of all losses separately on supervised and unsupervised parts
        """
        # how often would a supervised batch be encountered during inference
        # e.g. if sup_num is 3000, we would have every 16th = int(50000/3000) batch supervised
        # until we have traversed through the all supervised batches
        unsup_num = len(data_loaders["train_unsup"])
        sup_num = len(data_loaders["train_sup"])

        train_data_size = unsup_num + sup_num
        periodic_interval_batches = int(train_data_size // (1.0 * sup_num))
        sup_num = int(train_data_size / periodic_interval_batches)
        train_data_size = unsup_num + sup_num

        # initialize variables to store loss values
        num_losses = len(self.losses)

        # initialize variables to store loss values
        epoch_losses_sup = [0.] * num_losses
        epoch_losses_unsup = [0.] * num_losses

        # setup the iterators for training data loaders
        sup_iter = iter(data_loaders["train_sup"])
        unsup_iter = iter(data_loaders["train_unsup"])

        # count the number of supervised batches seen in this epoch
        cnt_sup = 0
        for i in tqdm(range(train_data_size), desc="training  "):
            # whether this batch is supervised or not
            is_supervised = (i % periodic_interval_batches == 1) and cnt_sup < sup_num

            # extract the corresponding batch
            if is_supervised:
                xs, ys = next(sup_iter)
                xs, ys = Variable(xs), Variable(ys)
                cnt_sup += 1
            else:
                xs = next(unsup_iter)
                xs = Variable(xs)

            # run the inference for each loss with supervised or un-supervised
            # data as arguments
            for loss_id in range(num_losses):
                #print(xs)
                if is_supervised:
                    new_loss = self.losses[loss_id].step(xs, ys)
                    #print("sup_loss:", new_loss)
                    epoch_losses_sup[loss_id] += new_loss
                else:
                    new_loss = self.losses[loss_id].step(xs)
                    #print("unsup_loss:", new_loss)
                    epoch_losses_unsup[loss_id] += new_loss

        # compute average epoch losses i.e. losses per example
        avg_losses_sup = map(lambda x: x / sup_num, epoch_losses_sup)
        avg_losses_unsup = map(lambda x: x / unsup_num, epoch_losses_unsup)

        # return the values of all losses
        return avg_losses_sup, avg_losses_unsup

    def get_accuracy(self, data_loader, desc=None):
        """
        compute the accuracy over the supervised training set or the testing set
        """
        predictions, actuals = [], []
        for i, (data) in tqdm(enumerate(data_loader), total=len(data_loader), desc=desc):
            xs, ys = data
            xs, ys = Variable(xs), Variable(ys)
            # use classification function to compute all predictions for each batch
            with torch.no_grad():
                predictions.append(self.classifier(xs))
            actuals.append(ys)

        # compute the number of accurate predictions
        accurate_preds = 0
        for pred, act in zip(predictions, actuals):
            for i in range(pred.size(0)):
                v = torch.sum(pred[i] == act[i])
                accurate_preds += (v.data[0] == pred.size(1))

        # calculate the accuracy between 0 and 1
        accuracy = (accurate_preds * 1.0) / (len(predictions) * self.batch_size)
        return accuracy

    def save(self, file_path, **kwargs):
        Path(file_path).parent.mkdir(mode=0o755, parents=True, exist_ok=True)
        logger.info(f"saving the model to {file_path}")
        states = kwargs
        states["epoch"] = self.epoch
        states["ss_vae"] = self.state_dict()
        states.update({
            "z_dim": self.z_dim,
            "h_dims": self.h_dims,
            "eps": self.eps,
            "enum_discrete": self.enum_discrete,
            "aux_loss": self.aux_loss,
            "aux_loss_multiplier": self.aux_loss_multiplier,
            "optimizer": self.optimizer.get_state(),
        })
        torch.save(states, file_path)

    def load(self, file_path):
        if isinstance(file_path, str):
            file_path = Path(file_path)
        if not file_path.exists():
            logger.error(f"no such file {file_path} exists")
            sys.exit(1)
        logger.info(f"loading the model from {file_path}")
        states = torch.load(file_path)

        self.z_dim = states["z_dim"]
        self.h_dims = states["h_dims"]
        self.eps = states["eps"]
        self.enum_discrete = states["enum_discrete"]
        self.aux_loss = states["aux_loss"]
        self.aux_loss_multiplier = states["aux_loss_multiplier"]
        self.epoch = states["epoch"]

        self.__setup_networks()
        self.load_state_dict(states["ss_vae"])
        self.optimizer.set_state(states["optimizer"])
    def train(self, epochs, lr=3.0e-5, tf=2):
        """Train the DLGM for some number of epochs."""
        # Set up the optimizer.
        optimizer = Adam({"lr": lr})
        train_elbo = {}
        test_elbo = {}
        start_epoch = 0
        # Load cached state, if given.
        if self.load_dir is not None:
            filename = self.load_dir + 'checkpoint.tar'
            checkpoint = torch.load(filename)
            self.encoder.load_state_dict(checkpoint['encoder_state_dict'])
            self.decoder.load_state_dict(checkpoint['decoder_state_dict'])
            optimizer.set_state(checkpoint['optimizer_state'])
            train_elbo = checkpoint['train_elbo']
            test_elbo = checkpoint['test_elbo']
            start_epoch = checkpoint['epoch'] + 1
            self.partition = checkpoint['partition']
            self.train_loader, self.test_loader = get_data_loaders(
                self.partition, self.p)

        # Set up the inference algorithm.
        elbo = Trace_ELBO()
        svi = SVI(self.model, self.guide, optimizer, loss=elbo)

        print("dataset length: ", len(self.train_loader.dataset))
        for epoch in range(start_epoch, start_epoch + epochs + 1, 1):
            train_loss = 0.0
            # Iterate over the training data.
            for i, temp in enumerate(self.train_loader):
                x = temp['spec'].cuda().view(-1, self.input_dim)
                train_loss += svi.step(x)
            # Report training diagnostics.
            normalizer_train = len(self.train_loader.dataset)
            total_epoch_loss_train = train_loss / normalizer_train
            train_elbo[epoch] = total_epoch_loss_train
            print("[epoch %03d]  average train loss: %.4f" %
                  (epoch, total_epoch_loss_train))

            if (epoch + 1) % tf == 0:
                test_loss = 0.0
                # Iterate over the test set.
                for i, temp in enumerate(self.test_loader):
                    x = temp['spec'].cuda().view(-1, self.input_dim)
                    test_loss += svi.evaluate_loss(x)
                # Report test diagnostics.
                normalizer_test = len(self.test_loader.dataset)
                total_epoch_loss_test = test_loss / normalizer_test
                test_elbo[epoch] = total_epoch_loss_test
                print("[epoch %03d]  average test loss: %.4f" %
                      (epoch, total_epoch_loss_test))
                self.visualize()

                if self.save_dir is not None:
                    filename = self.save_dir + 'checkpoint.tar'
                    state = {
                        'train_elbo': train_elbo,
                        'test_elbo': test_elbo,
                        'epoch': epoch,
                        'encoder_state_dict': self.encoder.state_dict(),
                        'decoder_state_dict': self.decoder.state_dict(),
                        'optimizer_state': optimizer.get_state(),
                        'partition': self.partition,
                    }
                    torch.save(state, filename)