Ejemplo n.º 1
0
    def forward(self, x):
        prev_h = [self.agent(x)]
        prev_h.extend(
            [torch.zeros_like(prev_h[0]) for _ in range(self.num_layers - 1)])
        prev_c = [torch.zeros_like(prev_h[0])
                  for _ in range(self.num_layers)]  # only used for LSTM

        input = torch.stack([self.sos_embedding] * x.size(0))

        symb_seq = []
        stop_seq = []
        symb_logits = []
        stop_logits = []
        symb_entropy = []
        stop_entropy = []

        for step in range(self.max_len):
            for i, layer in enumerate(self.cells):
                e_t = float(self.training) * (
                    self.noise_loc + self.noise_scale *
                    torch.randn_like(prev_h[0]).to(prev_h[0]))
                if isinstance(layer, nn.LSTMCell):
                    h_t, c_t = layer(input, (prev_h[i], prev_c[i]))
                    c_t = c_t + e_t
                    prev_c[i] = c_t
                else:
                    h_t = layer(input, prev_h[i])
                    h_t = h_t + e_t
                prev_h[i] = h_t
                input = h_t

            symb_probs = F.softmax(self.output_symbol(h_t), dim=1)
            stop_probs = torch.sigmoid(
                torch.squeeze(self.whether_to_stop(h_t), 1))
            symb_distr = Categorical(probs=symb_probs)
            stop_distr = Bernoulli(probs=stop_probs)
            symb = symb_distr.sample() if self.training else symb_probs.argmax(
                dim=1)
            stop = stop_distr.sample() if self.training else (
                stop_probs > 0.5).float()
            symb_logits.append(symb_distr.log_prob(symb))
            stop_logits.append(stop_distr.log_prob(stop))
            symb_entropy.append(symb_distr.entropy())
            stop_entropy.append(stop_distr.entropy())
            symb_seq.append(symb)
            stop_seq.append(stop)

            input = self.embedding(symb)

        symb_seq = torch.stack(symb_seq).permute(1, 0)
        stop_seq = torch.stack(stop_seq).permute(1, 0).long()
        symb_logits = torch.stack(symb_logits).permute(1, 0)
        stop_logits = torch.stack(stop_logits).permute(1, 0)
        symb_entropy = torch.stack(symb_entropy).permute(1, 0)
        stop_entropy = torch.stack(stop_entropy).permute(1, 0)

        logits = (symb_logits, stop_logits)
        entropy = (symb_entropy, stop_entropy)

        return symb_seq, stop_seq, logits, entropy
Ejemplo n.º 2
0
    def forward(self, *args, **kwargs):
        scores = self.agent(*args, **kwargs)

        distr = Bernoulli(logits=scores)
        entropy = distr.entropy().sum(dim=1)

        sample = distr.sample()

        return sample, scores, entropy
Ejemplo n.º 3
0
    def act(self, state):
        current_state = Tensor(state).unsqueeze(0)
        prob_per_action, values_per_action = self.brain.model(current_state)
        m = Bernoulli(prob_per_action)  # Categorical(prob_per_action)
        sampled_action = m.sample()
        log_prob = m.log_prob(Tensor(sampled_action))  # sampled_action.float()
        distribution_entropy = m.entropy().mean()
        action = int(sampled_action.item())

        return action, log_prob, values_per_action
Ejemplo n.º 4
0
class BernoulliDistribution(Distribution):
    """
    Bernoulli distribution for MultiBinary action spaces.

    :param action_dim: Number of binary actions
    """
    def __init__(self, action_dims: int):
        super(BernoulliDistribution, self).__init__()
        self.action_dims = action_dims

    def proba_distribution_net(self, latent_dim: int) -> nn.Module:
        """
        Create the layer that represents the distribution:
        it will be the logits of the Bernoulli distribution.

        :param latent_dim: Dimension of the last layer
            of the policy network (before the action layer)
        :return:
        """
        action_logits = nn.Linear(latent_dim, self.action_dims)
        return action_logits

    def proba_distribution(
            self, action_logits: th.Tensor) -> "BernoulliDistribution":
        self.distribution = Bernoulli(logits=action_logits)
        return self

    def log_prob(self, actions: th.Tensor) -> th.Tensor:
        return self.distribution.log_prob(actions).sum(dim=1)

    def entropy(self) -> th.Tensor:
        return self.distribution.entropy().sum(dim=1)

    def probabilities(self) -> th.Tensor:
        return self.distribution.probs.sum(dim=1)

    def sample(self) -> th.Tensor:
        return self.distribution.sample()

    def mode(self) -> th.Tensor:
        return th.round(self.distribution.probs)

    def actions_from_params(self,
                            action_logits: th.Tensor,
                            deterministic: bool = False) -> th.Tensor:
        # Update the proba distribution
        self.proba_distribution(action_logits)
        return self.get_actions(deterministic=deterministic)

    def log_prob_from_params(
            self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
        actions = self.actions_from_params(action_logits)
        log_prob = self.log_prob(actions)
        return actions, log_prob
Ejemplo n.º 5
0
    def forward(self, seq, mask):
        encoded = self.encoder(seq, mask)
        dist_params, actions = self.predictor(encoded)
        dist_params, actions = dist_params.t(), actions.t()
        sampler = Bernoulli(dist_params)
        # Compute LogProba
        log_probas = sampler.log_prob(actions)
        log_probas = apply_mask(log_probas, mask)

        # Compute Entropy
        entropy = sampler.entropy()
        entropy = apply_mask(log_probas, mask)

        return actions, log_probas, entropy, dist_params
Ejemplo n.º 6
0
        def forward(self, x):
            seq, batch = x.size(0), x.size(1)

            x = x.view(batch * seq, -1)
            params = self.net(x)
            params = params.view(seq, batch)

            sampler = Bernoulli(params)
            pred = sampler.sample()

            logits = sampler.log_prob(pred)
            entropy = sampler.entropy().sum(0)

            return pred, logits, entropy, params
Ejemplo n.º 7
0
    def forward(self, x, mask):
        self.x_sizes = x.size()
        encoded = self.encode(x)
        decoded = self.decode(encoded)
        dist_params, actions = self.predict(decoded)
        self.x_sizes = None
        sampler = Bernoulli(dist_params)
        # Compute LogProba
        log_probas = sampler.log_prob(actions)
        log_probas = apply_mask(log_probas, mask)

        # Compute Entropy
        entropy = sampler.entropy()
        entropy = apply_mask(log_probas, mask)

        return actions, log_probas, entropy, dist_params
Ejemplo n.º 8
0
    def forward(self, embedded_message, bits, _aux_input=None):
        embedded_bits = self.emb_column(bits.float())

        x = torch.cat([embedded_bits, embedded_message], dim=1)
        x = self.fc1(x)
        x = F.leaky_relu(x)
        x = self.fc2(x)

        probs = x.sigmoid()

        distr = Bernoulli(probs=probs)
        entropy = distr.entropy()

        if self.training:
            sample = distr.sample()
        else:
            sample = (probs > 0.5).float()
        log_prob = distr.log_prob(sample).sum(dim=1)
        return sample, log_prob, entropy
Ejemplo n.º 9
0
class CommonDistribution:
    def __init__(self, intent_probs, slot_sigms):
        self.cd = Categorical(intent_probs)
        self.bd = Bernoulli(slot_sigms)

    def sample(self):
        return self.cd.sample(), self.bd.sample()

    def log_prob(self, intent, slots):
        intent = intent.squeeze()
        cd_log_prob = self.cd.log_prob(intent).unsqueeze(1)
        bd_log_prob = self.bd.log_prob(slots)
        log_prob = torch.sum(torch.cat([cd_log_prob, bd_log_prob], dim=1), dim=1,)
        return log_prob

    def entropy(self):
        bd_entr = self.bd.entropy().mean(dim=1)
        cd_entr = self.cd.entropy()
        entr = bd_entr + cd_entr
        return entr
Ejemplo n.º 10
0
    def forward(self, *args, **kwargs):
        """Forward pass.

        Returns:
            sample {torch.Tensor} -- SFE sample.
                Size: [batch_size, n_bits]
            scores {torch.Tensor} -- the output of the network.
                Important to compute the policy component of the SFE loss.
                Size: [batch_size, n_bits]
            entropy {torch.Tensor} -- the entropy of the independent Bernoulli
                parameterized by the scores.
                Size: [batch_size]
        """
        scores = self.agent(*args, **kwargs)

        distr = Bernoulli(logits=scores)
        entropy = distr.entropy().sum(dim=1)

        sample = distr.sample()

        return sample, scores, entropy