Example #1
0
def sample(batch_size, max_steps=50, params=None):
    """Sample a batch of episodes from the MDP
    """
    ep_returns = []
    if params is None:
        params = pi

    batch_states = torch.zeros([batch_size, max_steps + 1])
    batch_pi_taken = torch.zeros([batch_size, max_steps])
    batch_a_taken = torch.zeros([batch_size, max_steps])
    batch_r = torch.zeros([batch_size, max_steps])
    batch_dones = torch.zeros([batch_size, max_steps])
    batch_mask = torch.ones([batch_size, max_steps + 1])

    state = Categorical(S0.unsqueeze(0).repeat(batch_size, 1)).sample()

    for t in range(max_steps):
        r = R[state]
        probs = params[:, state].softmax(0).transpose(0, 1)
        actions = Categorical(probs).sample()
        pi_taken = probs.gather(1, actions.unsqueeze(1)).squeeze(1)
        p_next = P[actions, state]
        next_state = Categorical(p_next).sample()

        batch_states[:, t] = state
        batch_pi_taken[:, t] = pi_taken
        batch_r[:, t] = r
        state = next_state
    batch_states[:, -1] = state

    return batch_states, batch_pi_taken, batch_a_taken, batch_r, batch_dones, batch_mask, ep_returns
Example #2
0
    def top_k(self, out, k=5):
        probs = self.softmax(out)
        sorted_probs, sorted_idxs = torch.sort(probs, descending=True)

        sorted_probs[:, k:] = 0
        sorted_probs /= sorted_probs.sum(dim=1, keepdim=True)

        sample = Categorical(sorted_probs).sample()
        sample_id = sorted_idxs.gather(1, sample.unsqueeze(1)).squeeze(1)
        return sample_id
Example #3
0
    def top_p(self, out, p=0.9):
        probs = self.softmax(out)
        sorted_probs, sorted_idxs = torch.sort(probs, descending=True)

        cumulative_probs = torch.cumsum(sorted_probs, dim=1)
        sorted_probs[cumulative_probs > p] = 0.
        sorted_probs /= sorted_probs.sum(dim=1, keepdim=True)

        sample = Categorical(sorted_probs).sample()
        sample_id = sorted_idxs.gather(1, sample.unsqueeze(1)).squeeze(1)
        return sample_id
Example #4
0
    def calculate_token_gumbel_softmax(self, p, tau, sentence_probability, batch_size):
        if self.training:
            token = self.utils_helper.calculate_gumbel_softmax(p, tau, hard=True)
        else:
            sentence_probability += p.detach()

            if self.greedy:
                _, token = torch.max(p, -1)
            else:
                token = Categorical(p).sample()
            token = to_one_hot(token, n_dims=self.vocab_size)

            if batch_size == 1:
                token = token.unsqueeze(0)
        return token, sentence_probability
Example #5
0
    def forward(self, hidden_state=None, messages=None, tau=1.2):
        """
        Merged version of Sender and Receiver
        """

        if messages is None:
            hidden_state = self.input_module(hidden_state)
            state, batch_size = self._init_state(hidden_state, type(self.rnn))

            # Init output
            if self.training:
                output = [
                    torch.zeros(
                        (batch_size, self.vocab_size),
                        dtype=torch.float32,
                        device=self.device,
                    )
                ]
                output[0][:, self.sos_id] = 1.0
            else:
                output = [
                    torch.full(
                        (batch_size, ),
                        fill_value=self.sos_id,
                        dtype=torch.int64,
                        device=self.device,
                    )
                ]

            # Keep track of sequence lengths
            initial_length = self.output_len + 1  # add the sos token
            seq_lengths = (torch.ones(
                [batch_size], dtype=torch.int64, device=self.device) *
                           initial_length)

            embeds = []  # keep track of the embedded sequence
            entropy = 0.0
            sentence_probability = torch.zeros((batch_size, self.vocab_size),
                                               device=self.device)

            for i in range(self.output_len):
                if self.training:
                    emb = torch.matmul(output[-1], self.embedding)
                else:
                    emb = self.embedding[output[-1]]

                embeds.append(emb)
                state = self.rnn(emb, state)

                if type(self.rnn) is nn.LSTMCell:
                    h, c = state
                else:
                    h = state

                p = F.softmax(self.linear_out(h), dim=1)
                entropy += Categorical(p).entropy()

                if self.training:
                    token = self.utils_helper.calculate_gumbel_softmax(
                        p, tau, hard=True)
                else:
                    sentence_probability += p.detach()
                    if self.greedy:
                        _, token = torch.max(p, -1)

                    else:
                        token = Categorical(p).sample()

                    if batch_size == 1:
                        token = token.unsqueeze(0)

                output.append(token)

                self._calculate_seq_len(seq_lengths,
                                        token,
                                        initial_length,
                                        seq_pos=i + 1)

            return (
                torch.stack(output, dim=1),
                seq_lengths,
                torch.mean(entropy) / self.output_len,
                torch.stack(embeds, dim=1),
                sentence_probability,
            )

        else:
            batch_size = messages.shape[0]

            emb = (torch.matmul(messages, self.embedding)
                   if self.training else self.embedding[messages])

            # initialize hidden
            h = torch.zeros([batch_size, self.hidden_size], device=self.device)
            if self.cell_type == "lstm":
                c = torch.zeros([batch_size, self.hidden_size],
                                device=self.device)
                h = (h, c)

            # make sequence_length be first dim
            seq_iterator = emb.transpose(0, 1)
            for w in seq_iterator:
                h = self.rnn(w, h)

            if self.cell_type == "lstm":
                h = h[0]  # keep only hidden state

            out = self.output_module(h)

            return out, emb
    def forward(self, hidden_state, tau=1.2):
        """
        Performs a forward pass. If training, use Gumbel Softmax (hard) for sampling, else use
        discrete sampling.
        Hidden state here represents the encoded image/metadata - initializes the RNN from it.
        """

        hidden_state = self.input_module(hidden_state)
        state, batch_size = self._init_state(hidden_state, type(self.rnn))

        # Init output
        # we apply self.vocab_size + 1 due to the extra sos token.
        # the agent should not have access to it so it should not be found in the vocab
        if self.training:
            output = [
                torch.zeros((batch_size, self.full_vocab_size),
                            dtype=torch.float32,
                            device=device)
            ]
            output[0][:, self.sos_id] = 1.0
        else:
            output = [
                torch.full(
                    (batch_size, ),
                    fill_value=self.sos_id,
                    dtype=torch.int64,
                    device=device,
                )
            ]

        # Keep track of sequence lengths
        initial_length = self.output_len + 1  # add the sos token
        seq_lengths = (
            torch.ones([batch_size], dtype=torch.int64, device=device) *
            initial_length)

        embeds = []  # keep track of the embedded sequence
        entropy = 0.0

        # loop through the entire output length
        for i in range(self.output_len):

            # matmul only on training since we use one hot vector during training and
            # index values during validation. We take the last character output of the RNN and use it
            # as input for the next character
            if self.training:
                emb = torch.matmul(output[-1], self.embedding)
            else:
                emb = self.embedding[output[-1]]

            # feed the embedded token to the RNN
            embeds.append(emb)
            state = self.rnn(emb, state)

            if type(self.rnn) is nn.LSTMCell:
                h, c = state
            else:
                h = state

            # get a probability for a given token from the vocabulary
            p = F.softmax(self.linear_out(h), dim=1)
            entropy += Categorical(p).entropy()

            # gumbel softmax returns one hot vectors
            if self.training:
                token = _gumbel_softmax(p, tau, hard=True)

                # add the start of string and padding index to the token in the form of a 0
                # we insert 0 because agents can never choose sos or pad tokens
                sos_index = torch.zeros(batch_size, 2)
                token = torch.cat((token, sos_index), dim=1)

            else:
                # during validation we return index values of vocabulary
                if self.greedy:
                    _, token = torch.max(p, -1)

                else:
                    token = Categorical(p).sample()

                if batch_size == 1:
                    token = token.unsqueeze(0)

            output.append(token)

            # calculate the sequence lengths for messages
            self._calculate_seq_len(seq_lengths,
                                    token,
                                    initial_length,
                                    seq_pos=i + 1)

        return (
            torch.stack(output, dim=1),
            seq_lengths,
            torch.mean(entropy) / self.output_len,
            hidden_state,
            torch.stack(embeds, dim=1),
        )
    def forward(self, tau=1.2, hidden_state=None, device=None):
        """
        Performs a forward pass. If training, use Gumbel Softmax (hard) for sampling, else use
        discrete sampling.
        Hidden state here represents the encoded image/metadata - initializes the RNN from it.
        """
        if device is None:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        hidden_state = self.input_module(hidden_state)
        state, batch_size = self._init_state(hidden_state, type(self.rnn), device)

        # Init output
        if self.training:
            output = [
                torch.zeros(
                    (batch_size, self.vocab_size), dtype=torch.float32, device=device
                )
            ]
            output[0][:, self.sos_id] = 1.0
        else:
            output = [
                torch.full(
                    (batch_size,),
                    fill_value=self.sos_id,
                    dtype=torch.int64,
                    device=device,
                )
            ]

        # Keep track of sequence lengths
        initial_length = self.output_len + 1  # add the sos token
        seq_lengths = (
            torch.ones([batch_size], dtype=torch.int64, device=device) * initial_length
        )

        embeds = []  # keep track of the embedded sequence
        entropy = 0.0
        sentence_probability = torch.zeros((batch_size, self.vocab_size), device=device)

        for i in range(self.output_len):
            if self.training:
                emb = torch.matmul(output[-1], self.embedding)
            else:
                emb = self.embedding[output[-1]]

            embeds.append(emb)
            state = self.rnn(emb, state)

            if type(self.rnn) is nn.LSTMCell:
                h, c = state
            else:
                h = state

            p = F.softmax(self.linear_out(h), dim=1)
            entropy += Categorical(p).entropy()

            if self.training:
                token = gumbel_softmax(p, tau, hard=True)
            else:
                sentence_probability += p.detach()
                if self.greedy:
                    _, token = torch.max(p, -1)

                else:
                    token = Categorical(p).sample()

                if batch_size == 1:
                    token = token.unsqueeze(0)

            output.append(token)

            self._calculate_seq_len(seq_lengths, token, initial_length, seq_pos=i + 1)

        return (
            torch.stack(output, dim=1),
            seq_lengths,
            torch.mean(entropy) / self.output_len,
            torch.stack(embeds, dim=1),
            sentence_probability,
        )
Example #8
0
def sample_mdn(mu, sigma, pi):
    idx = Categorical(logits=pi).sample()
    return torch.normal(mu, sigma).gather(-1, idx.unsqueeze(-1)).squeeze(-1)
Example #9
0
    def forward(self, tau, hidden_state=None):
        """
		Performs a forward pass. If training, use Gumbel Softmax (hard) for sampling, else use 
		discrete sampling.
		"""

        state, batch_size = self._init_state(hidden_state, type(self.rnn))

        # Init output
        if self.training:
            output = [
                torch.zeros((batch_size, self.vocab_size),
                            dtype=torch.float32,
                            device=device)
            ]
            output[0][:, self.sos_id] = 1.0
        else:
            output = [
                torch.full((batch_size, ),
                           fill_value=self.sos_id,
                           dtype=torch.int64,
                           device=device)
            ]

        # Keep track of sequence lengths
        if self.compute_lengths:
            n_sos_symbols = 1
            initial_length = self.output_len + n_sos_symbols
            seq_lengths = torch.ones([batch_size],
                                     dtype=torch.int64,
                                     device=device) * initial_length

        for i in range(self.output_len):
            if self.training:
                emb = torch.matmul(output[-1], self.embedding)
            else:
                emb = self.embedding[output[-1]]

            state = self.rnn(emb, state)

            if type(self.rnn) is nn.LSTMCell:
                h, c = state
            else:
                h = state

            p = F.softmax(self.linear_out(h), dim=1)

            if self.training:
                token = gumbel_softmax(p, tau, hard=True)
            else:
                if self.greedy:
                    _, token = torch.max(p, -1)
                else:
                    token = Categorical(p).sample()

                if batch_size == 1:
                    token = token.unsqueeze(0)

            output.append(token)

            if self.compute_lengths:
                self._calculate_seq_len(seq_lengths,
                                        token,
                                        initial_length,
                                        seq_pos=i + 1,
                                        n_sos_symbols=n_sos_symbols,
                                        is_discrete=not self.training)

        outputs = torch.stack(output, dim=1)

        if self.compute_lengths:
            return (outputs, seq_lengths)
        else:
            return outputs
    def forward(self, image_representation, messages=None, tau=1.2, device=None):
        """
        Performs a forward pass. If training, use Gumbel Softmax (hard) for sampling, else use
        discrete sampling.
        Hidden state here represents the encoded image/metadata/features - initializes the RNN from it.
        """

        if device is None:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        image_representation = self.input_module(image_representation)

        # receiver role
        if messages is None:

            # initialize the rnn using the obverter module
            state, batch_size = self._init_state(
                image_representation, type(self.rnn), device
            )

            # Init output
            if self.training:
                output = [
                    torch.zeros(
                        (batch_size, self.vocab_size),
                        dtype=torch.float32,
                        device=device,
                    )
                ]
                output[0][:, self.sos_id] = 1.0
            else:
                output = [
                    torch.full(
                        (batch_size,),
                        fill_value=self.sos_id,
                        dtype=torch.int64,
                        device=device,
                    )
                ]

            # Keep track of sequence lengths
            initial_length = self.output_len + 1  # add the sos token
            seq_lengths = (
                torch.ones([batch_size], dtype=torch.int64, device=device)
                * initial_length
            )

            embeds = []  # keep track of the embedded sequence
            entropy = 0.0

            sentence_probability = torch.zeros(
                (batch_size, self.vocab_size), device=device
            )

            for i in range(self.output_len):
                if self.training:
                    emb = torch.matmul(output[-1], self.embedding)
                else:
                    emb = self.embedding[output[-1]]

                embeds.append(emb)
                state = self.rnn(emb, state)

                if type(self.rnn) is nn.LSTMCell:
                    h, c = state
                else:
                    h = state

                p = F.softmax(self.linear_out(h), dim=1)
                entropy += Categorical(p).entropy()

                if self.training:
                    token = gumbel_softmax(p, tau, hard=True)
                else:
                    sentence_probability += p.detach()
                    if self.greedy:
                        _, token = torch.max(p, -1)
                    else:
                        token = Categorical(p).sample()

                    if batch_size == 1:
                        token = token.unsqueeze(0)

                output.append(token)

                self._calculate_seq_len(
                    seq_lengths, token, initial_length, seq_pos=i + 1
                )

            return (
                torch.stack(output, dim=1),
                seq_lengths,
                torch.mean(entropy) / self.output_len,
                torch.stack(embeds, dim=1),
                sentence_probability,
            )
        else:
            batch_size = messages.shape[0]

            emb = (
                torch.matmul(messages, self.embedding)
                if self.training
                else self.embedding[messages]
            )

            # initialize hidden
            h = torch.zeros([batch_size, self.hidden_size], device=device)
            if self.cell_type == "lstm":
                c = torch.zeros([batch_size, self.hidden_size], device=device)
                h = (h, c)

            # make sequence_length be first dim
            seq_iterator = emb.transpose(0, 1)
            for w in seq_iterator:
                h = self.rnn(w, h)

            if self.cell_type == "lstm":
                h = h[0]  # keep only hidden state

            combined = torch.cat((h, image_representation), dim=1)
            prediction = self.output_layer(combined)

            return prediction, emb