def cat_softmax(probs, mode, tau=1, hard=False, dim=-1): if mode == 'REINFORCE' or mode == 'SCST': cat_distr = OneHotCategorical(probs=probs) return cat_distr.sample(), cat_distr.entropy() elif mode == 'GUMBEL': cat_distr = RelaxedOneHotCategorical(tau, probs=probs) y_soft = cat_distr.rsample() if hard: # Straight through. index = y_soft.max(dim, keepdim=True)[1] y_hard = torch.zeros_like(probs, device=DEVICE).scatter_(dim, index, 1.0) ret = y_hard - y_soft.detach() + y_soft else: # Reparametrization trick. ret = y_soft return ret, ret
def decoding_sampler(logits, mode, tau=1, hard=False, dim=-1): if mode == 'REINFORCE' or mode == 'SCST': cat_distr = OneHotCategorical(logits=logits) return cat_distr.sample() elif mode == 'GUMBEL': cat_distr = RelaxedOneHotCategorical(tau, logits=logits) y_soft = cat_distr.rsample() elif mode == 'SOFTMAX': y_soft = F.softmax(logits, dim=1) if hard: # Straight through. index = y_soft.max(dim, keepdim=True)[1] y_hard = torch.zeros_like(logits, device=args.device).scatter_( dim, index, 1.0) ret = y_hard - y_soft.detach() + y_soft else: # Reparametrization trick. ret = y_soft return ret