Exemplo n.º 1
0
 def gumbel_sample(self, data) :
     """ We use gumbel sampling to pick on the mixtures based on their mixture 
     weights """
     distribution = Gumbel(loc=0, scale=1)
     z = distribution.sample()
     # z = np.random.gumbel(loc=0, scale=1, size=data.shape)
     return (torch.log(data) + z).argmax(dim=1)
Exemplo n.º 2
0
 def __init__(self, start_tokens: torch.LongTensor,
              end_token: Union[int, torch.LongTensor], tau: float,
              straight_through: bool = False,
              stop_gradient: bool = False, use_finish: bool = True):
     super().__init__(start_tokens, end_token, tau,
                      stop_gradient, use_finish)
     self._straight_through = straight_through
     # unit-scale, zero-location Gumbel distribution
     self._gumbel = Gumbel(loc=torch.tensor(0.0), scale=torch.tensor(1.0))
Exemplo n.º 3
0
    def __init__(self, do_entropy=False):
        '''

        '''
        super().__init__()
        self.gumbel = Gumbel(loc=0, scale=1)
        self.do_entropy = do_entropy
        if self.do_entropy:
            self.entropy = None
Exemplo n.º 4
0
def gumbel_with_maximum(phi, T, dim=-1):
    """
    Samples a set of gumbels which are conditioned on having a maximum along a dimension
    phi.max(dim)[0] should be broadcastable with the desired maximum T
    """
    # Gumbel with location phi, use PyTorch distributions so you cannot get -inf or inf (which causes trouble)
    g_phi = Gumbel(phi, torch.ones_like(phi)).rsample()
    Z, argmax = g_phi.max(dim)
    g = _shift_gumbel_maximum(g_phi, T, dim, Z=Z)
    return g, argmax
Exemplo n.º 5
0
    def __init__(self, bias=None):
        '''

        :param bias: a vector of log frequencies, to bias the sampling towards what we want
        '''
        super().__init__()
        self.gumbel = Gumbel(loc=0, scale=1)
        if bias is not None:
            self.bias = torch.from_numpy(np.array(bias)).to(dtype=torch.float32, device=device).unsqueeze(0)
        else:
            self.bias = None
Exemplo n.º 6
0
    def __init__(self, temperature=1.0, eps=0.0, do_entropy=True):
        '''

        :param bias: a vector of log frequencies, to bias the sampling towards what we want
        '''
        super().__init__()
        self.gumbel = Gumbel(loc=0, scale=1)
        self.temperature = torch.tensor(temperature, requires_grad=False)
        self.eps = eps
        self.do_entropy = do_entropy
        if self.do_entropy:
            self.entropy = None
Exemplo n.º 7
0
def cond_gumbel_sample(all_joint_log_probs,
                       perturbed_log_probs) -> torch.Tensor:
    # Sample plates x k? x |D_yv| Gumbel variables
    gumbel_d = Gumbel(loc=all_joint_log_probs, scale=1.0)
    G_yv = gumbel_d.rsample()

    # Condition the Gumbel samples on the maximum of previous samples
    # plates x k
    Z = G_yv.max(dim=-1)[0]
    T = perturbed_log_probs
    vi = T - G_yv + log1mexp(G_yv - Z.unsqueeze(-1))
    # plates (x k) x |D_yv|
    return T - vi.relu() - torch.nn.Softplus()(-vi.abs())
Exemplo n.º 8
0
    def __init__(
        self,
        shape: float = 1.0,
        topk: Optional[int] = None,
        equiv_len: Optional[int] = None,
        log_scores: bool = False,
    ):
        """FréchetSort is a softer version of descending sort which samples all possible
        orderings of items favoring orderings which resemble descending sort. This can
        be used to convert descending sort by rank score into a differentiable,
        stochastic policy amenable to policy gradient algorithms.

        :param shape: parameter of Frechet Distribution. Lower values correspond to
        aggressive deviations from descending sort.
        :param topk: If specified, only the first topk actions are specified.
        :param equiv_len: Orders are considered equivalent if the top equiv_len match. Used
            in probability computations.
            Essentially specifies the action space.
        :param log_scores Scores passed in are already log-transformed. In this case, we would
        simply add Gumbel noise.
        For LearnVM, we set this to be True because we expect input and output scores
        to be in the log space.

        Example:

        Consider the sampler:

        sampler = FrechetSort(shape=3, topk=5, equiv_len=3)

        Given a set of scores, this sampler will produce indices of items roughly
        resembling a argsort by scores in descending order. The higher the shape,
        the more it would resemble a descending argsort. `topk=5` means only the top
        5 ranks will be output. The `equiv_len` determines what orders are considered
        equivalent for probability computation. In this example, the sampler will
        produce probability for the top 3 items appearing in a given order for the
        `log_prob` call.
        """
        self.shape = shape
        self.topk = topk
        self.upto = equiv_len
        if topk is not None:
            if equiv_len is None:
                self.upto = topk
            # pyre-fixme[58]: `>` is not supported for operand types `Optional[int]`
            #  and `Optional[int]`.
            if self.upto > self.topk:
                raise ValueError(
                    f"Equiv length {equiv_len} cannot exceed topk={topk}.")
        self.gumbel_noise = Gumbel(0, 1.0 / shape)
        self.log_scores = log_scores
Exemplo n.º 9
0
class SoftmaxRandomSamplePolicy(SimplePolicy):
    '''
    Randomly samples from the softmax of the logits
    # https://hips.seas.harvard.edu/blog/2013/04/06/the-gumbel-max-trick-for-discrete-distributions/
    TODO: should probably switch that to something more like
    http://pytorch.org/docs/master/distributions.html
    '''
    def __init__(self, bias=None):
        '''

        :param bias: a vector of log frequencies, to bias the sampling towards what we want
        '''
        super().__init__()
        self.gumbel = Gumbel(loc=0, scale=1)
        if bias is not None:
            self.bias = torch.from_numpy(np.array(bias)).to(dtype=torch.float32, device=device).unsqueeze(0)
        else:
            self.bias = None
    def forward(self, logits: Variable):
        '''

        :param logits: Logits to generate probabilities from, batch_size x out_dim float32
        :return:
        '''
        eff_logits = self.effective_logits(logits)
        x = self.gumbel.sample(logits.shape).to(device) + eff_logits
        _, out = torch.max(x, -1)
        return out

    def effective_logits(self, logits):
        if self.bias is not None:
            logits += self.bias
        return logits
Exemplo n.º 10
0
 def log_sample(self, x, u, n, alpha=1., xg=None, ug=None):
     c = self.c.cost(x, u, xg, ug)
     costs = c.reshape(-1, 1)  # unnormalized log probabilities
     samples = Gumbel(loc=0., scale=1.).sample((x.shape[0], n))
     log_gumbel = alpha * costs + samples
     _, choices = torch.max(log_gumbel, 0)
     return choices, alpha * costs
def reinforce_unordered(conditional_loss_fun,
                        log_class_weights,
                        class_weights_detached,
                        seq_tensor,
                        z_sample,
                        epoch,
                        data,
                        n_samples=1,
                        baseline_separate=False,
                        baseline_n_samples=1,
                        baseline_deterministic=False,
                        baseline_constant=None):

    # Sample without replacement using Gumbel top-k trick
    phi = log_class_weights.detach()
    g_phi = Gumbel(phi, torch.ones_like(phi)).rsample()
    _, ind = g_phi.topk(n_samples, -1)

    log_p = log_class_weights.gather(-1, ind)
    n_classes = log_class_weights.shape[1]
    costs = torch.stack([
        conditional_loss_fun(get_one_hot_encoding_from_int(
            z_sample, n_classes)) for z_sample in ind.t()
    ], -1)

    with torch.no_grad():  # Don't compute gradients for advantage and ratio
        # log_R_s, log_R_ss = compute_log_R(log_p)
        log_R_s, log_R_ss = compute_log_R_O_nfac(log_p)

        if baseline_constant is not None:
            bl_vals = baseline_constant
        elif baseline_separate:
            bl_vals = get_baseline(conditional_loss_fun, log_class_weights,
                                   baseline_n_samples, baseline_deterministic)
            # Same bl for all samples, so add dimension
            bl_vals = bl_vals[:, None]
        elif log_p.size(-1) > 1:
            # Compute built in baseline
            bl_vals = ((log_p[:, None, :] + log_R_ss).exp() *
                       costs[:, None, :]).sum(-1)
        else:
            bl_vals = 0.  # No bl
        adv = costs - bl_vals
    # Also add the costs (with the unordered estimator) in case there is a direct dependency on the parameters
    loss = ((log_p + log_R_s).exp() * adv.detach() +
            (log_p + log_R_s).exp().detach() * costs).sum(-1)
    return loss
Exemplo n.º 12
0
def sample_stgs(proba_no_softmax, K, temp):
    device=proba_no_softmax.device
    dims=(proba_no_softmax.size(0),K)
    # sample with GS
    mean=torch.zeros(dims,device=device)
    scale=torch.ones(dims,device=device)
    samp_gumb=Gumbel(mean,scale).sample()
    g=samp_gumb.detach()
    out=(proba_no_softmax + g) / temp
    # soft approximation
    y=F.softmax(out,dim=-1)
    # use argmax at forward pass
    symb=y.argmax(dim=-1)
    # go throught one_hot for backward pass
    one_hot_symb=one_hot(symb,K)
    sample=(one_hot_symb - y).detach() + y
    return sample
Exemplo n.º 13
0
    def __init__(
        self,
        shape: float = 1,
        upto: Optional[int] = None
    ):
        """Samples an ordering

        Args:
            shape: 1/scale parameter of Gumbel distribution
                At lower temperatures, order is closer to argmax(scores)
            upto [Default: None]: If provided, position upto which log_prob is computed

        Returns:

        An action sampler which produces ordering
        """
        self.shape = shape
        self.upto = upto
        self.gumbel_noise = Gumbel(0, 1.0 / shape)
Exemplo n.º 14
0
def sample_gumbel_softmax(logits, temperature, n_sample):
    """
        Input:
        logits: Tensor of log probs, shape = BS x k
        temperature = scalar

        Output: Tensor of values sampled from Gumbel softmax.
                These will tend towards a one-hot representation in the limit of temp -> 0
                shape = n_sample x BS x k
    """
    g = Gumbel(torch.zeros(*logits.shape), torch.ones(*logits.shape)).rsample(
        (n_sample, )).to(device)
    h = (g + logits) / temperature
    y = F.softmax(h, dim=-1)
    return y
Exemplo n.º 15
0
class FrechetOrderSampler:
    def __init__(
        self,
        shape: float = 1,
        upto: Optional[int] = None
    ):
        """Samples an ordering

        Args:
            shape: 1/scale parameter of Gumbel distribution
                At lower temperatures, order is closer to argmax(scores)
            upto [Default: None]: If provided, position upto which log_prob is computed

        Returns:

        An action sampler which produces ordering
        """
        self.shape = shape
        self.upto = upto
        self.gumbel_noise = Gumbel(0, 1.0 / shape)

    def sample(
        self,
        scores: torch.Tensor
    ):
        """Sample an ordering given scores"""
        perturbed = torch.log(scores) + self.gumbel_noise.sample((len(scores),))
        return torch.argsort(-perturbed.detach())

    def log_prob(self, scores : torch.Tensor, permutations):
        """Compute log probability given scores and an action (a permutation).
        The formula uses the equivalence of sorting with Gumbel noise and
        Plackett-Luce model (See Yellot 1977)

        Args:
            scores: scores of different items
            action: prescribed (permutation) order of the items
        """
        s = torch.log(select_indices(scores, permutations))
        n = len(scores)
        p = self.upto if self.upto is not None else n - 1
        return -sum(
            torch.log(torch.exp((s[k:] - s[k]) * self.shape).sum(dim=0))
            for k in range(p))
Exemplo n.º 16
0
    def sample_model_spec(self, num):
        """
        Override, sample the alpha via gumbel softmax instead of normal softmax.
        :param num:
        :return:
        """
        alpha_topology = self.alpha_topology.detach().clone()
        alpha_ops = self.alpha_ops.detach().clone()
        sample_archs = []
        sample_ops = []
        gumbel_dist = Gumbel(torch.tensor([.0]), torch.tensor([1.0]))
        with torch.no_grad():
            for i in range(self.num_intermediate_nodes):
                # align with topoligy weights
                probs = gumbel_softmax(alpha_topology[: i+2, i], self.temperature(), gumbel_dist)
                sample_archs.append(Categorical(probs))
                probs_op = gumbel_softmax(alpha_ops[:, i], self.temperature(), gumbel_dist)
                sample_ops.append(Categorical(probs_op))

            return self._sample_model_spec(num, sample_archs, sample_ops)
Exemplo n.º 17
0
class SoftmaxRandomSamplePolicySparse(SimplePolicy):
    '''
    Randomly samples from the softmax of the logits
    # https://hips.seas.harvard.edu/blog/2013/04/06/the-gumbel-max-trick-for-discrete-distributions/
    TODO: should probably switch that to something more like
    http://pytorch.org/docs/master/distributions.html
    '''
    def __init__(self, do_entropy=False):
        '''

        '''
        super().__init__()
        self.gumbel = Gumbel(loc=0, scale=1)
        self.do_entropy = do_entropy
        if self.do_entropy:
            self.entropy = None

    def forward(self, all_logits_list, all_action_inds_list):
        '''

        :param logits: list of tensors of len batch_size
        :param action_inds: list of long tensors with indices of the actions corresponding to the logits
        :return:
        '''
        # The number of feasible actions is differnt for every item in the batch, so for loop is the simplest
        logp = []
        out_actions = []
        for logits, action_inds in zip(all_logits_list, all_action_inds_list):
            if len(logits):
                x = self.gumbel.sample(logits.shape).to(device=device, dtype=logits.dtype) + logits
                _, out = torch.max(x, -1)
                this_logp = F.log_softmax(logits)[out]
                out_actions.append(action_inds[out])
                logp.append(this_logp)
            else:
                out_actions.append(torch.tensor(0, device=logits.device, dtype=action_inds.dtype))
                logp.append(torch.tensor(0.0, device=logits.device, dtype=logits.dtype))
        self.logp = torch.stack(logp)
        out = torch.stack(out_actions)
        return out
Exemplo n.º 18
0
 def ops_weights(self, node):
     gumbel_dist = Gumbel(torch.tensor([.0]), torch.tensor([1.0]))
     return gumbel_softmax(self.alpha_ops[:, node], self.temperature(), gumbel_dist)
Exemplo n.º 19
0
 def topology_weights(self, node):
     # return the soft weights for topology aggregation
     gumbel_dist = Gumbel(torch.tensor([.0]), torch.tensor([1.0]))
     return gumbel_softmax(self.alpha_topology[: node + 2, node], self.temperature(), gumbel_dist)[1:]
Exemplo n.º 20
0
def rsample_gumbel(
    distr: Distribution,
    n: int,
) -> torch.Tensor:
    gumbel_distr = Gumbel(distr.logits, 1)
    return gumbel_distr.rsample((n, ))
def reinforce_sum_and_sample(conditional_loss_fun,
                             log_class_weights,
                             class_weights_detached,
                             seq_tensor,
                             z_sample,
                             epoch,
                             data,
                             n_samples=1,
                             baseline_separate=False,
                             baseline_n_samples=1,
                             baseline_deterministic=False,
                             rao_blackwellize=False):

    # Sample without replacement using Gumbel top-k trick
    phi = log_class_weights.detach()
    g_phi = Gumbel(phi, torch.ones_like(phi)).rsample()

    _, ind = g_phi.topk(n_samples, -1)

    log_p = log_class_weights.gather(-1, ind)
    n_classes = log_class_weights.shape[1]
    costs = torch.stack([
        conditional_loss_fun(get_one_hot_encoding_from_int(
            z_sample, n_classes)) for z_sample in ind.t()
    ], -1)

    with torch.no_grad():  # Don't compute gradients for advantage and ratio
        if baseline_separate:
            bl_vals = get_baseline(conditional_loss_fun, log_class_weights,
                                   baseline_n_samples, baseline_deterministic)
            # Same bl for all samples, so add dimension
            bl_vals = bl_vals[:, None]
        else:
            assert baseline_n_samples < n_samples
            bl_sampled_weight = log1mexp(
                log_p[:, :baseline_n_samples -
                      1].logsumexp(-1)).exp().detach()
            bl_vals = (log_p[:, :baseline_n_samples - 1].exp() * costs[:, :baseline_n_samples -1]).sum(-1)\
                      + bl_sampled_weight * costs[:, baseline_n_samples - 1]
            bl_vals = bl_vals[:, None]

    # We compute an 'exact' gradient if the sum of probabilities is roughly more than 1 - 1e-5
    # in which case we can simply sum al the terms and the relative error will be < 1e-5
    use_exact = log_p.logsumexp(-1) > -1e-5
    not_use_exact = use_exact == 0

    cost_exact = costs[use_exact]
    exact_loss = compute_summed_terms(log_p[use_exact], cost_exact,
                                      cost_exact - bl_vals[use_exact])

    log_p_est = log_p[not_use_exact]
    costs_est = costs[not_use_exact]
    bl_vals_est = bl_vals[not_use_exact]

    if rao_blackwellize:
        ap = all_perms(torch.arange(n_samples, dtype=torch.long),
                       device=log_p_est.device)
        log_p_ap = log_p_est[:, ap]
        bl_vals_ap = bl_vals_est.expand_as(costs_est)[:, ap]
        costs_ap = costs_est[:, ap]
        cond_losses = compute_sum_and_sample_loss(log_p_ap, costs_ap,
                                                  bl_vals_ap)

        # Compute probabilities for permutations
        log_probs_perms = log_pl_rec(log_p_ap, -1)
        cond_log_probs_perms = log_probs_perms - log_probs_perms.logsumexp(
            -1, keepdim=True)
        losses = (cond_losses * cond_log_probs_perms.exp()).sum(-1)
    else:
        losses = compute_sum_and_sample_loss(log_p_est, costs_est, bl_vals_est)

    # If they are summed we can simply concatenate but for consistency it is best to place them in order
    all_losses = log_p.new_zeros(log_p.size(0))
    all_losses[use_exact] = exact_loss
    all_losses[not_use_exact] = losses
    return all_losses
Exemplo n.º 22
0
class FrechetSort(Sampler):
    EPS = 1e-12

    @resolve_defaults
    def __init__(
        self,
        shape: float = 1.0,
        topk: Optional[int] = None,
        equiv_len: Optional[int] = None,
        log_scores: bool = False,
    ):
        """FréchetSort is a softer version of descending sort which samples all possible
        orderings of items favoring orderings which resemble descending sort. This can
        be used to convert descending sort by rank score into a differentiable,
        stochastic policy amenable to policy gradient algorithms.

        :param shape: parameter of Frechet Distribution. Lower values correspond to
        aggressive deviations from descending sort.
        :param topk: If specified, only the first topk actions are specified.
        :param equiv_len: Orders are considered equivalent if the top equiv_len match. Used
            in probability computations.
            Essentially specifies the action space.
        :param log_scores Scores passed in are already log-transformed. In this case, we would
        simply add Gumbel noise.
        For LearnVM, we set this to be True because we expect input and output scores
        to be in the log space.

        Example:

        Consider the sampler:

        sampler = FrechetSort(shape=3, topk=5, equiv_len=3)

        Given a set of scores, this sampler will produce indices of items roughly
        resembling a argsort by scores in descending order. The higher the shape,
        the more it would resemble a descending argsort. `topk=5` means only the top
        5 ranks will be output. The `equiv_len` determines what orders are considered
        equivalent for probability computation. In this example, the sampler will
        produce probability for the top 3 items appearing in a given order for the
        `log_prob` call.
        """
        self.shape = shape
        self.topk = topk
        self.upto = equiv_len
        if topk is not None:
            if equiv_len is None:
                self.upto = topk
            # pyre-fixme[58]: `>` is not supported for operand types `Optional[int]`
            #  and `Optional[int]`.
            if self.upto > self.topk:
                raise ValueError(
                    f"Equiv length {equiv_len} cannot exceed topk={topk}.")
        self.gumbel_noise = Gumbel(0, 1.0 / shape)
        self.log_scores = log_scores

    def sample_action(self, scores: torch.Tensor) -> rlt.ActorOutput:
        """Sample a ranking according to Frechet sort. Note that possible_actions_mask
        is ignored as the list of rankings scales exponentially with slate size and
        number of items and it can be difficult to enumerate them."""
        assert scores.dim() == 2, "sample_action only accepts batches"
        log_scores = scores if self.log_scores else torch.log(scores)
        perturbed = log_scores + self.gumbel_noise.sample(scores.shape)
        action = torch.argsort(perturbed.detach(), descending=True)
        log_prob = self.log_prob(scores, action)
        # Only truncate the action before returning
        if self.topk is not None:
            action = action[:self.topk]
        return rlt.ActorOutput(action, log_prob)

    def log_prob(
        self,
        scores: torch.Tensor,
        action: torch.Tensor,
        equiv_len_override: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        What is the probability of a given set of scores producing the given
        list of permutations only considering the top `equiv_len` ranks?

        We may want to override the default equiv_len here when we know the having larger
        action space doesn't matter. i.e. in Reels
        """
        upto = self.upto
        if equiv_len_override is not None:
            assert equiv_len_override.shape == (
                scores.shape[0],
            ), f"Invalid shape {equiv_len_override.shape}, compared to scores {scores.shape}. equiv_len_override {equiv_len_override}"
            upto = equiv_len_override.long()
            if self.topk is not None and torch.any(
                    equiv_len_override > self.topk):
                raise ValueError(
                    f"Override {equiv_len_override} cannot exceed topk={self.topk}."
                )

        squeeze = False
        if len(scores.shape) == 1:
            squeeze = True
            scores = scores.unsqueeze(0)
            action = action.unsqueeze(0)

        assert len(action.shape) == len(
            scores.shape) == 2, "scores should be batch"
        if action.shape[1] > scores.shape[1]:
            raise ValueError(
                f"action cardinality ({action.shape[1]}) is larger than the number of scores ({scores.shape[1]})"
            )
        elif action.shape[1] < scores.shape[1]:
            raise NotImplementedError(
                f"This semantic is ambiguous. If you have shorter slate, pad it with scores.shape[1] ({scores.shape[1]})"
            )

        log_scores = scores if self.log_scores else torch.log(scores)
        n = log_scores.shape[-1]
        # Add scores for the padding value
        log_scores = torch.cat(
            [
                log_scores,
                torch.full((log_scores.shape[0], 1),
                           -math.inf,
                           device=log_scores.device),
            ],
            dim=1,
        )
        s = torch.gather(log_scores, 1, action) * self.shape

        p = upto if upto is not None else n
        # We should unsqueeze here
        if isinstance(p, int):
            probs = sum(
                torch.nan_to_num(F.log_softmax(s[:, i:], dim=1)[:, 0],
                                 neginf=0.0) for i in range(p))
        elif isinstance(p, torch.Tensor):
            # do masked sum
            probs = sum(
                torch.nan_to_num(F.log_softmax(s[:, i:], dim=1)[:, 0],
                                 neginf=0.0) * (i < p).float()
                for i in range(n))
        else:
            raise RuntimeError(f"p is {p}")
        return probs
Exemplo n.º 23
0
    def sample_genotype_index(self):
        with torch.no_grad():
            sampled_alphas = gumbel_softmax_sample(
                self.gumble_arch_params.squeeze(),
                temperature=self.gumbel_temperature,
                dim=0)
            best_sampled_alphas = torch.argmax(sampled_alphas, dim=0)
            return best_sampled_alphas.detach()


class MixedSequential(torch.nn.Sequential):
    def forward(self, input):
        total_cost = 0
        for module in self._modules.values():
            input = module(input)
            if isinstance(module, MixedModule):
                input, cost = input
                total_cost += cost
        return input


gumbel_dist = Gumbel(0.0, 1.0)


def gumbel_softmax_sample(logits, temperature, dim=None, std=1.0):
    y = logits + gumbel_dist.sample(logits.shape).to(device=logits.device,
                                                     dtype=logits.dtype)
    # y = logits + sample_gumbel(logits=logits, std=std)
    return F.softmax(y / temperature, dim=dim)
Exemplo n.º 24
0
class FrechetSort(Sampler):
    @resolve_defaults
    def __init__(
        self,
        shape: float = 1.0,
        topk: Optional[int] = None,
        equiv_len: Optional[int] = None,
        log_scores: bool = False,
    ):
        """FréchetSort is a softer version of descending sort which samples all possible
        orderings of items favoring orderings which resemble descending sort. This can
        be used to convert descending sort by rank score into a differentiable,
        stochastic policy amenable to policy gradient algorithms.

        :param shape: parameter of Frechet Distribution. Lower values correspond to
        aggressive deviations from descending sort.
        :param topk: If specified, only the first topk actions are specified.
        :param equiv_len: Orders are considered equivalent if the top equiv_len match. Used
            in probability computations
        :param log_scores Scores passed in are already log-transformed. In this case, we would
        simply add Gumbel noise.

        Example:

        Consider the sampler:

        sampler = FrechetSort(shape=3, topk=5, equiv_len=3)

        Given a set of scores, this sampler will produce indices of items roughly
        resembling a argsort by scores in descending order. The higher the shape,
        the more it would resemble a descending argsort. `topk=5` means only the top
        5 ranks will be output. The `equiv_len` determines what orders are considered
        equivalent for probability computation. In this example, the sampler will
        produce probability for the top 3 items appearing in a given order for the
        `log_prob` call.
        """
        self.shape = shape
        self.topk = topk
        self.upto = equiv_len
        if topk is not None:
            if equiv_len is None:
                self.upto = topk
            # pyre-fixme[58]: `>` is not supported for operand types `Optional[int]`
            #  and `Optional[int]`.
            if self.upto > self.topk:
                raise ValueError(f"Equiv length {equiv_len} cannot exceed topk={topk}.")
        self.gumbel_noise = Gumbel(0, 1.0 / shape)
        self.log_scores = log_scores

    @staticmethod
    def select_indices(scores: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
        """Helper for scores[actions] that are also works for batched tensors"""
        if len(actions.shape) > 1:
            num_rows = scores.size(0)
            row_indices = torch.arange(num_rows).unsqueeze(0).T  # pyre-ignore[ 16 ]
            return scores[row_indices, actions].T
        else:
            return scores[actions]

    def sample_action(self, scores: torch.Tensor) -> rlt.ActorOutput:
        """Sample a ranking according to Frechet sort. Note that possible_actions_mask
        is ignored as the list of rankings scales exponentially with slate size and
        number of items and it can be difficult to enumerate them."""
        assert scores.dim() == 2, "sample_action only accepts batches"
        log_scores = scores if self.log_scores else torch.log(scores)
        perturbed = log_scores + self.gumbel_noise.sample((scores.shape[1],))
        action = torch.argsort(perturbed.detach(), descending=True)
        if self.topk is not None:
            action = action[: self.topk]
        log_prob = self.log_prob(scores, action)
        return rlt.ActorOutput(action, log_prob)

    def log_prob(self, scores: torch.Tensor, action) -> torch.Tensor:
        """What is the probability of a given set of scores producing the given
        list of permutations only considering the top `equiv_len` ranks?"""
        log_scores = scores if self.log_scores else torch.log(scores)
        s = self.select_indices(log_scores, action)
        n = len(log_scores)
        p = self.upto if self.upto is not None else n
        return -sum(
            torch.log(torch.exp((s[k:] - s[k]) * self.shape).sum(dim=0))
            for k in range(p)  # pyre-ignore
        )
Exemplo n.º 25
0
def gumbel_softmax_sample(logits, temp=1.):
    g = Gumbel(0, 1).sample(logits.shape)
    y = (g + logits) / temp
    return torch.softmax(y, dim=-1)
Exemplo n.º 26
0
class SoftmaxRandomSamplePolicy(SimplePolicy):
    '''
    Randomly samples from the softmax of the logits
    # https://hips.seas.harvard.edu/blog/2013/04/06/the-gumbel-max-trick-for-discrete-distributions/
    TODO: should probably switch that to something more like
    http://pytorch.org/docs/master/distributions.html
    '''
    def __init__(self, temperature=1.0, eps=0.0, do_entropy=True):
        '''

        :param bias: a vector of log frequencies, to bias the sampling towards what we want
        '''
        super().__init__()
        self.gumbel = Gumbel(loc=0, scale=1)
        self.temperature = torch.tensor(temperature, requires_grad=False)
        self.eps = eps
        self.do_entropy = do_entropy
        if self.do_entropy:
            self.entropy = None

    def set_temperature(self, new_temperature):
        if self.temperature != new_temperature:
            self.temperature = torch.tensor(new_temperature,
                                            dtype=self.temperature.dtype,
                                            device=self.temperature.device)

    def forward(self, logits: Variable, priors=None):
        '''

        :param logits: Logits to generate probabilities from, batch_size x out_dim float32
        :return:
        '''
        # epsilon-greediness
        if random.random() < self.eps:
            if priors is None:
                new_logits = torch.zeros_like(logits)
                new_logits[logits < logits.max()-1000] = -1e4
            else:
                new_logits = priors
            if self.do_entropy:
                self.entropy = torch.zeros_like(logits).sum(dim=1)
        else:
            if self.temperature.dtype != logits.dtype or self.temperature.device != logits.device:
                self.temperature = torch.tensor(self.temperature,
                                                dtype=logits.dtype,
                                                device=logits.device,
                                                requires_grad=False)
            # print('temperature: ', self.temperature)


            # temperature is applied to model only, not priors!
            if priors is None:
                new_logits = logits/self.temperature
                raw_logits = new_logits
            else:
                raw_logits = (logits-priors)/self.temperature
                new_logits = priors + raw_logits
            if self.do_entropy:
                raw_logits_normalized = F.log_softmax(raw_logits, dim=1)
                self.entropy = torch.sum(-raw_logits_normalized*torch.exp(raw_logits_normalized), dim=1)

        eff_logits = new_logits
        x = self.gumbel.sample(logits.shape).to(device=device, dtype=eff_logits.dtype) + eff_logits
        _, out = torch.max(x, -1)
        all_logp = F.log_softmax(eff_logits, dim=1)
        self.logp = torch.cat([this_logp[this_ind:(this_ind+1)] for this_logp, this_ind in zip(all_logp, out)])
        return out

    def effective_logits(self, logits):
        return logits
Exemplo n.º 27
0
class GumbelSoftmaxEmbeddingHelper(SoftmaxEmbeddingHelper):
    r"""A helper that feeds Gumbel softmax sample to the next step.

    Uses the Gumbel softmax vector to pass through word embeddings to
    get the next input (i.e., a mixed word embedding).

    A subclass of :class:`~texar.modules.Helper`. Used as a helper to
    :class:`~texar.modules.RNNDecoderBase` in inference mode.

    Same as :class:`~texar.modules.SoftmaxEmbeddingHelper` except that here
    Gumbel softmax (instead of softmax) is used.

    Args:
        embedding: A callable or the ``params`` argument for
            :torch_nn:`functional.embedding`.
            If a callable, it can take a vector tensor of ``ids`` (argmax
            ids), or take two arguments (``ids``, ``times``), where ``ids``
            is a vector of argmax ids, and ``times`` is a vector of current
            time steps (i.e., position ids). The latter case can be used
            when :attr:`embedding` is a combination of word embedding and
            position embedding.
            The returned tensor will be passed to the decoder input.
        start_tokens: 1D :tensor:`LongTensor` shaped ``[batch_size]``,
            representing the start tokens for each sequence in batch.
        end_token: Python int or scalar :tensor:`LongTensor`, denoting the
            token that marks end of decoding.
        tau: A float scalar tensor, the softmax temperature.
        straight_through (bool): Whether to use straight through gradient
            between time steps. If `True`, a single token with highest
            probability (i.e., greedy sample) is fed to the next step and
            gradient is computed using straight through. If `False`
            (default), the soft Gumbel-softmax distribution is fed to the
            next step.
        stop_gradient (bool): Whether to stop the gradient backpropagation
            when feeding softmax vector to the next step.
        use_finish (bool): Whether to stop decoding once :attr:`end_token`
            is generated. If `False`, decoding will continue until
            :attr:`max_decoding_length` of the decoder is reached.

    Raises:
        ValueError: if :attr:`start_tokens` is not a 1D tensor or
            :attr:`end_token` is not a scalar.
    """

    def __init__(self, start_tokens: torch.LongTensor,
                 end_token: Union[int, torch.LongTensor], tau: float,
                 straight_through: bool = False,
                 stop_gradient: bool = False, use_finish: bool = True):
        super().__init__(start_tokens, end_token, tau,
                         stop_gradient, use_finish)
        self._straight_through = straight_through
        # unit-scale, zero-location Gumbel distribution
        self._gumbel = Gumbel(loc=torch.tensor(0.0), scale=torch.tensor(1.0))

    def sample(self, time: int, outputs: torch.Tensor) -> torch.Tensor:
        r"""Returns ``sample_id`` of shape ``[batch_size, vocab_size]``. If
        :attr:`straight_through` is `False`, this contains the Gumbel softmax
        distributions over vocabulary with temperature :attr:`tau`. If
        :attr:`straight_through` is `True`, this contains one-hot vectors of
        the greedy samples.
        """
        gumbel_samples = self._gumbel.sample(outputs.size()).to(
            device=outputs.device, dtype=outputs.dtype)
        sample_ids = torch.softmax(
            (outputs + gumbel_samples) / self._tau, dim=-1)
        if self._straight_through:
            argmax_ids = torch.argmax(sample_ids, dim=-1).unsqueeze(1)
            sample_ids_hard = torch.zeros_like(sample_ids).scatter_(
                dim=-1, index=argmax_ids, value=1.0)  # one-hot vectors
            sample_ids = (sample_ids_hard - sample_ids).detach() + sample_ids
        return sample_ids
Exemplo n.º 28
0
    def forward(self, inputs, hidden_state):
        inputs = inputs.reshape(-1, self.input_shape)
        h_in = hidden_state.reshape(-1, self.hidden_dim)

        #compute latent
        self.latent = F.softmax(self.embed_fc(inputs[:self.n_agents, - self.n_agents:]),dim=1)  #(n, pi)

        latent_embed = self.latent.unsqueeze(0).expand(self.bs, self.n_agents, self.latent_dim).reshape(
            self.bs * self.n_agents, self.latent_dim) #(bs*n,pi)

        latent_infer = F.relu(self.inference_fc1(th.cat([h_in.detach(), inputs[:, :-self.n_agents]], dim=1)))
        latent_infer = F.softmax(self.inference_fc2(latent_infer),dim=1) # (bs*n,pi)

        #self.latent = self.embed_fc(inputs[:self.n_agents, - self.n_agents:])  # (n,2*latent_dim)==(n,mu+log var)
        #self.latent[:, -self.latent_dim:] = th.exp(self.latent[:, -self.latent_dim:])  # var
        #latent_embed = self.latent.unsqueeze(0).expand(self.bs, self.n_agents, self.latent_dim * 2).reshape(
        #    self.bs * self.n_agents, self.latent_dim * 2)

        #latent_infer = F.relu(self.inference_fc1(th.cat([h_in, inputs[:, :-self.n_agents]], dim=1)))
        #latent_infer = self.inference_fc2(latent_infer)  # (n,2*latent_dim)==(n,mu+log var)
        #latent_infer[:, -self.latent_dim:] = th.exp(latent_infer[:, -self.latent_dim:])
        # loss
        #loss=(latent_embed-latent_infer).norm(dim=1).sum()
        #loss= -(latent_embed * th.log(latent_infer+self.eps)).sum()/(self.bs*self.n_agents)
        loss = -(latent_infer * th.log(latent_embed + self.eps)).sum() / (self.bs * self.n_agents)
        #loss = -((latent_infer * th.log(latent_embed + self.eps)).sum() + 0.01*(latent_embed*th.log(latent_embed+self.eps)).sum())/ (self.bs * self.n_agents)

        # sample
        g=Gumbel(0.0,1.0)
        latent_embed = F.softmax(th.log(latent_embed+self.eps) + g.sample(latent_embed.size()).cuda(), dim=1)  # softmax onehot
        #latent_infer = F.softmax(th.log(latent_infer+self.eps) + g.sample(latent_infer.size()), dim=1)


        #gaussian_embed = D.Normal(latent_embed[:, :self.latent_dim], (latent_embed[:, self.latent_dim:])**(1/2))
        #gaussian_infer = D.Normal(latent_infer[:, :self.latent_dim], (latent_infer[:, self.latent_dim:])**(1/2))

        #loss = gaussian_embed.entropy().sum() + kl_divergence(gaussian_embed, gaussian_infer).sum()  # CE = H + KL
        #loss = loss / (self.bs*self.n_agents)
        # handcrafted reparameterization
        # (1,n*latent_dim)                            (1,n*latent_dim)==>(bs,n*latent*dim)
        # latent_embed = self.latent[:,:self.latent_dim].reshape(1,-1)+self.latent[:,-self.latent_dim:].reshape(1,-1)*th.randn(self.bs,self.n_agents*self.latent_dim)
        # latent_embed = latent_embed.reshape(-1,self.latent_dim)  #(bs*n,latent_dim)
        # latent_infer = latent_infer[:, :self.latent_dim] + latent_infer[:, -self.latent_dim:] * th.randn_like(latent_infer[:, -self.latent_dim:])
        # loss= (latent_embed-latent_infer).norm(dim=1).sum()/(self.bs*self.n_agents)

        #latent = gaussian_embed.rsample()
        latent=self.latent_fc1(latent_embed)

        #latent = F.relu(self.latent_fc1(latent))
        #latent = (self.latent_fc2(latent))

        # latent=latent.reshape(-1,self.args.latent_dim)

        # fc1_w=F.relu(self.fc1_w_nn(latent))
        # fc1_b=F.relu((self.fc1_b_nn(latent)))
        # fc1_w=fc1_w.reshape(-1,self.input_shape,self.args.rnn_hidden_dim)
        # fc1_b=fc1_b.reshape(-1,1,self.args.rnn_hidden_dim)

        # rnn_ih_w=F.relu(self.rnn_ih_w_nn(latent))
        # rnn_ih_b=F.relu(self.rnn_ih_b_nn(latent))
        # rnn_hh_w=F.relu(self.rnn_hh_w_nn(latent))
        # rnn_hh_b=F.relu(self.rnn_hh_b_nn(latent))
        # rnn_ih_w=rnn_ih_w.reshape(-1,self.args.rnn_hidden_dim,self.args.rnn_hidden_dim)
        # rnn_ih_b=rnn_ih_b.reshape(-1,1,self.args.rnn_hidden_dim)
        # rnn_hh_w = rnn_hh_w.reshape(-1, self.args.rnn_hidden_dim, self.args.rnn_hidden_dim)
        # rnn_hh_b = rnn_hh_b.reshape(-1, 1, self.args.rnn_hidden_dim)

        fc2_w = self.fc2_w_nn(latent)
        fc2_b = self.fc2_b_nn(latent)
        fc2_w = fc2_w.reshape(-1, self.args.rnn_hidden_dim, self.args.n_actions)
        fc2_b = fc2_b.reshape((-1, 1, self.args.n_actions))

        # x=F.relu(th.bmm(inputs,fc1_w)+fc1_b) #(bs*n,(obs+act+id)) at time t
        x = F.relu(self.fc1(inputs))  # (bs*n,(obs+act+id)) at time t

        # gi=th.bmm(x,rnn_ih_w)+rnn_ih_b
        # gh=th.bmm(h_in,rnn_hh_w)+rnn_hh_b
        # i_r,i_i,i_n=gi.chunk(3,2)
        # h_r,h_i,h_n=gh.chunk(3,2)

        # resetgate=th.sigmoid(i_r+h_r)
        # inputgate=th.sigmoid(i_i+h_i)
        # newgate=th.tanh(i_n+resetgate*h_n)
        # h=newgate+inputgate*(h_in-newgate)
        # h=th.tanh(gi+gh)

        # x=x.reshape(-1,self.args.rnn_hidden_dim)
        h = self.rnn(x, h_in)
        h = h.reshape(-1, 1, self.args.rnn_hidden_dim)

        q = th.bmm(h, fc2_w) + fc2_b

        # h_in = hidden_state.reshape(-1, self.args.rnn_hidden_dim) # (bs,n,dim) ==> (bs*n, dim)
        # h = self.rnn(x, h_in)
        # q = self.fc2(h)
        return q.view(-1, self.args.n_actions), h.view(-1, self.args.rnn_hidden_dim), loss
Exemplo n.º 29
0
    def beam_search(self,
                    init,
                    steps,
                    beam_size,
                    temperature=1.0,
                    stochastic=False,
                    verbose=False):
        assert len(init.shape) == 2 and init.shape[1] == self.init_dim
        assert self.event_dim >= beam_size > 0 and steps > 0

        batch_size = init.shape[0]
        current_beam_size = beam_size

        # Initial hidden weights
        hidden = self.init_to_hidden(
            init)  # [gru_layers, batch_size, hidden_size]
        hidden = hidden[:, :,
                        None, :]  # [gru_layers, batch_size, 1, hidden_size]
        hidden = hidden.repeat(
            1, 1, current_beam_size,
            1)  # [gru_layers, batch_size, beam_size, hidden_dim]

        # Initial event
        event = self.get_primary_event(batch_size)  # [1, batch]
        event = event[:, :, None].repeat(1, 1,
                                         current_beam_size)  # [1, batch, 1]

        # [batch, beam, 1]   event sequences of beams
        beam_events = event[0, :, None, :].repeat(1, current_beam_size, 1)

        # [batch, beam] log probs sum of beams
        beam_log_prob = torch.zeros(batch_size, current_beam_size).to(device)

        if stochastic:
            # [batch, beam] Gumbel perturbed log probs of beams
            beam_log_prob_perturbed = torch.zeros(batch_size,
                                                  current_beam_size).to(device)
            beam_z = torch.full((batch_size, beam_size), float('inf'))
            gumbel_dist = Gumbel(0, 1)

        step_iter = range(steps)
        if verbose:
            step_iter = Bar(['', 'Stochastic '][stochastic] +
                            'Beam Search').iter(step_iter)

        for step in step_iter:

            event = event.view(1, batch_size *
                               current_beam_size)  # [1, batch*beam0]
            hidden = hidden.view(self.rnn_layers,
                                 batch_size * current_beam_size,
                                 self.hidden_dim)  # [grus, batch*beam, hid]

            logits, hidden = self.gen_forward(event, hidden)
            hidden = hidden.view(self.rnn_layers, batch_size,
                                 current_beam_size,
                                 self.hidden_dim)  # [grus, batch, cbeam, hid]
            logits = (logits / temperature).view(
                1, batch_size, current_beam_size,
                self.event_dim)  # [1, batch, cbeam, out]

            beam_log_prob_expand = logits + beam_log_prob[
                None, :, :, None]  # [1, batch, cbeam, out]
            beam_log_prob_expand_batch = beam_log_prob_expand.view(
                1, batch_size, -1)  # [1, batch, cbeam*out]

            if stochastic:
                beam_log_prob_expand_perturbed = beam_log_prob_expand + gumbel_dist.sample(
                    beam_log_prob_expand.shape)
                beam_log_prob_Z, _ = beam_log_prob_expand_perturbed.max(
                    -1)  # [1, batch, cbeam]
                # print(beam_log_prob_Z)
                beam_log_prob_expand_perturbed_normalized = beam_log_prob_expand_perturbed
                # beam_log_prob_expand_perturbed_normalized = -torch.log(
                #     torch.exp(-beam_log_prob_perturbed[None, :, :, None])
                #     - torch.exp(-beam_log_prob_Z[:, :, :, None])
                #     + torch.exp(-beam_log_prob_expand_perturbed)) # [1, batch, cbeam, out]
                # beam_log_prob_expand_perturbed_normalized = beam_log_prob_perturbed[None, :, :, None] + beam_log_prob_expand_perturbed # [1, batch, cbeam, out]

                beam_log_prob_expand_perturbed_normalized_batch = \
                    beam_log_prob_expand_perturbed_normalized.view(1, batch_size, -1)  # [1, batch, cbeam*out]
                _, top_indices = beam_log_prob_expand_perturbed_normalized_batch.topk(
                    beam_size, -1)  # [1, batch, cbeam]

                beam_log_prob_perturbed = \
                    torch.gather(beam_log_prob_expand_perturbed_normalized_batch, -1, top_indices)[0]  # [batch, beam]

            else:
                _, top_indices = beam_log_prob_expand_batch.topk(beam_size, -1)

            top_indices.to(device)

            beam_log_prob = torch.gather(beam_log_prob_expand_batch, -1,
                                         top_indices)[0]  # [batch, beam]

            beam_index_old = torch.arange(
                current_beam_size, device=device)[None, None, :,
                                                  None]  # [1, 1, cbeam, 1]
            beam_index_old = beam_index_old.repeat(
                1, batch_size, 1, self.output_dim)  # [1, batch, cbeam, out]
            beam_index_old = beam_index_old.view(1, batch_size,
                                                 -1)  # [1, batch, cbeam*out]
            # beam_index_old.to(device)
            # print(device)
            # print(beam_index_old.device)
            # print(top_indices.device)
            beam_index_new = torch.gather(beam_index_old, -1, top_indices)

            hidden = torch.gather(
                hidden, 2, beam_index_new[:, :, :, None].repeat(4, 1, 1, 1024))

            event_index = torch.arange(
                self.output_dim, device=device)[None, None,
                                                None, :]  # [1, 1, 1, out]
            event_index = event_index.repeat(1, batch_size, current_beam_size,
                                             1)  # [1, batch, cbeam, out]
            event_index = event_index.view(1, batch_size,
                                           -1)  # [1, batch, cbeam*out]
            event_index.to(device)
            event = torch.gather(event_index, -1,
                                 top_indices)  # [1, batch, cbeam*out]

            beam_events = torch.gather(
                beam_events[None], 2,
                beam_index_new.unsqueeze(-1).repeat(1, 1, 1,
                                                    beam_events.shape[-1]))
            beam_events = torch.cat([beam_events, event.unsqueeze(-1)], -1)[0]

            current_beam_size = beam_size

        best = beam_events[torch.arange(batch_size).long(),
                           beam_log_prob.argmax(-1)]
        best = best.contiguous().t()
        return best
Exemplo n.º 30
0
def get_raoblackwell_ps_loss(
        conditional_loss_fun,
        log_class_weights,
        topk,
        sample_topk,
        grad_estimator,
        grad_estimator_kwargs={'grad_estimator_kwargs': None},
        epoch=None,
        data=None):
    """
    Returns a pseudo_loss, such that the gradient obtained by calling
    pseudo_loss.backwards() is unbiased for the true loss

    Parameters
    ----------
    conditional_loss_fun : function
        A function that returns the loss conditional on an instance of the
        categorical random variable. It must take in a one-hot-encoding
        matrix (batchsize x n_categories) and return a vector of
        losses, one for each observation in the batch.
    log_class_weights : torch.Tensor
        A tensor of shape batchsize x n_categories of the log class weights
    topk : Integer
        The number of categories to sum over
    grad_estimator : function
        A function that returns the pseudo loss, that is, the loss which
        gives a gradient estimator when .backwards() is called.
        See baselines_lib for details.
    grad_estimator_kwargs : dict
        keyword arguments to gradient estimator
    epoch : int
        The epoch of the optimizer (for Gumbel-softmax, which has an annealing rate)
    data : torch.Tensor
        The data at which we evaluate the loss (for NVIl and RELAX, which have
        a data dependent baseline)

    Returns
    -------
    ps_loss :
        a value such that ps_loss.backward() returns an
        estimate of the gradient.
        In general, ps_loss might not equal the actual loss.
    """

    # class weights from the variational distribution
    assert np.all(log_class_weights.detach().cpu().numpy() <= 0)
    class_weights = torch.exp(log_class_weights.detach())

    if sample_topk:
        # perturb the log_class_weights
        phi = log_class_weights.detach()
        g_phi = Gumbel(phi, torch.ones_like(phi)).rsample()

        _, ind = g_phi.topk(topk + 1, dim=-1)

        topk_domain = ind[..., :-1]
        concentrated_mask = torch.zeros_like(phi).scatter(-1, topk_domain,
                                                          1).detach()
        sample_ind = ind[..., -1]  # Last sample we use as real sample
        seq_tensor = torch.arange(class_weights.size(0),
                                  dtype=torch.long,
                                  device=class_weights.device)
    else:
        # this is the indicator C_k
        concentrated_mask, topk_domain, seq_tensor = \
            get_concentrated_mask(class_weights, topk)
        concentrated_mask = concentrated_mask.float().detach()

    ############################
    # compute the summed term
    summed_term = 0.0

    for i in range(topk):
        # get categories to be summed
        summed_indx = topk_domain[:, i]

        # compute gradient estimate
        grad_summed = \
                grad_estimator(conditional_loss_fun, log_class_weights,
                                class_weights, seq_tensor, \
                                z_sample = summed_indx,
                                epoch = epoch,
                                data = data,
                                **grad_estimator_kwargs)

        # sum
        summed_weights = class_weights[seq_tensor, summed_indx].squeeze()
        summed_term = summed_term + \
                        (grad_summed * summed_weights).sum()

    ############################
    # compute sampled term
    sampled_weight = torch.sum(class_weights * (1 - concentrated_mask),
                               dim=1,
                               keepdim=True)

    if not (topk == class_weights.shape[1]):
        # if we didn't sum everything
        # we sample from the remaining terms

        if not sample_topk:
            # class weights conditioned on being in the diffuse set
            conditional_class_weights = (class_weights + 1e-12) * \
                        (1 - concentrated_mask)  / (sampled_weight + 1e-12)

            # sample from conditional distribution
            conditional_z_sample = sample_class_weights(
                conditional_class_weights)
        else:
            conditional_z_sample = sample_ind  # We have already sampled it

        grad_sampled = grad_estimator(conditional_loss_fun,
                                      log_class_weights,
                                      class_weights,
                                      seq_tensor,
                                      z_sample=conditional_z_sample,
                                      epoch=epoch,
                                      data=data,
                                      **grad_estimator_kwargs)

    else:
        grad_sampled = 0.

    return (grad_sampled * sampled_weight.squeeze()).sum() + summed_term