Exemplo n.º 1
0
    def word_p(self, val):
        if isinstance(val, float):
            val = self.beta.new_tensor(val)

        if not isinstance(val, torch.FloatTensor) and not isinstance(val, torch.cuda.FloatTensor):
            raise InvalidArgumentError("word_p should be a float or FloatTensor, not {}.".format(type(val)))

        if val > self.min_word_p and val <= 1.:
            self._word_p = val
        elif val > 1.:
            self._word_p = self.min_word_p.new_tensor(1.)
        elif val <= self.min_word_p:
            self._word_p = self.min_word_p.new_tensor(self.min_word_p.item())
Exemplo n.º 2
0
    def min_rate(self, val):
        if isinstance(val, float):
            val = torch.tensor(val, device=self.device)

        if not isinstance(val, torch.FloatTensor) and not isinstance(
                val, torch.cuda.FloatTensor):
            raise InvalidArgumentError(
                "min_rate should be a float or FloatTensor.")

        if val > 0:
            self._min_rate = val
        else:
            self._min_rate = torch.tensor(0., device=self.device)
Exemplo n.º 3
0
    def _unpack_data(self, data, N):
        """Unpacks the input data to the forward pass and supplies missing data.

        Args:
            data(list of torch.Tensor): data provided to forward pass. We assume the following ordering
                [input, length(optional), mask(optional), reversed(optional), reversed_length(optional)]
            N(int): the number of data tensors to return. Can be 1-4.

        Returns:
            x_in(torch.Tensor): batch of input sequences.
            x_len(torch.Tensor): lengths of input sequences or None.
            x_mask(torch.Tensor): mask over the padding of input sequences that are not of max length or None.
            x_reverse(torch.Tensor): batch of reversed input sequences or None.
            x_len_reverse(torch.Tensor): lengths of reversed input sequences or None.
        """
        # Checks and padding of data, so we have N tensors or None to process
        if not isinstance(data[0], torch.Tensor):
            raise InvalidArgumentError("Data should contain a torch Tensor with data at the first position.")
        if N < 1 or N > 4:
            raise InvalidArgumentError("N should be between 1 and 4.")
        data = (data + [None, ] * N)[:N]
        for d in data:
            if not isinstance(d, torch.Tensor) and d is not None:
                raise InvalidArgumentError("Data should contain only torch Tensors or None.")

        # If no mask is given, we create an empty mask as placeholder.
        if N > 2 and data[2] is None:
            x_mask = torch.ones(data[0].shape).to(self.device)
            if data[1] is not None:
                warn("Data length is given without mask. Assuming all sentences are of the same length. Sentences shorter than {} words will not be masked.".format(self.seq_len))

        # When the reversed data is not given, we assume no padding and reverse the sequence ourselves
        if N > 3 and data[3] is None:
            warn("Reversed data not provided. We assume no padding and reverse the data cheaply.")
            indices = torch.arange(data[0].shape[1] - 1, -1, -1)
            data[3] = x_in.index_select(1, indices)

        return data
Exemplo n.º 4
0
 def _gaussian_kl_divergence(self, mu_1, var_1, mu_2, var_2, mask, dim):
     """Computes the batch KL-divergence between two Gaussian distributions with diagonal covariance."""
     if mu_2 is None and var_2 is None:
         return 0.5 * torch.sum((-torch.log(var_1) + var_1 + mu_1**2 - 1) *
                                mask.unsqueeze(dim),
                                dim=dim)
     elif mu_2 is not None and var_2 is not None:
         return 0.5 * torch.sum(
             (torch.log(var_2) - torch.log(var_1) + var_1 / var_2 +
              (mu_2 - mu_1)**2 / var_2 - 1) * mask.unsqueeze(dim),
             dim=dim)
     else:
         raise InvalidArgumentError(
             "Either provide mu_2 and var_2 or neither.")
Exemplo n.º 5
0
    def constraint(self, vals):
        if type(vals) != list:
            if isinstance(vals, str):
                vals = [vals]
            else:
                raise InvalidArgumentError(
                    'constraint should be a list or str')
        for val in vals:
            if val not in ['mdr', 'mmd', 'elbo']:
                raise UnknownArgumentError(
                    'constraint {} unknown. Please choose [mdr, mmd].'.format(
                        val))

        self._constraint = vals
Exemplo n.º 6
0
 def _gaussian_sample_z(self, mu, var, shape, det):
     """Sample from a Gaussian distribution with mean mu and variance var."""
     if mu is None and var is None and shape is not None:
         if det:
             return torch.zeros(shape, device=self.device)
         else:
             return self.error.sample(shape)
     elif mu is not None and var is not None:
         if det:
             return mu
         else:
             return mu + torch.sqrt(var) * self.error.sample(var.shape)
     else:
         raise InvalidArgumentError("Provide either mu and var or neither with a shape.")
Exemplo n.º 7
0
    def __init__(self, device, seq_len, word_p, word_p_enc, parameter_p,
                 encoder_p, drop_type, min_rate, unk_index, css, N, rnn_type,
                 kl_step, beta, lamb, mmd, ann_mode, rate_mode, posterior,
                 hinge_weight, ann_word, word_step, v_dim, x_dim, h_dim, s_dim,
                 z_dim, l_dim, h_dim_enc, l_dim_enc, lagrangian, constraint,
                 max_mmd, max_elbo, alpha):
        super(GenerativeDecoder,
              self).__init__(device, seq_len, word_p, parameter_p, drop_type,
                             unk_index, css, N, rnn_type, v_dim, x_dim, h_dim,
                             s_dim, l_dim)
        # Recurrent dropout was never implemented for the VAE's because it doesn't work well
        if self.drop_type == "recurrent":
            raise InvalidArgumentError(
                "Recurrent dropout not implemented for this model. Please choose ['varied', 'shared']"
            )
        # LSTM's are not supported because GRU's work equally well (with less parameters)
        if self.rnn_type == "LSTM":
            raise InvalidArgumentError(
                "LSTM not implemented for this model. Please choose ['GRU']")

        # Choose between the vMF-autoencoder and Gauss-autoencoder
        self.posterior = posterior

        # Encoder architecture settings
        self.encoder_p = torch.tensor(encoder_p,
                                      device=self.device,
                                      dtype=torch.float)
        self.h_dim_enc = h_dim_enc
        self.l_dim_enc = l_dim_enc
        self.word_p_enc = word_p_enc

        # Optimization hyperparameters
        self.min_rate = torch.tensor(
            min_rate, device=self.device,
            dtype=torch.float)  # minimum Rate of hinge/FB
        self.beta = torch.tensor(beta, device=self.device,
                                 dtype=torch.float)  # beta value of beta-VAE
        self.alpha = torch.tensor(alpha, device=self.device,
                                  dtype=torch.float)  # alpha value of InfoVAE
        self.lamb = torch.tensor(lamb, device=self.device,
                                 dtype=torch.float)  # lambda value of InfoVAE
        self.kl_step = torch.tensor(
            kl_step, device=self.device,
            dtype=torch.float)  # Step size of KL annealing
        self.hinge_weight = torch.tensor(
            hinge_weight, device=self.device,
            dtype=torch.float)  # Weight of hinge loss
        # Step size of word dropout annealing
        self.word_step = torch.tensor(word_step,
                                      device=self.device,
                                      dtype=torch.float)
        self.max_mmd = torch.tensor(max_mmd,
                                    device=self.device,
                                    dtype=torch.float)  # Maximum MMD
        self.max_elbo = torch.tensor(max_elbo,
                                     device=self.device,
                                     dtype=torch.float)  # Maximum ELBO

        #  Optimization modes
        self.mmd = mmd  # When true, we add the maximum mean discrepancy to the loss, and optimize the InfoVAE
        self.ann_mode = ann_mode  # The mode of annealing
        self.rate_mode = rate_mode  # How to force the VAE to encode a minimum rate
        self.ann_word = ann_word  # Whether to anneal word dropout

        # The weight of the constraint in the Lagrangian dual function
        # Hardcoded start at 1.01
        self.lagrangian = lagrangian
        self.constraint = constraint
        self.lag_weight = Parameter(torch.tensor([1.01] *
                                                 len(self.constraint)))

        if self.ann_word:
            self.word_p = 1.

        self.z_dim = z_dim

        # We start the scale factor at zero, to be incremented linearly with kl_step every forward pass
        if self.ann_mode == "linear":
            self.scale = torch.tensor(self.kl_step.item() * self.beta.item(),
                                      device=self.device)
        # Or we start the scale at 10%, to be increased or decreased in 10% increments based on a desired rate
        elif self.ann_mode == "sfb":
            self.scale = torch.tensor(0.1 * self.beta.item(),
                                      device=self.device)

        # This switch should be manually managed from training/testing scripts to select generating from prior/posterior
        self.use_prior = False

        # N(0, I) error distribution to sample from latent spaces with reparameterized gradient
        self.error = Normal(torch.tensor(0., device=device),
                            torch.tensor(1., device=device))
Exemplo n.º 8
0
 def seq_len(self, val):
     if not isinstance(val, int):
         raise InvalidArgumentError("seq_len should be an integer.")
     self._seq_len = val