Esempio n. 1
0
 def gumbel_sample(self, data) :
     """ We use gumbel sampling to pick on the mixtures based on their mixture 
     weights """
     distribution = Gumbel(loc=0, scale=1)
     z = distribution.sample()
     # z = np.random.gumbel(loc=0, scale=1, size=data.shape)
     return (torch.log(data) + z).argmax(dim=1)
Esempio n. 2
0
class SoftmaxRandomSamplePolicy(SimplePolicy):
    '''
    Randomly samples from the softmax of the logits
    # https://hips.seas.harvard.edu/blog/2013/04/06/the-gumbel-max-trick-for-discrete-distributions/
    TODO: should probably switch that to something more like
    http://pytorch.org/docs/master/distributions.html
    '''
    def __init__(self, bias=None):
        '''

        :param bias: a vector of log frequencies, to bias the sampling towards what we want
        '''
        super().__init__()
        self.gumbel = Gumbel(loc=0, scale=1)
        if bias is not None:
            self.bias = torch.from_numpy(np.array(bias)).to(dtype=torch.float32, device=device).unsqueeze(0)
        else:
            self.bias = None
    def forward(self, logits: Variable):
        '''

        :param logits: Logits to generate probabilities from, batch_size x out_dim float32
        :return:
        '''
        eff_logits = self.effective_logits(logits)
        x = self.gumbel.sample(logits.shape).to(device) + eff_logits
        _, out = torch.max(x, -1)
        return out

    def effective_logits(self, logits):
        if self.bias is not None:
            logits += self.bias
        return logits
Esempio n. 3
0
class FrechetOrderSampler:
    def __init__(
        self,
        shape: float = 1,
        upto: Optional[int] = None
    ):
        """Samples an ordering

        Args:
            shape: 1/scale parameter of Gumbel distribution
                At lower temperatures, order is closer to argmax(scores)
            upto [Default: None]: If provided, position upto which log_prob is computed

        Returns:

        An action sampler which produces ordering
        """
        self.shape = shape
        self.upto = upto
        self.gumbel_noise = Gumbel(0, 1.0 / shape)

    def sample(
        self,
        scores: torch.Tensor
    ):
        """Sample an ordering given scores"""
        perturbed = torch.log(scores) + self.gumbel_noise.sample((len(scores),))
        return torch.argsort(-perturbed.detach())

    def log_prob(self, scores : torch.Tensor, permutations):
        """Compute log probability given scores and an action (a permutation).
        The formula uses the equivalence of sorting with Gumbel noise and
        Plackett-Luce model (See Yellot 1977)

        Args:
            scores: scores of different items
            action: prescribed (permutation) order of the items
        """
        s = torch.log(select_indices(scores, permutations))
        n = len(scores)
        p = self.upto if self.upto is not None else n - 1
        return -sum(
            torch.log(torch.exp((s[k:] - s[k]) * self.shape).sum(dim=0))
            for k in range(p))
Esempio n. 4
0
class SoftmaxRandomSamplePolicySparse(SimplePolicy):
    '''
    Randomly samples from the softmax of the logits
    # https://hips.seas.harvard.edu/blog/2013/04/06/the-gumbel-max-trick-for-discrete-distributions/
    TODO: should probably switch that to something more like
    http://pytorch.org/docs/master/distributions.html
    '''
    def __init__(self, do_entropy=False):
        '''

        '''
        super().__init__()
        self.gumbel = Gumbel(loc=0, scale=1)
        self.do_entropy = do_entropy
        if self.do_entropy:
            self.entropy = None

    def forward(self, all_logits_list, all_action_inds_list):
        '''

        :param logits: list of tensors of len batch_size
        :param action_inds: list of long tensors with indices of the actions corresponding to the logits
        :return:
        '''
        # The number of feasible actions is differnt for every item in the batch, so for loop is the simplest
        logp = []
        out_actions = []
        for logits, action_inds in zip(all_logits_list, all_action_inds_list):
            if len(logits):
                x = self.gumbel.sample(logits.shape).to(device=device, dtype=logits.dtype) + logits
                _, out = torch.max(x, -1)
                this_logp = F.log_softmax(logits)[out]
                out_actions.append(action_inds[out])
                logp.append(this_logp)
            else:
                out_actions.append(torch.tensor(0, device=logits.device, dtype=action_inds.dtype))
                logp.append(torch.tensor(0.0, device=logits.device, dtype=logits.dtype))
        self.logp = torch.stack(logp)
        out = torch.stack(out_actions)
        return out
Esempio n. 5
0
    def beam_search(self,
                    init,
                    steps,
                    beam_size,
                    temperature=1.0,
                    stochastic=False,
                    verbose=False):
        assert len(init.shape) == 2 and init.shape[1] == self.init_dim
        assert self.event_dim >= beam_size > 0 and steps > 0

        batch_size = init.shape[0]
        current_beam_size = beam_size

        # Initial hidden weights
        hidden = self.init_to_hidden(
            init)  # [gru_layers, batch_size, hidden_size]
        hidden = hidden[:, :,
                        None, :]  # [gru_layers, batch_size, 1, hidden_size]
        hidden = hidden.repeat(
            1, 1, current_beam_size,
            1)  # [gru_layers, batch_size, beam_size, hidden_dim]

        # Initial event
        event = self.get_primary_event(batch_size)  # [1, batch]
        event = event[:, :, None].repeat(1, 1,
                                         current_beam_size)  # [1, batch, 1]

        # [batch, beam, 1]   event sequences of beams
        beam_events = event[0, :, None, :].repeat(1, current_beam_size, 1)

        # [batch, beam] log probs sum of beams
        beam_log_prob = torch.zeros(batch_size, current_beam_size).to(device)

        if stochastic:
            # [batch, beam] Gumbel perturbed log probs of beams
            beam_log_prob_perturbed = torch.zeros(batch_size,
                                                  current_beam_size).to(device)
            beam_z = torch.full((batch_size, beam_size), float('inf'))
            gumbel_dist = Gumbel(0, 1)

        step_iter = range(steps)
        if verbose:
            step_iter = Bar(['', 'Stochastic '][stochastic] +
                            'Beam Search').iter(step_iter)

        for step in step_iter:

            event = event.view(1, batch_size *
                               current_beam_size)  # [1, batch*beam0]
            hidden = hidden.view(self.rnn_layers,
                                 batch_size * current_beam_size,
                                 self.hidden_dim)  # [grus, batch*beam, hid]

            logits, hidden = self.gen_forward(event, hidden)
            hidden = hidden.view(self.rnn_layers, batch_size,
                                 current_beam_size,
                                 self.hidden_dim)  # [grus, batch, cbeam, hid]
            logits = (logits / temperature).view(
                1, batch_size, current_beam_size,
                self.event_dim)  # [1, batch, cbeam, out]

            beam_log_prob_expand = logits + beam_log_prob[
                None, :, :, None]  # [1, batch, cbeam, out]
            beam_log_prob_expand_batch = beam_log_prob_expand.view(
                1, batch_size, -1)  # [1, batch, cbeam*out]

            if stochastic:
                beam_log_prob_expand_perturbed = beam_log_prob_expand + gumbel_dist.sample(
                    beam_log_prob_expand.shape)
                beam_log_prob_Z, _ = beam_log_prob_expand_perturbed.max(
                    -1)  # [1, batch, cbeam]
                # print(beam_log_prob_Z)
                beam_log_prob_expand_perturbed_normalized = beam_log_prob_expand_perturbed
                # beam_log_prob_expand_perturbed_normalized = -torch.log(
                #     torch.exp(-beam_log_prob_perturbed[None, :, :, None])
                #     - torch.exp(-beam_log_prob_Z[:, :, :, None])
                #     + torch.exp(-beam_log_prob_expand_perturbed)) # [1, batch, cbeam, out]
                # beam_log_prob_expand_perturbed_normalized = beam_log_prob_perturbed[None, :, :, None] + beam_log_prob_expand_perturbed # [1, batch, cbeam, out]

                beam_log_prob_expand_perturbed_normalized_batch = \
                    beam_log_prob_expand_perturbed_normalized.view(1, batch_size, -1)  # [1, batch, cbeam*out]
                _, top_indices = beam_log_prob_expand_perturbed_normalized_batch.topk(
                    beam_size, -1)  # [1, batch, cbeam]

                beam_log_prob_perturbed = \
                    torch.gather(beam_log_prob_expand_perturbed_normalized_batch, -1, top_indices)[0]  # [batch, beam]

            else:
                _, top_indices = beam_log_prob_expand_batch.topk(beam_size, -1)

            top_indices.to(device)

            beam_log_prob = torch.gather(beam_log_prob_expand_batch, -1,
                                         top_indices)[0]  # [batch, beam]

            beam_index_old = torch.arange(
                current_beam_size, device=device)[None, None, :,
                                                  None]  # [1, 1, cbeam, 1]
            beam_index_old = beam_index_old.repeat(
                1, batch_size, 1, self.output_dim)  # [1, batch, cbeam, out]
            beam_index_old = beam_index_old.view(1, batch_size,
                                                 -1)  # [1, batch, cbeam*out]
            # beam_index_old.to(device)
            # print(device)
            # print(beam_index_old.device)
            # print(top_indices.device)
            beam_index_new = torch.gather(beam_index_old, -1, top_indices)

            hidden = torch.gather(
                hidden, 2, beam_index_new[:, :, :, None].repeat(4, 1, 1, 1024))

            event_index = torch.arange(
                self.output_dim, device=device)[None, None,
                                                None, :]  # [1, 1, 1, out]
            event_index = event_index.repeat(1, batch_size, current_beam_size,
                                             1)  # [1, batch, cbeam, out]
            event_index = event_index.view(1, batch_size,
                                           -1)  # [1, batch, cbeam*out]
            event_index.to(device)
            event = torch.gather(event_index, -1,
                                 top_indices)  # [1, batch, cbeam*out]

            beam_events = torch.gather(
                beam_events[None], 2,
                beam_index_new.unsqueeze(-1).repeat(1, 1, 1,
                                                    beam_events.shape[-1]))
            beam_events = torch.cat([beam_events, event.unsqueeze(-1)], -1)[0]

            current_beam_size = beam_size

        best = beam_events[torch.arange(batch_size).long(),
                           beam_log_prob.argmax(-1)]
        best = best.contiguous().t()
        return best
Esempio n. 6
0
class GumbelSoftmaxEmbeddingHelper(SoftmaxEmbeddingHelper):
    r"""A helper that feeds Gumbel softmax sample to the next step.

    Uses the Gumbel softmax vector to pass through word embeddings to
    get the next input (i.e., a mixed word embedding).

    A subclass of :class:`~texar.modules.Helper`. Used as a helper to
    :class:`~texar.modules.RNNDecoderBase` in inference mode.

    Same as :class:`~texar.modules.SoftmaxEmbeddingHelper` except that here
    Gumbel softmax (instead of softmax) is used.

    Args:
        embedding: A callable or the ``params`` argument for
            :torch_nn:`functional.embedding`.
            If a callable, it can take a vector tensor of ``ids`` (argmax
            ids), or take two arguments (``ids``, ``times``), where ``ids``
            is a vector of argmax ids, and ``times`` is a vector of current
            time steps (i.e., position ids). The latter case can be used
            when :attr:`embedding` is a combination of word embedding and
            position embedding.
            The returned tensor will be passed to the decoder input.
        start_tokens: 1D :tensor:`LongTensor` shaped ``[batch_size]``,
            representing the start tokens for each sequence in batch.
        end_token: Python int or scalar :tensor:`LongTensor`, denoting the
            token that marks end of decoding.
        tau: A float scalar tensor, the softmax temperature.
        straight_through (bool): Whether to use straight through gradient
            between time steps. If `True`, a single token with highest
            probability (i.e., greedy sample) is fed to the next step and
            gradient is computed using straight through. If `False`
            (default), the soft Gumbel-softmax distribution is fed to the
            next step.
        stop_gradient (bool): Whether to stop the gradient backpropagation
            when feeding softmax vector to the next step.
        use_finish (bool): Whether to stop decoding once :attr:`end_token`
            is generated. If `False`, decoding will continue until
            :attr:`max_decoding_length` of the decoder is reached.

    Raises:
        ValueError: if :attr:`start_tokens` is not a 1D tensor or
            :attr:`end_token` is not a scalar.
    """

    def __init__(self, start_tokens: torch.LongTensor,
                 end_token: Union[int, torch.LongTensor], tau: float,
                 straight_through: bool = False,
                 stop_gradient: bool = False, use_finish: bool = True):
        super().__init__(start_tokens, end_token, tau,
                         stop_gradient, use_finish)
        self._straight_through = straight_through
        # unit-scale, zero-location Gumbel distribution
        self._gumbel = Gumbel(loc=torch.tensor(0.0), scale=torch.tensor(1.0))

    def sample(self, time: int, outputs: torch.Tensor) -> torch.Tensor:
        r"""Returns ``sample_id`` of shape ``[batch_size, vocab_size]``. If
        :attr:`straight_through` is `False`, this contains the Gumbel softmax
        distributions over vocabulary with temperature :attr:`tau`. If
        :attr:`straight_through` is `True`, this contains one-hot vectors of
        the greedy samples.
        """
        gumbel_samples = self._gumbel.sample(outputs.size()).to(
            device=outputs.device, dtype=outputs.dtype)
        sample_ids = torch.softmax(
            (outputs + gumbel_samples) / self._tau, dim=-1)
        if self._straight_through:
            argmax_ids = torch.argmax(sample_ids, dim=-1).unsqueeze(1)
            sample_ids_hard = torch.zeros_like(sample_ids).scatter_(
                dim=-1, index=argmax_ids, value=1.0)  # one-hot vectors
            sample_ids = (sample_ids_hard - sample_ids).detach() + sample_ids
        return sample_ids
Esempio n. 7
0
class FrechetSort(Sampler):
    EPS = 1e-12

    @resolve_defaults
    def __init__(
        self,
        shape: float = 1.0,
        topk: Optional[int] = None,
        equiv_len: Optional[int] = None,
        log_scores: bool = False,
    ):
        """FréchetSort is a softer version of descending sort which samples all possible
        orderings of items favoring orderings which resemble descending sort. This can
        be used to convert descending sort by rank score into a differentiable,
        stochastic policy amenable to policy gradient algorithms.

        :param shape: parameter of Frechet Distribution. Lower values correspond to
        aggressive deviations from descending sort.
        :param topk: If specified, only the first topk actions are specified.
        :param equiv_len: Orders are considered equivalent if the top equiv_len match. Used
            in probability computations.
            Essentially specifies the action space.
        :param log_scores Scores passed in are already log-transformed. In this case, we would
        simply add Gumbel noise.
        For LearnVM, we set this to be True because we expect input and output scores
        to be in the log space.

        Example:

        Consider the sampler:

        sampler = FrechetSort(shape=3, topk=5, equiv_len=3)

        Given a set of scores, this sampler will produce indices of items roughly
        resembling a argsort by scores in descending order. The higher the shape,
        the more it would resemble a descending argsort. `topk=5` means only the top
        5 ranks will be output. The `equiv_len` determines what orders are considered
        equivalent for probability computation. In this example, the sampler will
        produce probability for the top 3 items appearing in a given order for the
        `log_prob` call.
        """
        self.shape = shape
        self.topk = topk
        self.upto = equiv_len
        if topk is not None:
            if equiv_len is None:
                self.upto = topk
            # pyre-fixme[58]: `>` is not supported for operand types `Optional[int]`
            #  and `Optional[int]`.
            if self.upto > self.topk:
                raise ValueError(
                    f"Equiv length {equiv_len} cannot exceed topk={topk}.")
        self.gumbel_noise = Gumbel(0, 1.0 / shape)
        self.log_scores = log_scores

    def sample_action(self, scores: torch.Tensor) -> rlt.ActorOutput:
        """Sample a ranking according to Frechet sort. Note that possible_actions_mask
        is ignored as the list of rankings scales exponentially with slate size and
        number of items and it can be difficult to enumerate them."""
        assert scores.dim() == 2, "sample_action only accepts batches"
        log_scores = scores if self.log_scores else torch.log(scores)
        perturbed = log_scores + self.gumbel_noise.sample(scores.shape)
        action = torch.argsort(perturbed.detach(), descending=True)
        log_prob = self.log_prob(scores, action)
        # Only truncate the action before returning
        if self.topk is not None:
            action = action[:self.topk]
        return rlt.ActorOutput(action, log_prob)

    def log_prob(
        self,
        scores: torch.Tensor,
        action: torch.Tensor,
        equiv_len_override: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        What is the probability of a given set of scores producing the given
        list of permutations only considering the top `equiv_len` ranks?

        We may want to override the default equiv_len here when we know the having larger
        action space doesn't matter. i.e. in Reels
        """
        upto = self.upto
        if equiv_len_override is not None:
            assert equiv_len_override.shape == (
                scores.shape[0],
            ), f"Invalid shape {equiv_len_override.shape}, compared to scores {scores.shape}. equiv_len_override {equiv_len_override}"
            upto = equiv_len_override.long()
            if self.topk is not None and torch.any(
                    equiv_len_override > self.topk):
                raise ValueError(
                    f"Override {equiv_len_override} cannot exceed topk={self.topk}."
                )

        squeeze = False
        if len(scores.shape) == 1:
            squeeze = True
            scores = scores.unsqueeze(0)
            action = action.unsqueeze(0)

        assert len(action.shape) == len(
            scores.shape) == 2, "scores should be batch"
        if action.shape[1] > scores.shape[1]:
            raise ValueError(
                f"action cardinality ({action.shape[1]}) is larger than the number of scores ({scores.shape[1]})"
            )
        elif action.shape[1] < scores.shape[1]:
            raise NotImplementedError(
                f"This semantic is ambiguous. If you have shorter slate, pad it with scores.shape[1] ({scores.shape[1]})"
            )

        log_scores = scores if self.log_scores else torch.log(scores)
        n = log_scores.shape[-1]
        # Add scores for the padding value
        log_scores = torch.cat(
            [
                log_scores,
                torch.full((log_scores.shape[0], 1),
                           -math.inf,
                           device=log_scores.device),
            ],
            dim=1,
        )
        s = torch.gather(log_scores, 1, action) * self.shape

        p = upto if upto is not None else n
        # We should unsqueeze here
        if isinstance(p, int):
            probs = sum(
                torch.nan_to_num(F.log_softmax(s[:, i:], dim=1)[:, 0],
                                 neginf=0.0) for i in range(p))
        elif isinstance(p, torch.Tensor):
            # do masked sum
            probs = sum(
                torch.nan_to_num(F.log_softmax(s[:, i:], dim=1)[:, 0],
                                 neginf=0.0) * (i < p).float()
                for i in range(n))
        else:
            raise RuntimeError(f"p is {p}")
        return probs
Esempio n. 8
0
class FrechetSort(Sampler):
    @resolve_defaults
    def __init__(
        self,
        shape: float = 1.0,
        topk: Optional[int] = None,
        equiv_len: Optional[int] = None,
        log_scores: bool = False,
    ):
        """FréchetSort is a softer version of descending sort which samples all possible
        orderings of items favoring orderings which resemble descending sort. This can
        be used to convert descending sort by rank score into a differentiable,
        stochastic policy amenable to policy gradient algorithms.

        :param shape: parameter of Frechet Distribution. Lower values correspond to
        aggressive deviations from descending sort.
        :param topk: If specified, only the first topk actions are specified.
        :param equiv_len: Orders are considered equivalent if the top equiv_len match. Used
            in probability computations
        :param log_scores Scores passed in are already log-transformed. In this case, we would
        simply add Gumbel noise.

        Example:

        Consider the sampler:

        sampler = FrechetSort(shape=3, topk=5, equiv_len=3)

        Given a set of scores, this sampler will produce indices of items roughly
        resembling a argsort by scores in descending order. The higher the shape,
        the more it would resemble a descending argsort. `topk=5` means only the top
        5 ranks will be output. The `equiv_len` determines what orders are considered
        equivalent for probability computation. In this example, the sampler will
        produce probability for the top 3 items appearing in a given order for the
        `log_prob` call.
        """
        self.shape = shape
        self.topk = topk
        self.upto = equiv_len
        if topk is not None:
            if equiv_len is None:
                self.upto = topk
            # pyre-fixme[58]: `>` is not supported for operand types `Optional[int]`
            #  and `Optional[int]`.
            if self.upto > self.topk:
                raise ValueError(f"Equiv length {equiv_len} cannot exceed topk={topk}.")
        self.gumbel_noise = Gumbel(0, 1.0 / shape)
        self.log_scores = log_scores

    @staticmethod
    def select_indices(scores: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
        """Helper for scores[actions] that are also works for batched tensors"""
        if len(actions.shape) > 1:
            num_rows = scores.size(0)
            row_indices = torch.arange(num_rows).unsqueeze(0).T  # pyre-ignore[ 16 ]
            return scores[row_indices, actions].T
        else:
            return scores[actions]

    def sample_action(self, scores: torch.Tensor) -> rlt.ActorOutput:
        """Sample a ranking according to Frechet sort. Note that possible_actions_mask
        is ignored as the list of rankings scales exponentially with slate size and
        number of items and it can be difficult to enumerate them."""
        assert scores.dim() == 2, "sample_action only accepts batches"
        log_scores = scores if self.log_scores else torch.log(scores)
        perturbed = log_scores + self.gumbel_noise.sample((scores.shape[1],))
        action = torch.argsort(perturbed.detach(), descending=True)
        if self.topk is not None:
            action = action[: self.topk]
        log_prob = self.log_prob(scores, action)
        return rlt.ActorOutput(action, log_prob)

    def log_prob(self, scores: torch.Tensor, action) -> torch.Tensor:
        """What is the probability of a given set of scores producing the given
        list of permutations only considering the top `equiv_len` ranks?"""
        log_scores = scores if self.log_scores else torch.log(scores)
        s = self.select_indices(log_scores, action)
        n = len(log_scores)
        p = self.upto if self.upto is not None else n
        return -sum(
            torch.log(torch.exp((s[k:] - s[k]) * self.shape).sum(dim=0))
            for k in range(p)  # pyre-ignore
        )
Esempio n. 9
0
    def forward(self, inputs, hidden_state):
        inputs = inputs.reshape(-1, self.input_shape)
        h_in = hidden_state.reshape(-1, self.hidden_dim)

        #compute latent
        self.latent = F.softmax(self.embed_fc(inputs[:self.n_agents, - self.n_agents:]),dim=1)  #(n, pi)

        latent_embed = self.latent.unsqueeze(0).expand(self.bs, self.n_agents, self.latent_dim).reshape(
            self.bs * self.n_agents, self.latent_dim) #(bs*n,pi)

        latent_infer = F.relu(self.inference_fc1(th.cat([h_in.detach(), inputs[:, :-self.n_agents]], dim=1)))
        latent_infer = F.softmax(self.inference_fc2(latent_infer),dim=1) # (bs*n,pi)

        #self.latent = self.embed_fc(inputs[:self.n_agents, - self.n_agents:])  # (n,2*latent_dim)==(n,mu+log var)
        #self.latent[:, -self.latent_dim:] = th.exp(self.latent[:, -self.latent_dim:])  # var
        #latent_embed = self.latent.unsqueeze(0).expand(self.bs, self.n_agents, self.latent_dim * 2).reshape(
        #    self.bs * self.n_agents, self.latent_dim * 2)

        #latent_infer = F.relu(self.inference_fc1(th.cat([h_in, inputs[:, :-self.n_agents]], dim=1)))
        #latent_infer = self.inference_fc2(latent_infer)  # (n,2*latent_dim)==(n,mu+log var)
        #latent_infer[:, -self.latent_dim:] = th.exp(latent_infer[:, -self.latent_dim:])
        # loss
        #loss=(latent_embed-latent_infer).norm(dim=1).sum()
        #loss= -(latent_embed * th.log(latent_infer+self.eps)).sum()/(self.bs*self.n_agents)
        loss = -(latent_infer * th.log(latent_embed + self.eps)).sum() / (self.bs * self.n_agents)
        #loss = -((latent_infer * th.log(latent_embed + self.eps)).sum() + 0.01*(latent_embed*th.log(latent_embed+self.eps)).sum())/ (self.bs * self.n_agents)

        # sample
        g=Gumbel(0.0,1.0)
        latent_embed = F.softmax(th.log(latent_embed+self.eps) + g.sample(latent_embed.size()).cuda(), dim=1)  # softmax onehot
        #latent_infer = F.softmax(th.log(latent_infer+self.eps) + g.sample(latent_infer.size()), dim=1)


        #gaussian_embed = D.Normal(latent_embed[:, :self.latent_dim], (latent_embed[:, self.latent_dim:])**(1/2))
        #gaussian_infer = D.Normal(latent_infer[:, :self.latent_dim], (latent_infer[:, self.latent_dim:])**(1/2))

        #loss = gaussian_embed.entropy().sum() + kl_divergence(gaussian_embed, gaussian_infer).sum()  # CE = H + KL
        #loss = loss / (self.bs*self.n_agents)
        # handcrafted reparameterization
        # (1,n*latent_dim)                            (1,n*latent_dim)==>(bs,n*latent*dim)
        # latent_embed = self.latent[:,:self.latent_dim].reshape(1,-1)+self.latent[:,-self.latent_dim:].reshape(1,-1)*th.randn(self.bs,self.n_agents*self.latent_dim)
        # latent_embed = latent_embed.reshape(-1,self.latent_dim)  #(bs*n,latent_dim)
        # latent_infer = latent_infer[:, :self.latent_dim] + latent_infer[:, -self.latent_dim:] * th.randn_like(latent_infer[:, -self.latent_dim:])
        # loss= (latent_embed-latent_infer).norm(dim=1).sum()/(self.bs*self.n_agents)

        #latent = gaussian_embed.rsample()
        latent=self.latent_fc1(latent_embed)

        #latent = F.relu(self.latent_fc1(latent))
        #latent = (self.latent_fc2(latent))

        # latent=latent.reshape(-1,self.args.latent_dim)

        # fc1_w=F.relu(self.fc1_w_nn(latent))
        # fc1_b=F.relu((self.fc1_b_nn(latent)))
        # fc1_w=fc1_w.reshape(-1,self.input_shape,self.args.rnn_hidden_dim)
        # fc1_b=fc1_b.reshape(-1,1,self.args.rnn_hidden_dim)

        # rnn_ih_w=F.relu(self.rnn_ih_w_nn(latent))
        # rnn_ih_b=F.relu(self.rnn_ih_b_nn(latent))
        # rnn_hh_w=F.relu(self.rnn_hh_w_nn(latent))
        # rnn_hh_b=F.relu(self.rnn_hh_b_nn(latent))
        # rnn_ih_w=rnn_ih_w.reshape(-1,self.args.rnn_hidden_dim,self.args.rnn_hidden_dim)
        # rnn_ih_b=rnn_ih_b.reshape(-1,1,self.args.rnn_hidden_dim)
        # rnn_hh_w = rnn_hh_w.reshape(-1, self.args.rnn_hidden_dim, self.args.rnn_hidden_dim)
        # rnn_hh_b = rnn_hh_b.reshape(-1, 1, self.args.rnn_hidden_dim)

        fc2_w = self.fc2_w_nn(latent)
        fc2_b = self.fc2_b_nn(latent)
        fc2_w = fc2_w.reshape(-1, self.args.rnn_hidden_dim, self.args.n_actions)
        fc2_b = fc2_b.reshape((-1, 1, self.args.n_actions))

        # x=F.relu(th.bmm(inputs,fc1_w)+fc1_b) #(bs*n,(obs+act+id)) at time t
        x = F.relu(self.fc1(inputs))  # (bs*n,(obs+act+id)) at time t

        # gi=th.bmm(x,rnn_ih_w)+rnn_ih_b
        # gh=th.bmm(h_in,rnn_hh_w)+rnn_hh_b
        # i_r,i_i,i_n=gi.chunk(3,2)
        # h_r,h_i,h_n=gh.chunk(3,2)

        # resetgate=th.sigmoid(i_r+h_r)
        # inputgate=th.sigmoid(i_i+h_i)
        # newgate=th.tanh(i_n+resetgate*h_n)
        # h=newgate+inputgate*(h_in-newgate)
        # h=th.tanh(gi+gh)

        # x=x.reshape(-1,self.args.rnn_hidden_dim)
        h = self.rnn(x, h_in)
        h = h.reshape(-1, 1, self.args.rnn_hidden_dim)

        q = th.bmm(h, fc2_w) + fc2_b

        # h_in = hidden_state.reshape(-1, self.args.rnn_hidden_dim) # (bs,n,dim) ==> (bs*n, dim)
        # h = self.rnn(x, h_in)
        # q = self.fc2(h)
        return q.view(-1, self.args.n_actions), h.view(-1, self.args.rnn_hidden_dim), loss
Esempio n. 10
0
class SoftmaxRandomSamplePolicy(SimplePolicy):
    '''
    Randomly samples from the softmax of the logits
    # https://hips.seas.harvard.edu/blog/2013/04/06/the-gumbel-max-trick-for-discrete-distributions/
    TODO: should probably switch that to something more like
    http://pytorch.org/docs/master/distributions.html
    '''
    def __init__(self, temperature=1.0, eps=0.0, do_entropy=True):
        '''

        :param bias: a vector of log frequencies, to bias the sampling towards what we want
        '''
        super().__init__()
        self.gumbel = Gumbel(loc=0, scale=1)
        self.temperature = torch.tensor(temperature, requires_grad=False)
        self.eps = eps
        self.do_entropy = do_entropy
        if self.do_entropy:
            self.entropy = None

    def set_temperature(self, new_temperature):
        if self.temperature != new_temperature:
            self.temperature = torch.tensor(new_temperature,
                                            dtype=self.temperature.dtype,
                                            device=self.temperature.device)

    def forward(self, logits: Variable, priors=None):
        '''

        :param logits: Logits to generate probabilities from, batch_size x out_dim float32
        :return:
        '''
        # epsilon-greediness
        if random.random() < self.eps:
            if priors is None:
                new_logits = torch.zeros_like(logits)
                new_logits[logits < logits.max()-1000] = -1e4
            else:
                new_logits = priors
            if self.do_entropy:
                self.entropy = torch.zeros_like(logits).sum(dim=1)
        else:
            if self.temperature.dtype != logits.dtype or self.temperature.device != logits.device:
                self.temperature = torch.tensor(self.temperature,
                                                dtype=logits.dtype,
                                                device=logits.device,
                                                requires_grad=False)
            # print('temperature: ', self.temperature)


            # temperature is applied to model only, not priors!
            if priors is None:
                new_logits = logits/self.temperature
                raw_logits = new_logits
            else:
                raw_logits = (logits-priors)/self.temperature
                new_logits = priors + raw_logits
            if self.do_entropy:
                raw_logits_normalized = F.log_softmax(raw_logits, dim=1)
                self.entropy = torch.sum(-raw_logits_normalized*torch.exp(raw_logits_normalized), dim=1)

        eff_logits = new_logits
        x = self.gumbel.sample(logits.shape).to(device=device, dtype=eff_logits.dtype) + eff_logits
        _, out = torch.max(x, -1)
        all_logp = F.log_softmax(eff_logits, dim=1)
        self.logp = torch.cat([this_logp[this_ind:(this_ind+1)] for this_logp, this_ind in zip(all_logp, out)])
        return out

    def effective_logits(self, logits):
        return logits