 def gumbel_sample(self, data) :
     """ We use gumbel sampling to pick on the mixtures based on their mixture 
     weights """
     distribution = Gumbel(loc=0, scale=1)
     z = distribution.sample()
     # z = np.random.gumbel(loc=0, scale=1, size=data.shape)
     return (torch.log(data) + z).argmax(dim=1)
 def __init__(self, start_tokens: torch.LongTensor,
              end_token: Union[int, torch.LongTensor], tau: float,
              straight_through: bool = False,
              stop_gradient: bool = False, use_finish: bool = True):
     super().__init__(start_tokens, end_token, tau,
                      stop_gradient, use_finish)
     self._straight_through = straight_through
     # unit-scale, zero-location Gumbel distribution
     self._gumbel = Gumbel(loc=torch.tensor(0.0), scale=torch.tensor(1.0))
    def __init__(self, do_entropy=False):

        self.gumbel = Gumbel(loc=0, scale=1)
        self.do_entropy = do_entropy
        if self.do_entropy:
            self.entropy = None
def gumbel_with_maximum(phi, T, dim=-1):
    Samples a set of gumbels which are conditioned on having a maximum along a dimension
    phi.max(dim)[0] should be broadcastable with the desired maximum T
    # Gumbel with location phi, use PyTorch distributions so you cannot get -inf or inf (which causes trouble)
    g_phi = Gumbel(phi, torch.ones_like(phi)).rsample()
    Z, argmax = g_phi.max(dim)
    g = _shift_gumbel_maximum(g_phi, T, dim, Z=Z)
    return g, argmax
    def __init__(self, bias=None):

        :param bias: a vector of log frequencies, to bias the sampling towards what we want
        self.gumbel = Gumbel(loc=0, scale=1)
        if bias is not None:
            self.bias = torch.from_numpy(np.array(bias)).to(dtype=torch.float32, device=device).unsqueeze(0)
            self.bias = None
    def __init__(self, temperature=1.0, eps=0.0, do_entropy=True):

        :param bias: a vector of log frequencies, to bias the sampling towards what we want
        self.gumbel = Gumbel(loc=0, scale=1)
        self.temperature = torch.tensor(temperature, requires_grad=False)
        self.eps = eps
        self.do_entropy = do_entropy
        if self.do_entropy:
            self.entropy = None
ファイル: swor.py プロジェクト: HEmile/storchastic
def cond_gumbel_sample(all_joint_log_probs,
                       perturbed_log_probs) -> torch.Tensor:
    # Sample plates x k? x |D_yv| Gumbel variables
    gumbel_d = Gumbel(loc=all_joint_log_probs, scale=1.0)
    G_yv = gumbel_d.rsample()

    # Condition the Gumbel samples on the maximum of previous samples
    # plates x k
    Z = G_yv.max(dim=-1)[0]
    T = perturbed_log_probs
    vi = T - G_yv + log1mexp(G_yv - Z.unsqueeze(-1))
    # plates (x k) x |D_yv|
    return T - vi.relu() - torch.nn.Softplus()(-vi.abs())
ファイル: frechet.py プロジェクト: IronOnet/ReAgent
    def __init__(
        shape: float = 1.0,
        topk: Optional[int] = None,
        equiv_len: Optional[int] = None,
        log_scores: bool = False,
        """FréchetSort is a softer version of descending sort which samples all possible
        orderings of items favoring orderings which resemble descending sort. This can
        be used to convert descending sort by rank score into a differentiable,
        stochastic policy amenable to policy gradient algorithms.

        :param shape: parameter of Frechet Distribution. Lower values correspond to
        aggressive deviations from descending sort.
        :param topk: If specified, only the first topk actions are specified.
        :param equiv_len: Orders are considered equivalent if the top equiv_len match. Used
            in probability computations.
            Essentially specifies the action space.
        :param log_scores Scores passed in are already log-transformed. In this case, we would
        simply add Gumbel noise.
        For LearnVM, we set this to be True because we expect input and output scores
        to be in the log space.


        Consider the sampler:

        sampler = FrechetSort(shape=3, topk=5, equiv_len=3)

        Given a set of scores, this sampler will produce indices of items roughly
        resembling a argsort by scores in descending order. The higher the shape,
        the more it would resemble a descending argsort. `topk=5` means only the top
        5 ranks will be output. The `equiv_len` determines what orders are considered
        equivalent for probability computation. In this example, the sampler will
        produce probability for the top 3 items appearing in a given order for the
        `log_prob` call.
        self.shape = shape
        self.topk = topk
        self.upto = equiv_len
        if topk is not None:
            if equiv_len is None:
                self.upto = topk
            # pyre-fixme[58]: `>` is not supported for operand types `Optional[int]`
            #  and `Optional[int]`.
            if self.upto > self.topk:
                raise ValueError(
                    f"Equiv length {equiv_len} cannot exceed topk={topk}.")
        self.gumbel_noise = Gumbel(0, 1.0 / shape)
        self.log_scores = log_scores
class SoftmaxRandomSamplePolicy(SimplePolicy):
    Randomly samples from the softmax of the logits
    # https://hips.seas.harvard.edu/blog/2013/04/06/the-gumbel-max-trick-for-discrete-distributions/
    TODO: should probably switch that to something more like
    def __init__(self, bias=None):

        :param bias: a vector of log frequencies, to bias the sampling towards what we want
        self.gumbel = Gumbel(loc=0, scale=1)
        if bias is not None:
            self.bias = torch.from_numpy(np.array(bias)).to(dtype=torch.float32, device=device).unsqueeze(0)
            self.bias = None
    def forward(self, logits: Variable):

        :param logits: Logits to generate probabilities from, batch_size x out_dim float32
        eff_logits = self.effective_logits(logits)
        x = self.gumbel.sample(logits.shape).to(device) + eff_logits
        _, out = torch.max(x, -1)
        return out

    def effective_logits(self, logits):
        if self.bias is not None:
            logits += self.bias
        return logits
 def log_sample(self, x, u, n, alpha=1., xg=None, ug=None):
     c = self.c.cost(x, u, xg, ug)
     costs = c.reshape(-1, 1)  # unnormalized log probabilities
     samples = Gumbel(loc=0., scale=1.).sample((x.shape[0], n))
     log_gumbel = alpha * costs + samples
     _, choices = torch.max(log_gumbel, 0)
     return choices, alpha * costs
def reinforce_unordered(conditional_loss_fun,

    # Sample without replacement using Gumbel top-k trick
    phi = log_class_weights.detach()
    g_phi = Gumbel(phi, torch.ones_like(phi)).rsample()
    _, ind = g_phi.topk(n_samples, -1)

    log_p = log_class_weights.gather(-1, ind)
    n_classes = log_class_weights.shape[1]
    costs = torch.stack([
            z_sample, n_classes)) for z_sample in ind.t()
    ], -1)

    with torch.no_grad():  # Don't compute gradients for advantage and ratio
        # log_R_s, log_R_ss = compute_log_R(log_p)
        log_R_s, log_R_ss = compute_log_R_O_nfac(log_p)

        if baseline_constant is not None:
            bl_vals = baseline_constant
        elif baseline_separate:
            bl_vals = get_baseline(conditional_loss_fun, log_class_weights,
                                   baseline_n_samples, baseline_deterministic)
            # Same bl for all samples, so add dimension
            bl_vals = bl_vals[:, None]
        elif log_p.size(-1) > 1:
            # Compute built in baseline
            bl_vals = ((log_p[:, None, :] + log_R_ss).exp() *
                       costs[:, None, :]).sum(-1)
            bl_vals = 0.  # No bl
        adv = costs - bl_vals
    # Also add the costs (with the unordered estimator) in case there is a direct dependency on the parameters
    loss = ((log_p + log_R_s).exp() * adv.detach() +
            (log_p + log_R_s).exp().detach() * costs).sum(-1)
    return loss
def sample_stgs(proba_no_softmax, K, temp):
    # sample with GS
    out=(proba_no_softmax + g) / temp
    # soft approximation
    # use argmax at forward pass
    # go throught one_hot for backward pass
    sample=(one_hot_symb - y).detach() + y
    return sample
    def __init__(
        shape: float = 1,
        upto: Optional[int] = None
        """Samples an ordering

            shape: 1/scale parameter of Gumbel distribution
                At lower temperatures, order is closer to argmax(scores)
            upto [Default: None]: If provided, position upto which log_prob is computed


        An action sampler which produces ordering
        self.shape = shape
        self.upto = upto
        self.gumbel_noise = Gumbel(0, 1.0 / shape)
def sample_gumbel_softmax(logits, temperature, n_sample):
        logits: Tensor of log probs, shape = BS x k
        temperature = scalar

        Output: Tensor of values sampled from Gumbel softmax.
                These will tend towards a one-hot representation in the limit of temp -> 0
                shape = n_sample x BS x k
    g = Gumbel(torch.zeros(*logits.shape), torch.ones(*logits.shape)).rsample(
        (n_sample, )).to(device)
    h = (g + logits) / temperature
    y = F.softmax(h, dim=-1)
    return y
class FrechetOrderSampler:
    def __init__(
        shape: float = 1,
        upto: Optional[int] = None
        """Samples an ordering

            shape: 1/scale parameter of Gumbel distribution
                At lower temperatures, order is closer to argmax(scores)
            upto [Default: None]: If provided, position upto which log_prob is computed


        An action sampler which produces ordering
        self.shape = shape
        self.upto = upto
        self.gumbel_noise = Gumbel(0, 1.0 / shape)

    def sample(
        scores: torch.Tensor
        """Sample an ordering given scores"""
        perturbed = torch.log(scores) + self.gumbel_noise.sample((len(scores),))
        return torch.argsort(-perturbed.detach())

    def log_prob(self, scores : torch.Tensor, permutations):
        """Compute log probability given scores and an action (a permutation).
        The formula uses the equivalence of sorting with Gumbel noise and
        Plackett-Luce model (See Yellot 1977)

            scores: scores of different items
            action: prescribed (permutation) order of the items
        s = torch.log(select_indices(scores, permutations))
        n = len(scores)
        p = self.upto if self.upto is not None else n - 1
        return -sum(
            torch.log(torch.exp((s[k:] - s[k]) * self.shape).sum(dim=0))
            for k in range(p))
    def sample_model_spec(self, num):
        Override, sample the alpha via gumbel softmax instead of normal softmax.
        :param num:
        alpha_topology = self.alpha_topology.detach().clone()
        alpha_ops = self.alpha_ops.detach().clone()
        sample_archs = []
        sample_ops = []
        gumbel_dist = Gumbel(torch.tensor([.0]), torch.tensor([1.0]))
        with torch.no_grad():
            for i in range(self.num_intermediate_nodes):
                # align with topoligy weights
                probs = gumbel_softmax(alpha_topology[: i+2, i], self.temperature(), gumbel_dist)
                probs_op = gumbel_softmax(alpha_ops[:, i], self.temperature(), gumbel_dist)

            return self._sample_model_spec(num, sample_archs, sample_ops)
class SoftmaxRandomSamplePolicySparse(SimplePolicy):
    Randomly samples from the softmax of the logits
    # https://hips.seas.harvard.edu/blog/2013/04/06/the-gumbel-max-trick-for-discrete-distributions/
    TODO: should probably switch that to something more like
    def __init__(self, do_entropy=False):

        self.gumbel = Gumbel(loc=0, scale=1)
        self.do_entropy = do_entropy
        if self.do_entropy:
            self.entropy = None

    def forward(self, all_logits_list, all_action_inds_list):

        :param logits: list of tensors of len batch_size
        :param action_inds: list of long tensors with indices of the actions corresponding to the logits
        # The number of feasible actions is differnt for every item in the batch, so for loop is the simplest
        logp = []
        out_actions = []
        for logits, action_inds in zip(all_logits_list, all_action_inds_list):
            if len(logits):
                x = self.gumbel.sample(logits.shape).to(device=device, dtype=logits.dtype) + logits
                _, out = torch.max(x, -1)
                this_logp = F.log_softmax(logits)[out]
                out_actions.append(torch.tensor(0, device=logits.device, dtype=action_inds.dtype))
                logp.append(torch.tensor(0.0, device=logits.device, dtype=logits.dtype))
        self.logp = torch.stack(logp)
        out = torch.stack(out_actions)
        return out
 def ops_weights(self, node):
     gumbel_dist = Gumbel(torch.tensor([.0]), torch.tensor([1.0]))
     return gumbel_softmax(self.alpha_ops[:, node], self.temperature(), gumbel_dist)
 def topology_weights(self, node):
     # return the soft weights for topology aggregation
     gumbel_dist = Gumbel(torch.tensor([.0]), torch.tensor([1.0]))
     return gumbel_softmax(self.alpha_topology[: node + 2, node], self.temperature(), gumbel_dist)[1:]
ファイル: util.py プロジェクト: HEmile/storchastic
def rsample_gumbel(
    distr: Distribution,
    n: int,
) -> torch.Tensor:
    gumbel_distr = Gumbel(distr.logits, 1)
    return gumbel_distr.rsample((n, ))
def reinforce_sum_and_sample(conditional_loss_fun,

    # Sample without replacement using Gumbel top-k trick
    phi = log_class_weights.detach()
    g_phi = Gumbel(phi, torch.ones_like(phi)).rsample()

    _, ind = g_phi.topk(n_samples, -1)

    log_p = log_class_weights.gather(-1, ind)
    n_classes = log_class_weights.shape[1]
    costs = torch.stack([
            z_sample, n_classes)) for z_sample in ind.t()
    ], -1)

    with torch.no_grad():  # Don't compute gradients for advantage and ratio
        if baseline_separate:
            bl_vals = get_baseline(conditional_loss_fun, log_class_weights,
                                   baseline_n_samples, baseline_deterministic)
            # Same bl for all samples, so add dimension
            bl_vals = bl_vals[:, None]
            assert baseline_n_samples < n_samples
            bl_sampled_weight = log1mexp(
                log_p[:, :baseline_n_samples -
            bl_vals = (log_p[:, :baseline_n_samples - 1].exp() * costs[:, :baseline_n_samples -1]).sum(-1)\
                      + bl_sampled_weight * costs[:, baseline_n_samples - 1]
            bl_vals = bl_vals[:, None]

    # We compute an 'exact' gradient if the sum of probabilities is roughly more than 1 - 1e-5
    # in which case we can simply sum al the terms and the relative error will be < 1e-5
    use_exact = log_p.logsumexp(-1) > -1e-5
    not_use_exact = use_exact == 0

    cost_exact = costs[use_exact]
    exact_loss = compute_summed_terms(log_p[use_exact], cost_exact,
                                      cost_exact - bl_vals[use_exact])

    log_p_est = log_p[not_use_exact]
    costs_est = costs[not_use_exact]
    bl_vals_est = bl_vals[not_use_exact]

    if rao_blackwellize:
        ap = all_perms(torch.arange(n_samples, dtype=torch.long),
        log_p_ap = log_p_est[:, ap]
        bl_vals_ap = bl_vals_est.expand_as(costs_est)[:, ap]
        costs_ap = costs_est[:, ap]
        cond_losses = compute_sum_and_sample_loss(log_p_ap, costs_ap,

        # Compute probabilities for permutations
        log_probs_perms = log_pl_rec(log_p_ap, -1)
        cond_log_probs_perms = log_probs_perms - log_probs_perms.logsumexp(
            -1, keepdim=True)
        losses = (cond_losses * cond_log_probs_perms.exp()).sum(-1)
        losses = compute_sum_and_sample_loss(log_p_est, costs_est, bl_vals_est)

    # If they are summed we can simply concatenate but for consistency it is best to place them in order
    all_losses = log_p.new_zeros(log_p.size(0))
    all_losses[use_exact] = exact_loss
    all_losses[not_use_exact] = losses
    return all_losses
ファイル: frechet.py プロジェクト: IronOnet/ReAgent
class FrechetSort(Sampler):
    EPS = 1e-12

    def __init__(
        shape: float = 1.0,
        topk: Optional[int] = None,
        equiv_len: Optional[int] = None,
        log_scores: bool = False,
        """FréchetSort is a softer version of descending sort which samples all possible
        orderings of items favoring orderings which resemble descending sort. This can
        be used to convert descending sort by rank score into a differentiable,
        stochastic policy amenable to policy gradient algorithms.

        :param shape: parameter of Frechet Distribution. Lower values correspond to
        aggressive deviations from descending sort.
        :param topk: If specified, only the first topk actions are specified.
        :param equiv_len: Orders are considered equivalent if the top equiv_len match. Used
            in probability computations.
            Essentially specifies the action space.
        :param log_scores Scores passed in are already log-transformed. In this case, we would
        simply add Gumbel noise.
        For LearnVM, we set this to be True because we expect input and output scores
        to be in the log space.


        Consider the sampler:

        sampler = FrechetSort(shape=3, topk=5, equiv_len=3)

        Given a set of scores, this sampler will produce indices of items roughly
        resembling a argsort by scores in descending order. The higher the shape,
        the more it would resemble a descending argsort. `topk=5` means only the top
        5 ranks will be output. The `equiv_len` determines what orders are considered
        equivalent for probability computation. In this example, the sampler will
        produce probability for the top 3 items appearing in a given order for the
        `log_prob` call.
        self.shape = shape
        self.topk = topk
        self.upto = equiv_len
        if topk is not None:
            if equiv_len is None:
                self.upto = topk
            # pyre-fixme[58]: `>` is not supported for operand types `Optional[int]`
            #  and `Optional[int]`.
            if self.upto > self.topk:
                raise ValueError(
                    f"Equiv length {equiv_len} cannot exceed topk={topk}.")
        self.gumbel_noise = Gumbel(0, 1.0 / shape)
        self.log_scores = log_scores

    def sample_action(self, scores: torch.Tensor) -> rlt.ActorOutput:
        """Sample a ranking according to Frechet sort. Note that possible_actions_mask
        is ignored as the list of rankings scales exponentially with slate size and
        number of items and it can be difficult to enumerate them."""
        assert scores.dim() == 2, "sample_action only accepts batches"
        log_scores = scores if self.log_scores else torch.log(scores)
        perturbed = log_scores + self.gumbel_noise.sample(scores.shape)
        action = torch.argsort(perturbed.detach(), descending=True)
        log_prob = self.log_prob(scores, action)
        # Only truncate the action before returning
        if self.topk is not None:
            action = action[:self.topk]
        return rlt.ActorOutput(action, log_prob)

    def log_prob(
        scores: torch.Tensor,
        action: torch.Tensor,
        equiv_len_override: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        What is the probability of a given set of scores producing the given
        list of permutations only considering the top `equiv_len` ranks?

        We may want to override the default equiv_len here when we know the having larger
        action space doesn't matter. i.e. in Reels
        upto = self.upto
        if equiv_len_override is not None:
            assert equiv_len_override.shape == (
            ), f"Invalid shape {equiv_len_override.shape}, compared to scores {scores.shape}. equiv_len_override {equiv_len_override}"
            upto = equiv_len_override.long()
            if self.topk is not None and torch.any(
                    equiv_len_override > self.topk):
                raise ValueError(
                    f"Override {equiv_len_override} cannot exceed topk={self.topk}."

        squeeze = False
        if len(scores.shape) == 1:
            squeeze = True
            scores = scores.unsqueeze(0)
            action = action.unsqueeze(0)

        assert len(action.shape) == len(
            scores.shape) == 2, "scores should be batch"
        if action.shape[1] > scores.shape[1]:
            raise ValueError(
                f"action cardinality ({action.shape[1]}) is larger than the number of scores ({scores.shape[1]})"
        elif action.shape[1] < scores.shape[1]:
            raise NotImplementedError(
                f"This semantic is ambiguous. If you have shorter slate, pad it with scores.shape[1] ({scores.shape[1]})"

        log_scores = scores if self.log_scores else torch.log(scores)
        n = log_scores.shape[-1]
        # Add scores for the padding value
        log_scores = torch.cat(
                torch.full((log_scores.shape[0], 1),
        s = torch.gather(log_scores, 1, action) * self.shape

        p = upto if upto is not None else n
        # We should unsqueeze here
        if isinstance(p, int):
            probs = sum(
                torch.nan_to_num(F.log_softmax(s[:, i:], dim=1)[:, 0],
                                 neginf=0.0) for i in range(p))
        elif isinstance(p, torch.Tensor):
            # do masked sum
            probs = sum(
                torch.nan_to_num(F.log_softmax(s[:, i:], dim=1)[:, 0],
                                 neginf=0.0) * (i < p).float()
                for i in range(n))
            raise RuntimeError(f"p is {p}")
        return probs
    def sample_genotype_index(self):
        with torch.no_grad():
            sampled_alphas = gumbel_softmax_sample(
            best_sampled_alphas = torch.argmax(sampled_alphas, dim=0)
            return best_sampled_alphas.detach()

class MixedSequential(torch.nn.Sequential):
    def forward(self, input):
        total_cost = 0
        for module in self._modules.values():
            input = module(input)
            if isinstance(module, MixedModule):
                input, cost = input
                total_cost += cost
        return input

gumbel_dist = Gumbel(0.0, 1.0)

def gumbel_softmax_sample(logits, temperature, dim=None, std=1.0):
    y = logits + gumbel_dist.sample(logits.shape).to(device=logits.device,
    # y = logits + sample_gumbel(logits=logits, std=std)
    return F.softmax(y / temperature, dim=dim)
class FrechetSort(Sampler):
    def __init__(
        shape: float = 1.0,
        topk: Optional[int] = None,
        equiv_len: Optional[int] = None,
        log_scores: bool = False,
        """FréchetSort is a softer version of descending sort which samples all possible
        orderings of items favoring orderings which resemble descending sort. This can
        be used to convert descending sort by rank score into a differentiable,
        stochastic policy amenable to policy gradient algorithms.

        :param shape: parameter of Frechet Distribution. Lower values correspond to
        aggressive deviations from descending sort.
        :param topk: If specified, only the first topk actions are specified.
        :param equiv_len: Orders are considered equivalent if the top equiv_len match. Used
            in probability computations
        :param log_scores Scores passed in are already log-transformed. In this case, we would
        simply add Gumbel noise.


        Consider the sampler:

        sampler = FrechetSort(shape=3, topk=5, equiv_len=3)

        Given a set of scores, this sampler will produce indices of items roughly
        resembling a argsort by scores in descending order. The higher the shape,
        the more it would resemble a descending argsort. `topk=5` means only the top
        5 ranks will be output. The `equiv_len` determines what orders are considered
        equivalent for probability computation. In this example, the sampler will
        produce probability for the top 3 items appearing in a given order for the
        `log_prob` call.
        self.shape = shape
        self.topk = topk
        self.upto = equiv_len
        if topk is not None:
            if equiv_len is None:
                self.upto = topk
            # pyre-fixme[58]: `>` is not supported for operand types `Optional[int]`
            #  and `Optional[int]`.
            if self.upto > self.topk:
                raise ValueError(f"Equiv length {equiv_len} cannot exceed topk={topk}.")
        self.gumbel_noise = Gumbel(0, 1.0 / shape)
        self.log_scores = log_scores

    def select_indices(scores: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
        """Helper for scores[actions] that are also works for batched tensors"""
        if len(actions.shape) > 1:
            num_rows = scores.size(0)
            row_indices = torch.arange(num_rows).unsqueeze(0).T  # pyre-ignore[ 16 ]
            return scores[row_indices, actions].T
            return scores[actions]

    def sample_action(self, scores: torch.Tensor) -> rlt.ActorOutput:
        """Sample a ranking according to Frechet sort. Note that possible_actions_mask
        is ignored as the list of rankings scales exponentially with slate size and
        number of items and it can be difficult to enumerate them."""
        assert scores.dim() == 2, "sample_action only accepts batches"
        log_scores = scores if self.log_scores else torch.log(scores)
        perturbed = log_scores + self.gumbel_noise.sample((scores.shape[1],))
        action = torch.argsort(perturbed.detach(), descending=True)
        if self.topk is not None:
            action = action[: self.topk]
        log_prob = self.log_prob(scores, action)
        return rlt.ActorOutput(action, log_prob)

    def log_prob(self, scores: torch.Tensor, action) -> torch.Tensor:
        """What is the probability of a given set of scores producing the given
        list of permutations only considering the top `equiv_len` ranks?"""
        log_scores = scores if self.log_scores else torch.log(scores)
        s = self.select_indices(log_scores, action)
        n = len(log_scores)
        p = self.upto if self.upto is not None else n
        return -sum(
            torch.log(torch.exp((s[k:] - s[k]) * self.shape).sum(dim=0))
            for k in range(p)  # pyre-ignore
def gumbel_softmax_sample(logits, temp=1.):
    g = Gumbel(0, 1).sample(logits.shape)
    y = (g + logits) / temp
    return torch.softmax(y, dim=-1)
class SoftmaxRandomSamplePolicy(SimplePolicy):
    Randomly samples from the softmax of the logits
    # https://hips.seas.harvard.edu/blog/2013/04/06/the-gumbel-max-trick-for-discrete-distributions/
    TODO: should probably switch that to something more like
    def __init__(self, temperature=1.0, eps=0.0, do_entropy=True):

        :param bias: a vector of log frequencies, to bias the sampling towards what we want
        self.gumbel = Gumbel(loc=0, scale=1)
        self.temperature = torch.tensor(temperature, requires_grad=False)
        self.eps = eps
        self.do_entropy = do_entropy
        if self.do_entropy:
            self.entropy = None

    def set_temperature(self, new_temperature):
        if self.temperature != new_temperature:
            self.temperature = torch.tensor(new_temperature,

    def forward(self, logits: Variable, priors=None):

        :param logits: Logits to generate probabilities from, batch_size x out_dim float32
        # epsilon-greediness
        if random.random() < self.eps:
            if priors is None:
                new_logits = torch.zeros_like(logits)
                new_logits[logits < logits.max()-1000] = -1e4
                new_logits = priors
            if self.do_entropy:
                self.entropy = torch.zeros_like(logits).sum(dim=1)
            if self.temperature.dtype != logits.dtype or self.temperature.device != logits.device:
                self.temperature = torch.tensor(self.temperature,
            # print('temperature: ', self.temperature)

            # temperature is applied to model only, not priors!
            if priors is None:
                new_logits = logits/self.temperature
                raw_logits = new_logits
                raw_logits = (logits-priors)/self.temperature
                new_logits = priors + raw_logits
            if self.do_entropy:
                raw_logits_normalized = F.log_softmax(raw_logits, dim=1)
                self.entropy = torch.sum(-raw_logits_normalized*torch.exp(raw_logits_normalized), dim=1)

        eff_logits = new_logits
        x = self.gumbel.sample(logits.shape).to(device=device, dtype=eff_logits.dtype) + eff_logits
        _, out = torch.max(x, -1)
        all_logp = F.log_softmax(eff_logits, dim=1)
        self.logp = torch.cat([this_logp[this_ind:(this_ind+1)] for this_logp, this_ind in zip(all_logp, out)])
        return out

    def effective_logits(self, logits):
        return logits
class GumbelSoftmaxEmbeddingHelper(SoftmaxEmbeddingHelper):
    r"""A helper that feeds Gumbel softmax sample to the next step.

    Uses the Gumbel softmax vector to pass through word embeddings to
    get the next input (i.e., a mixed word embedding).

    A subclass of :class:`~texar.modules.Helper`. Used as a helper to
    :class:`~texar.modules.RNNDecoderBase` in inference mode.

    Same as :class:`~texar.modules.SoftmaxEmbeddingHelper` except that here
    Gumbel softmax (instead of softmax) is used.

        embedding: A callable or the ``params`` argument for
            If a callable, it can take a vector tensor of ``ids`` (argmax
            ids), or take two arguments (``ids``, ``times``), where ``ids``
            is a vector of argmax ids, and ``times`` is a vector of current
            time steps (i.e., position ids). The latter case can be used
            when :attr:`embedding` is a combination of word embedding and
            position embedding.
            The returned tensor will be passed to the decoder input.
        start_tokens: 1D :tensor:`LongTensor` shaped ``[batch_size]``,
            representing the start tokens for each sequence in batch.
        end_token: Python int or scalar :tensor:`LongTensor`, denoting the
            token that marks end of decoding.
        tau: A float scalar tensor, the softmax temperature.
        straight_through (bool): Whether to use straight through gradient
            between time steps. If `True`, a single token with highest
            probability (i.e., greedy sample) is fed to the next step and
            gradient is computed using straight through. If `False`
            (default), the soft Gumbel-softmax distribution is fed to the
            next step.
        stop_gradient (bool): Whether to stop the gradient backpropagation
            when feeding softmax vector to the next step.
        use_finish (bool): Whether to stop decoding once :attr:`end_token`
            is generated. If `False`, decoding will continue until
            :attr:`max_decoding_length` of the decoder is reached.

        ValueError: if :attr:`start_tokens` is not a 1D tensor or
            :attr:`end_token` is not a scalar.

    def __init__(self, start_tokens: torch.LongTensor,
                 end_token: Union[int, torch.LongTensor], tau: float,
                 straight_through: bool = False,
                 stop_gradient: bool = False, use_finish: bool = True):
        super().__init__(start_tokens, end_token, tau,
                         stop_gradient, use_finish)
        self._straight_through = straight_through
        # unit-scale, zero-location Gumbel distribution
        self._gumbel = Gumbel(loc=torch.tensor(0.0), scale=torch.tensor(1.0))

    def sample(self, time: int, outputs: torch.Tensor) -> torch.Tensor:
        r"""Returns ``sample_id`` of shape ``[batch_size, vocab_size]``. If
        :attr:`straight_through` is `False`, this contains the Gumbel softmax
        distributions over vocabulary with temperature :attr:`tau`. If
        :attr:`straight_through` is `True`, this contains one-hot vectors of
        the greedy samples.
        gumbel_samples = self._gumbel.sample(outputs.size()).to(
            device=outputs.device, dtype=outputs.dtype)
        sample_ids = torch.softmax(
            (outputs + gumbel_samples) / self._tau, dim=-1)
        if self._straight_through:
            argmax_ids = torch.argmax(sample_ids, dim=-1).unsqueeze(1)
            sample_ids_hard = torch.zeros_like(sample_ids).scatter_(
                dim=-1, index=argmax_ids, value=1.0)  # one-hot vectors
            sample_ids = (sample_ids_hard - sample_ids).detach() + sample_ids
        return sample_ids
ファイル: latent_cat_rnn_agent.py プロジェクト: drdh/pymarl
    def forward(self, inputs, hidden_state):
        inputs = inputs.reshape(-1, self.input_shape)
        h_in = hidden_state.reshape(-1, self.hidden_dim)

        #compute latent
        self.latent = F.softmax(self.embed_fc(inputs[:self.n_agents, - self.n_agents:]),dim=1)  #(n, pi)

        latent_embed = self.latent.unsqueeze(0).expand(self.bs, self.n_agents, self.latent_dim).reshape(
            self.bs * self.n_agents, self.latent_dim) #(bs*n,pi)

        latent_infer = F.relu(self.inference_fc1(th.cat([h_in.detach(), inputs[:, :-self.n_agents]], dim=1)))
        latent_infer = F.softmax(self.inference_fc2(latent_infer),dim=1) # (bs*n,pi)

        #self.latent = self.embed_fc(inputs[:self.n_agents, - self.n_agents:])  # (n,2*latent_dim)==(n,mu+log var)
        #self.latent[:, -self.latent_dim:] = th.exp(self.latent[:, -self.latent_dim:])  # var
        #latent_embed = self.latent.unsqueeze(0).expand(self.bs, self.n_agents, self.latent_dim * 2).reshape(
        #    self.bs * self.n_agents, self.latent_dim * 2)

        #latent_infer = F.relu(self.inference_fc1(th.cat([h_in, inputs[:, :-self.n_agents]], dim=1)))
        #latent_infer = self.inference_fc2(latent_infer)  # (n,2*latent_dim)==(n,mu+log var)
        #latent_infer[:, -self.latent_dim:] = th.exp(latent_infer[:, -self.latent_dim:])
        # loss
        #loss= -(latent_embed * th.log(latent_infer+self.eps)).sum()/(self.bs*self.n_agents)
        loss = -(latent_infer * th.log(latent_embed + self.eps)).sum() / (self.bs * self.n_agents)
        #loss = -((latent_infer * th.log(latent_embed + self.eps)).sum() + 0.01*(latent_embed*th.log(latent_embed+self.eps)).sum())/ (self.bs * self.n_agents)

        # sample
        latent_embed = F.softmax(th.log(latent_embed+self.eps) + g.sample(latent_embed.size()).cuda(), dim=1)  # softmax onehot
        #latent_infer = F.softmax(th.log(latent_infer+self.eps) + g.sample(latent_infer.size()), dim=1)

        #gaussian_embed = D.Normal(latent_embed[:, :self.latent_dim], (latent_embed[:, self.latent_dim:])**(1/2))
        #gaussian_infer = D.Normal(latent_infer[:, :self.latent_dim], (latent_infer[:, self.latent_dim:])**(1/2))

        #loss = gaussian_embed.entropy().sum() + kl_divergence(gaussian_embed, gaussian_infer).sum()  # CE = H + KL
        #loss = loss / (self.bs*self.n_agents)
        # handcrafted reparameterization
        # (1,n*latent_dim)                            (1,n*latent_dim)==>(bs,n*latent*dim)
        # latent_embed = self.latent[:,:self.latent_dim].reshape(1,-1)+self.latent[:,-self.latent_dim:].reshape(1,-1)*th.randn(self.bs,self.n_agents*self.latent_dim)
        # latent_embed = latent_embed.reshape(-1,self.latent_dim)  #(bs*n,latent_dim)
        # latent_infer = latent_infer[:, :self.latent_dim] + latent_infer[:, -self.latent_dim:] * th.randn_like(latent_infer[:, -self.latent_dim:])
        # loss= (latent_embed-latent_infer).norm(dim=1).sum()/(self.bs*self.n_agents)

        #latent = gaussian_embed.rsample()

        #latent = F.relu(self.latent_fc1(latent))
        #latent = (self.latent_fc2(latent))

        # latent=latent.reshape(-1,self.args.latent_dim)

        # fc1_w=F.relu(self.fc1_w_nn(latent))
        # fc1_b=F.relu((self.fc1_b_nn(latent)))
        # fc1_w=fc1_w.reshape(-1,self.input_shape,self.args.rnn_hidden_dim)
        # fc1_b=fc1_b.reshape(-1,1,self.args.rnn_hidden_dim)

        # rnn_ih_w=F.relu(self.rnn_ih_w_nn(latent))
        # rnn_ih_b=F.relu(self.rnn_ih_b_nn(latent))
        # rnn_hh_w=F.relu(self.rnn_hh_w_nn(latent))
        # rnn_hh_b=F.relu(self.rnn_hh_b_nn(latent))
        # rnn_ih_w=rnn_ih_w.reshape(-1,self.args.rnn_hidden_dim,self.args.rnn_hidden_dim)
        # rnn_ih_b=rnn_ih_b.reshape(-1,1,self.args.rnn_hidden_dim)
        # rnn_hh_w = rnn_hh_w.reshape(-1, self.args.rnn_hidden_dim, self.args.rnn_hidden_dim)
        # rnn_hh_b = rnn_hh_b.reshape(-1, 1, self.args.rnn_hidden_dim)

        fc2_w = self.fc2_w_nn(latent)
        fc2_b = self.fc2_b_nn(latent)
        fc2_w = fc2_w.reshape(-1, self.args.rnn_hidden_dim, self.args.n_actions)
        fc2_b = fc2_b.reshape((-1, 1, self.args.n_actions))

        # x=F.relu(th.bmm(inputs,fc1_w)+fc1_b) #(bs*n,(obs+act+id)) at time t
        x = F.relu(self.fc1(inputs))  # (bs*n,(obs+act+id)) at time t

        # gi=th.bmm(x,rnn_ih_w)+rnn_ih_b
        # gh=th.bmm(h_in,rnn_hh_w)+rnn_hh_b
        # i_r,i_i,i_n=gi.chunk(3,2)
        # h_r,h_i,h_n=gh.chunk(3,2)

        # resetgate=th.sigmoid(i_r+h_r)
        # inputgate=th.sigmoid(i_i+h_i)
        # newgate=th.tanh(i_n+resetgate*h_n)
        # h=newgate+inputgate*(h_in-newgate)
        # h=th.tanh(gi+gh)

        # x=x.reshape(-1,self.args.rnn_hidden_dim)
        h = self.rnn(x, h_in)
        h = h.reshape(-1, 1, self.args.rnn_hidden_dim)

        q = th.bmm(h, fc2_w) + fc2_b

        # h_in = hidden_state.reshape(-1, self.args.rnn_hidden_dim) # (bs,n,dim) ==> (bs*n, dim)
        # h = self.rnn(x, h_in)
        # q = self.fc2(h)
        return q.view(-1, self.args.n_actions), h.view(-1, self.args.rnn_hidden_dim), loss
    def beam_search(self,
        assert len(init.shape) == 2 and init.shape[1] == self.init_dim
        assert self.event_dim >= beam_size > 0 and steps > 0

        batch_size = init.shape[0]
        current_beam_size = beam_size

        # Initial hidden weights
        hidden = self.init_to_hidden(
            init)  # [gru_layers, batch_size, hidden_size]
        hidden = hidden[:, :,
                        None, :]  # [gru_layers, batch_size, 1, hidden_size]
        hidden = hidden.repeat(
            1, 1, current_beam_size,
            1)  # [gru_layers, batch_size, beam_size, hidden_dim]

        # Initial event
        event = self.get_primary_event(batch_size)  # [1, batch]
        event = event[:, :, None].repeat(1, 1,
                                         current_beam_size)  # [1, batch, 1]

        # [batch, beam, 1]   event sequences of beams
        beam_events = event[0, :, None, :].repeat(1, current_beam_size, 1)

        # [batch, beam] log probs sum of beams
        beam_log_prob = torch.zeros(batch_size, current_beam_size).to(device)

        if stochastic:
            # [batch, beam] Gumbel perturbed log probs of beams
            beam_log_prob_perturbed = torch.zeros(batch_size,
            beam_z = torch.full((batch_size, beam_size), float('inf'))
            gumbel_dist = Gumbel(0, 1)

        step_iter = range(steps)
        if verbose:
            step_iter = Bar(['', 'Stochastic '][stochastic] +
                            'Beam Search').iter(step_iter)

        for step in step_iter:

            event = event.view(1, batch_size *
                               current_beam_size)  # [1, batch*beam0]
            hidden = hidden.view(self.rnn_layers,
                                 batch_size * current_beam_size,
                                 self.hidden_dim)  # [grus, batch*beam, hid]

            logits, hidden = self.gen_forward(event, hidden)
            hidden = hidden.view(self.rnn_layers, batch_size,
                                 self.hidden_dim)  # [grus, batch, cbeam, hid]
            logits = (logits / temperature).view(
                1, batch_size, current_beam_size,
                self.event_dim)  # [1, batch, cbeam, out]

            beam_log_prob_expand = logits + beam_log_prob[
                None, :, :, None]  # [1, batch, cbeam, out]
            beam_log_prob_expand_batch = beam_log_prob_expand.view(
                1, batch_size, -1)  # [1, batch, cbeam*out]

            if stochastic:
                beam_log_prob_expand_perturbed = beam_log_prob_expand + gumbel_dist.sample(
                beam_log_prob_Z, _ = beam_log_prob_expand_perturbed.max(
                    -1)  # [1, batch, cbeam]
                # print(beam_log_prob_Z)
                beam_log_prob_expand_perturbed_normalized = beam_log_prob_expand_perturbed
                # beam_log_prob_expand_perturbed_normalized = -torch.log(
                #     torch.exp(-beam_log_prob_perturbed[None, :, :, None])
                #     - torch.exp(-beam_log_prob_Z[:, :, :, None])
                #     + torch.exp(-beam_log_prob_expand_perturbed)) # [1, batch, cbeam, out]
                # beam_log_prob_expand_perturbed_normalized = beam_log_prob_perturbed[None, :, :, None] + beam_log_prob_expand_perturbed # [1, batch, cbeam, out]

                beam_log_prob_expand_perturbed_normalized_batch = \
                    beam_log_prob_expand_perturbed_normalized.view(1, batch_size, -1)  # [1, batch, cbeam*out]
                _, top_indices = beam_log_prob_expand_perturbed_normalized_batch.topk(
                    beam_size, -1)  # [1, batch, cbeam]

                beam_log_prob_perturbed = \
                    torch.gather(beam_log_prob_expand_perturbed_normalized_batch, -1, top_indices)[0]  # [batch, beam]

                _, top_indices = beam_log_prob_expand_batch.topk(beam_size, -1)


            beam_log_prob = torch.gather(beam_log_prob_expand_batch, -1,
                                         top_indices)[0]  # [batch, beam]

            beam_index_old = torch.arange(
                current_beam_size, device=device)[None, None, :,
                                                  None]  # [1, 1, cbeam, 1]
            beam_index_old = beam_index_old.repeat(
                1, batch_size, 1, self.output_dim)  # [1, batch, cbeam, out]
            beam_index_old = beam_index_old.view(1, batch_size,
                                                 -1)  # [1, batch, cbeam*out]
            # beam_index_old.to(device)
            # print(device)
            # print(beam_index_old.device)
            # print(top_indices.device)
            beam_index_new = torch.gather(beam_index_old, -1, top_indices)

            hidden = torch.gather(
                hidden, 2, beam_index_new[:, :, :, None].repeat(4, 1, 1, 1024))

            event_index = torch.arange(
                self.output_dim, device=device)[None, None,
                                                None, :]  # [1, 1, 1, out]
            event_index = event_index.repeat(1, batch_size, current_beam_size,
                                             1)  # [1, batch, cbeam, out]
            event_index = event_index.view(1, batch_size,
                                           -1)  # [1, batch, cbeam*out]
            event = torch.gather(event_index, -1,
                                 top_indices)  # [1, batch, cbeam*out]

            beam_events = torch.gather(
                beam_events[None], 2,
                beam_index_new.unsqueeze(-1).repeat(1, 1, 1,
            beam_events = torch.cat([beam_events, event.unsqueeze(-1)], -1)[0]

            current_beam_size = beam_size

        best = beam_events[torch.arange(batch_size).long(),
        best = best.contiguous().t()
        return best
def get_raoblackwell_ps_loss(
        grad_estimator_kwargs={'grad_estimator_kwargs': None},
    Returns a pseudo_loss, such that the gradient obtained by calling
    pseudo_loss.backwards() is unbiased for the true loss

    conditional_loss_fun : function
        A function that returns the loss conditional on an instance of the
        categorical random variable. It must take in a one-hot-encoding
        matrix (batchsize x n_categories) and return a vector of
        losses, one for each observation in the batch.
    log_class_weights : torch.Tensor
        A tensor of shape batchsize x n_categories of the log class weights
    topk : Integer
        The number of categories to sum over
    grad_estimator : function
        A function that returns the pseudo loss, that is, the loss which
        gives a gradient estimator when .backwards() is called.
        See baselines_lib for details.
    grad_estimator_kwargs : dict
        keyword arguments to gradient estimator
    epoch : int
        The epoch of the optimizer (for Gumbel-softmax, which has an annealing rate)
    data : torch.Tensor
        The data at which we evaluate the loss (for NVIl and RELAX, which have
        a data dependent baseline)

    ps_loss :
        a value such that ps_loss.backward() returns an
        estimate of the gradient.
        In general, ps_loss might not equal the actual loss.

    # class weights from the variational distribution
    assert np.all(log_class_weights.detach().cpu().numpy() <= 0)
    class_weights = torch.exp(log_class_weights.detach())

    if sample_topk:
        # perturb the log_class_weights
        phi = log_class_weights.detach()
        g_phi = Gumbel(phi, torch.ones_like(phi)).rsample()

        _, ind = g_phi.topk(topk + 1, dim=-1)

        topk_domain = ind[..., :-1]
        concentrated_mask = torch.zeros_like(phi).scatter(-1, topk_domain,
        sample_ind = ind[..., -1]  # Last sample we use as real sample
        seq_tensor = torch.arange(class_weights.size(0),
        # this is the indicator C_k
        concentrated_mask, topk_domain, seq_tensor = \
            get_concentrated_mask(class_weights, topk)
        concentrated_mask = concentrated_mask.float().detach()

    # compute the summed term
    summed_term = 0.0

    for i in range(topk):
        # get categories to be summed
        summed_indx = topk_domain[:, i]

        # compute gradient estimate
        grad_summed = \
                grad_estimator(conditional_loss_fun, log_class_weights,
                                class_weights, seq_tensor, \
                                z_sample = summed_indx,
                                epoch = epoch,
                                data = data,

        # sum
        summed_weights = class_weights[seq_tensor, summed_indx].squeeze()
        summed_term = summed_term + \
                        (grad_summed * summed_weights).sum()

    # compute sampled term
    sampled_weight = torch.sum(class_weights * (1 - concentrated_mask),

    if not (topk == class_weights.shape[1]):
        # if we didn't sum everything
        # we sample from the remaining terms

        if not sample_topk:
            # class weights conditioned on being in the diffuse set
            conditional_class_weights = (class_weights + 1e-12) * \
                        (1 - concentrated_mask)  / (sampled_weight + 1e-12)

            # sample from conditional distribution
            conditional_z_sample = sample_class_weights(
            conditional_z_sample = sample_ind  # We have already sampled it

        grad_sampled = grad_estimator(conditional_loss_fun,

        grad_sampled = 0.

    return (grad_sampled * sampled_weight.squeeze()).sum() + summed_term