Ejemplo n.º 1
0
 def expanded_sample(self):
     # get the int from Variable or Tensor
     if self.n.data.dim() == 2:
         n = int(self.n.data.cpu()[0][0])
     else:
         n = int(self.n.data.cpu()[0])
     return Variable(torch_multinomial(self.ps.data, n, replacement=True))
Ejemplo n.º 2
0
 def expanded_sample(self):
     # get the int from Variable or Tensor
     if self.n.data.dim() == 2:
         n = int(self.n.data.cpu()[0][0])
     else:
         n = int(self.n.data.cpu()[0])
     return Variable(torch_multinomial(self.ps.data, n, replacement=True))
Ejemplo n.º 3
0
    def __init__(self,
                 input_dim,
                 hidden_dim,
                 output_dim_multiplier=1,
                 mask_encoding=None,
                 permutation=None):
        super(AutoRegressiveNN, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim_multiplier = output_dim_multiplier

        if mask_encoding is None:
            # the dependency structure is chosen at random
            self.mask_encoding = 1 + torch_multinomial(
                torch.ones(input_dim - 1) / (input_dim - 1),
                num_samples=hidden_dim,
                replacement=True)
        else:
            # the dependency structure is given by the user
            self.mask_encoding = mask_encoding

        if permutation is None:
            # a permutation is chosen at random
            self.permutation = torch.randperm(input_dim)
        else:
            # the permutation is chosen by the user
            self.permutation = permutation

        # these masks control the autoregressive structure
        self.mask1 = Variable(torch.zeros(hidden_dim, input_dim))
        self.mask2 = Variable(
            torch.zeros(input_dim * self.output_dim_multiplier, hidden_dim))

        for k in range(hidden_dim):
            # fill in mask1
            m_k = self.mask_encoding[k]
            slice_k = torch.cat(
                [torch.ones(m_k),
                 torch.zeros(input_dim - m_k)])
            for j in range(input_dim):
                self.mask1[k, self.permutation[j]] = slice_k[j]
            # fill in mask2
            slice_k = torch.cat(
                [torch.zeros(m_k),
                 torch.ones(input_dim - m_k)])
            for r in range(self.output_dim_multiplier):
                for j in range(input_dim):
                    self.mask2[r * input_dim + self.permutation[j],
                               k] = slice_k[j]

        self.lin1 = MaskedLinear(input_dim, hidden_dim, self.mask1)
        self.lin2 = MaskedLinear(hidden_dim, input_dim * output_dim_multiplier,
                                 self.mask2)
        self.relu = nn.ReLU()
Ejemplo n.º 4
0
    def sample(self):
        """
        Returns a sample which has the same shape as `ps`, except that the last dimension
        will have the same size as the number of events.

        :return: sample from the OneHotCategorical distribution
        :rtype: torch.Tensor
        """
        sample = torch_multinomial(self.ps.data, 1,
                                   replacement=True).expand(*self.shape())
        sample_one_hot = torch_zeros_like(self.ps.data).scatter_(-1, sample, 1)
        return Variable(sample_one_hot)
Ejemplo n.º 5
0
    def sample(self):
        """
        Returns a sample which has the same shape as `ps` (or `vs`). The type
        of the sample is `numpy.ndarray` if `vs` is a list or a numpy array,
        else a tensor is returned.

        :return: sample from the Categorical distribution
        :rtype: numpy.ndarray or torch.LongTensor
        """
        sample = torch_multinomial(self.ps.data, 1, replacement=True).expand(*self.shape())
        sample_one_hot = torch_zeros_like(self.ps.data).scatter_(-1, sample, 1)

        if self.vs is not None:
            if isinstance(self.vs, np.ndarray):
                sample_bool_index = sample_one_hot.cpu().numpy().astype(bool)
                return self.vs[sample_bool_index].reshape(*self.shape())
            else:
                return self.vs.masked_select(sample_one_hot.byte())
        return Variable(sample)