Beispiel #1
0
def gumbel_softmax(logits, dim=-1, tau=1, hard=False, eps=1e-10):
    """
    Sample from the Gumbel-Softmax distribution and optionally discretize.
    Args:
      logits: [batch_size, n_class] unnormalized log-probs
      dim: along which dim the softmax is performed
      tau: non-negative scalar temperature
      hard: if True, take argmax, but differentiate w.r.t. soft sample y
      eps: eps
    Returns:
      [batch_size, n_class] sample from the Gumbel-Softmax distribution.
      If hard=True, then the returned sample will be one-hot, otherwise it will
      be a probability distribution that sums to 1 across classes
    Constraints:
    - this implementation only works on batch_size x num_features tensor for now
    based on
    https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb ,
    (MIT license)
    """
    y_soft = _gumbel_softmax_sample(logits, tau=tau, eps=eps)
    if hard:
        _, k = y_soft.data.max(dim=dim)
        # this bit is based on
        # https://discuss.pytorch.org/t/stop-gradients-for-st-gumbel-softmax/530/5
        y_hard = torch.zeros_like(as_tensor(logits))
        set_index_one_hot_(y_hard, dim, k, 1.0)
        # this cool bit of code achieves two things:
        # - makes the output value exactly one-hot (since we add then
        #   subtract y_soft value)
        # - makes the gradient equal to y_soft gradient (since we strip
        #   all other gradients)
        y = var_with(y_hard - as_tensor(y_soft), y_soft) + y_soft
    else:
        y = y_soft
    return y
Beispiel #2
0
def reversed(x, dim=-1):
    # https://github.com/pytorch/pytorch/issues/229#issuecomment-350041662
    xsize = x.size()
    dim = x.dim() + dim if dim < 0 else dim
    x = x.view(-1, *xsize[dim:])
    inds = var_with(torch.arange(x.size(1) - 1, -1, -1).long(), x)
    x = x.view(x.size(0), x.size(1), -1)[:, inds, :]
    return x.view(xsize)
Beispiel #3
0
    def _extract_sent_feature(self, sent, length, gru):
        sent = self.embedding(sent)
        batch_size = sent.size(0)

        state_shape = (1, batch_size, self.hidden_dim)
        initial_state = var_with(torch.zeros(state_shape), sent)
        rnn_output, _ = rnn_with_length(gru, sent, length, initial_state)
        rnn_result = index_one_hot_ellipsis(rnn_output, 1, length - 1)

        return rnn_result
Beispiel #4
0
def _gumbel_softmax_sample(logits, dim=-1, tau=1, eps=1e-10):
    """
    Draw a sample from the Gumbel-Softmax distribution
    based on
    https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb
    (MIT license)
    """
    gumbel_noise = _sample_gumbel(logits.size(),
                                  eps=eps,
                                  out=as_tensor(logits).new())
    y = logits + var_with(gumbel_noise, logits)
    return F.softmax(y / tau, dim=dim)
Beispiel #5
0
def weighted_loss(loss, target, weight, ignore_index):
    if weight is not None:
        weight = var_with(weight, target)
        weight = weight[target]
    else:
        weight = 1
    if ignore_index is not None:
        weight *= (target.ne(ignore_index).float())

    if type(weight) is int and weight == 1:
        return loss.mean()
    else:
        return masked_average(loss, weight)
Beispiel #6
0
def meshgrid_exclude_self(input, dim=1):
    """
    Exclude self from the grid. Specifically, given an array a[i, j] of n * n, it produces
    a new array with size n * (n - 1) where only a[i, j] (i != j) is preserved.

    The operation is performed over dim and dim +1 axes.
    """
    n = input.size(dim)
    assert n == input.size(dim + 1)

    # exclude self-attention
    rng = var_with(torch.arange(0, n), input)
    rng_n1 = rng.unsqueeze(1).expand((n, n))
    rng_1n = rng.unsqueeze(0).expand((n, n))
    mask_self = (rng_n1 != rng_1n)

    for i in range(dim):
        mask_self.unsqueeze_(0)
    for j in range(input.dim() - dim - 2):
        mask_self.unsqueeze_(-1)
    target_shape = concat_shape(input.size()[:dim], n, n-1, input.size()[dim+2:])

    return input.masked_select(mask_self).view(target_shape)
Beispiel #7
0
def inverse_permutation(perm):
    length = perm.size(0)
    inv = var_with(perm.data.new(length).long().zero_(), perm)
    inv.scatter_(0, perm, var_with(torch.arange(0, length).long(), perm))
    return inv.long()
 def enc_txt(self, caps):
     sents, lengths, _, inv = _prepare_batch(caps, self.projector)
     inv = var_with(as_variable(inv), sents)
     out, x = self.model.txt_enc.forward(sents, lengths, True)
     return out[inv], x