Пример #1
0
    def forward(self, instructions_batch):
        token_lists, _ = instructions_batch
        batch_size = len(token_lists)
        dims = (self.num_layers, batch_size, self.hidden_dim)
        hidden = (Variable(cuda_tensor(torch.zeros(*dims)), requires_grad=False),
                  Variable(cuda_tensor(torch.zeros(*dims)), requires_grad=False))

        # pad text tokens with 0's
        text_lengths = np.array([len(tokens) for tokens in token_lists])
        tokens_batch = [[] for _ in range(batch_size)]
        for i in range(batch_size):
            num_zeros = text_lengths[0] - text_lengths[i]
            tokens_batch[i] = token_lists[i] + [0] * num_zeros
        tokens_batch = cuda_var(torch.from_numpy(np.array(tokens_batch)))

        # swap so batch dimension is second, sequence dimension is first
        tokens_batch = tokens_batch.transpose(0, 1)
        emb_sentence = self.embedding(tokens_batch)
        packed_input = pack_padded_sequence(emb_sentence, text_lengths)
        lstm_out_packed, _ = self.lstm(packed_input, hidden)
        # return average output embedding
        lstm_out, seq_lengths = pad_packed_sequence(lstm_out_packed)
        lstm_out = lstm_out.transpose(0, 1)
        sum_emb_list = []
        for i, seq_out in enumerate(lstm_out):
            seq_len = seq_lengths[i]
            sum_emb = torch.sum(seq_out[:seq_len], 0) / seq_len
            sum_emb_list.append(sum_emb.view(1, -1))
        return torch.cat(sum_emb_list)
Пример #2
0
    def forward_old(self, image_seq_batch, seq_lengths):
        b, n, d = image_seq_batch.data.shape
        lengths = [(l, i) for i, l in enumerate(seq_lengths.cpu().numpy())]
        lengths.sort(reverse=True)
        sort_idx = [i for _, i in lengths]
        sort_idx_reverse = [sort_idx.index(i) for i in range(len(sort_idx))]

        image_seq_list = [images for images in image_seq_batch]
        image_seq_batch = torch.cat(
            [image_seq_list[i].view(1, n, d) for i in sort_idx])
        lengths_np = np.array([l for l, _ in lengths])

        batch_size = int(image_seq_batch.data.shape[0])
        dims = (self.num_layers, batch_size, self.output_emb_dim)
        hidden = (Variable(cuda_tensor(torch.zeros(*dims)),
                           requires_grad=False),
                  Variable(cuda_tensor(torch.zeros(*dims)),
                           requires_grad=False))

        # swap so batch dimension is second, sequence dimension is first
        image_seq_batch = image_seq_batch.transpose(0, 1)
        packed_input = pack_padded_sequence(image_seq_batch, lengths_np)
        lstm_out_packed, _ = self.lstm(packed_input, hidden)
        # return average output embedding
        lstm_out, seq_lengths = pad_packed_sequence(lstm_out_packed)
        lstm_out = lstm_out.transpose(0, 1)
        final_vectors = [
            lstm_out[i][int(seq_len) - 1]
            for i, seq_len in enumerate(seq_lengths)
        ]
        final_vectors = [final_vectors[i] for i in sort_idx_reverse]
        return torch.cat(
            [vec.view(1, self.output_emb_dim) for vec in final_vectors])
Пример #3
0
    def forward(self, instructions_batch):
        token_lists, text_pointers = instructions_batch
        batch_size = len(token_lists)
        text_lengths = np.array([len(tokens) for tokens in token_lists])
        dims = (self.num_layers, batch_size, self.hidden_dim)
        hidden_f = (Variable(cuda_tensor(torch.zeros(*dims)),
                             requires_grad=False),
                    Variable(cuda_tensor(torch.zeros(*dims)),
                             requires_grad=False))
        hidden_b = (Variable(cuda_tensor(torch.zeros(*dims)),
                             requires_grad=False),
                    Variable(cuda_tensor(torch.zeros(*dims)),
                             requires_grad=False))

        # pad text tokens with 0's
        tokens_batch_f = [[] for _ in xrange(batch_size)]
        tokens_batch_b = [[] for _ in xrange(batch_size)]
        for i in xrange(batch_size):
            num_zeros = text_lengths[0] - text_lengths[i]
            tokens_batch_f[i] = token_lists[i] + [0] * num_zeros
            tokens_batch_b[i] = token_lists[i][::-1] + [0] * num_zeros
        tokens_batch_f = cuda_var(torch.from_numpy(np.array(tokens_batch_f)))
        tokens_batch_b = cuda_var(torch.from_numpy(np.array(tokens_batch_b)))

        # swap so batch dimension is second, sequence dimension is first
        tokens_batch_f = tokens_batch_f.transpose(0, 1)
        tokens_batch_b = tokens_batch_b.transpose(0, 1)
        emb_sentence_f = self.embedding(tokens_batch_f)
        emb_sentence_b = self.embedding(tokens_batch_b)
        packed_input_f = pack_padded_sequence(emb_sentence_f, text_lengths)
        packed_input_b = pack_padded_sequence(emb_sentence_b, text_lengths)
        lstm_out_packed_f, _ = self.lstm_f(packed_input_f, hidden_f)
        lstm_out_packed_b, _ = self.lstm_b(packed_input_b, hidden_b)

        # return average output embedding
        lstm_out_f, _ = pad_packed_sequence(lstm_out_packed_f)
        lstm_out_b, _ = pad_packed_sequence(lstm_out_packed_b)
        lstm_out_f = lstm_out_f.transpose(0, 1)
        lstm_out_b = lstm_out_b.transpose(0, 1)
        embeddings_list = []
        for i, (start_i, end_i) in enumerate(text_pointers):
            embeddings = []
            if start_i > 0:
                embeddings.append(lstm_out_f[i][start_i - 1])
            else:
                embeddings.append(cuda_var(torch.zeros(self.hidden_dim)))
            embeddings.append(lstm_out_f[i][end_i - 1])
            embeddings.append(lstm_out_b[i][start_i])
            if end_i < text_lengths[i]:
                embeddings.append(lstm_out_b[i][end_i])
            else:
                embeddings.append(cuda_var(torch.zeros(self.hidden_dim)))
            embeddings_list.append(torch.cat(embeddings).view(1, -1))

        embeddings_batch = torch.cat(embeddings_list)
        return embeddings_batch
    def forward(self, input_vector, hidden_vectors):
        """
        @param image_vector: batch of sequence of image embedding
        @param hidden_vectors: hidden vectors for each batch """

        if hidden_vectors is None:
            dims = (1, self.output_emb_dim)
            hidden_vectors = (Variable(cuda_tensor(torch.zeros(*dims)),
                                       requires_grad=False),
                              Variable(cuda_tensor(torch.zeros(*dims)),
                                       requires_grad=False))
        new_hidden_vector = self.lstm(input_vector, hidden_vectors)

        return new_hidden_vector
Пример #5
0
    def get_probs(self, agent_observed_state, model_state, mode=None, volatile=False):

        assert isinstance(agent_observed_state, AgentObservedState)
        agent_observed_state_list = [agent_observed_state]

        image_seq_lens = [1]
        image_seq_lens_batch = cuda_tensor(
            torch.from_numpy(np.array(image_seq_lens)))
        # max_len = max(image_seq_lens)
        # image_seqs = [aos.get_image()[:max_len]
        #               for aos in agent_observed_state_list]
        image_seqs = [[aos.get_last_image()]
                      for aos in agent_observed_state_list]
        image_batch = cuda_var(torch.from_numpy(np.array(image_seqs)).float(), volatile)

        instructions = [aos.get_instruction()
                        for aos in agent_observed_state_list]
        read_pointers = [aos.get_read_pointers()
                         for aos in agent_observed_state_list]
        instructions_batch = (instructions, read_pointers)

        prev_actions_raw = [aos.get_previous_action()
                            for aos in agent_observed_state_list]
        prev_actions = [self.none_action if a is None else a
                        for a in prev_actions_raw]
        prev_actions_batch = cuda_var(torch.from_numpy(np.array(prev_actions)), volatile)

        probs_batch, new_model_state, image_emb_seq, state_feature = self.final_module(
            image_batch, image_seq_lens_batch, instructions_batch, prev_actions_batch, mode, model_state)
        return probs_batch, new_model_state, image_emb_seq, state_feature
Пример #6
0
    def forward(self, image, image_lens, instructions, prev_action, mode, model_state):

        image_emb_seq = self.image_module(image)
        num_states = image_emb_seq.size()[0]
        image_emb = image_emb_seq.view(num_states, -1)

        if model_state is None:
            text_emb = self.text_module(instructions)
            image_hidden_states = None
            dims = (num_states, self.image_recurrence_module.output_emb_dim)
            prev_image_memory_emb = Variable(cuda_tensor(torch.zeros(*dims)), requires_grad=False)
        else:
            text_emb, image_hidden_states, prev_image_memory_emb = model_state

        new_image_memory_emb, new_image_hidden_states = \
            self.image_recurrence_module(image_emb_seq, image_lens, image_hidden_states)

        new_model_state = (text_emb, new_image_hidden_states, new_image_memory_emb)
        action_emb = self.action_module(prev_action)
        x = torch.cat([prev_image_memory_emb, image_emb, text_emb, action_emb], dim=1)
        x = F.leaky_relu(self.dense1(x))
        if mode is None or mode == ReadPointerAgent.ACT_MODE:
            return F.log_softmax(self.dense2(x)), new_model_state, image_emb_seq, x
        elif mode == ReadPointerAgent.READ_MODE:
            return F.log_softmax(self.dense_read(x)), new_model_state, image_emb_seq, x
        else:
            raise ValueError("invalid mode for model: %r" % mode)
Пример #7
0
    def get_probs_symbolic_text(self, agent_observed_state, symbolic_text, model_state, mode=None, volatile=False):
        """ Same as get_probs instead forces the model to use the given symbolic text """

        assert isinstance(agent_observed_state, AgentObservedState)
        agent_observed_state_list = [agent_observed_state]

        image_seq_lens = [1]
        image_seq_lens_batch = cuda_tensor(
            torch.from_numpy(np.array(image_seq_lens)))
        image_seqs = [[aos.get_last_image()]
                      for aos in agent_observed_state_list]
        image_batch = cuda_var(torch.from_numpy(np.array(image_seqs)).float(), volatile)

        instructions_batch = [symbolic_text]

        prev_actions_raw = [aos.get_previous_action()
                            for aos in agent_observed_state_list]
        prev_actions = [self.none_action if a is None else a
                        for a in prev_actions_raw]
        prev_actions_batch = cuda_var(torch.from_numpy(np.array(prev_actions)), volatile)

        probs_batch, new_model_state, image_emb_seq, state_feature = self.final_module(image_batch, image_seq_lens_batch,
                                                                        instructions_batch, prev_actions_batch,
                                                                        mode, model_state)
        return probs_batch, new_model_state, image_emb_seq, state_feature
    def get_probs_batch(self, agent_observed_state_list, mode=None):
        for aos in agent_observed_state_list:
            assert isinstance(aos, AgentObservedState)
        # print "batch size:", len(agent_observed_state_list)

        # sort list by instruction length
        agent_observed_state_list = sorted(
            agent_observed_state_list,
            key=lambda aos_: len(aos_.get_instruction()),
            reverse=True
        )

        image_seq_lens = [aos.get_num_images()
                          for aos in agent_observed_state_list]
        image_seq_lens_batch = cuda_tensor(
            torch.from_numpy(np.array(image_seq_lens)))
        max_len = max(image_seq_lens)
        image_seqs = [aos.get_image()[:max_len]
                      for aos in agent_observed_state_list]
        image_batch = cuda_var(torch.from_numpy(np.array(image_seqs)).float())

        instructions_batch = [aos.get_symbolic_instruction()
                              for aos in agent_observed_state_list]

        prev_actions_raw = [aos.get_previous_action()
                            for aos in agent_observed_state_list]
        prev_actions = [self.none_action if a is None else a
                        for a in prev_actions_raw]
        prev_actions_batch = cuda_var(torch.from_numpy(np.array(prev_actions)))

        probs_batch = self.final_module(image_batch, image_seq_lens_batch,
                                        instructions_batch, prev_actions_batch,
                                        mode)
        return probs_batch
    def get_probs(self, agent_observed_state, model_state, mode=None):

        assert isinstance(agent_observed_state, AgentObservedState)
        agent_observed_state_list = [agent_observed_state]

        image_seq_lens = [1]
        image_seq_lens_batch = cuda_tensor(
            torch.from_numpy(np.array(image_seq_lens)))
        image_seqs = [[aos.get_last_image()]
                      for aos in agent_observed_state_list]
        image_batch = cuda_var(torch.from_numpy(np.array(image_seqs)).float())

        goal_image_seqs = [[aos.get_goal_image()] for aos in agent_observed_state_list]
        goal_image_batch = cuda_var(torch.from_numpy(np.array(goal_image_seqs)).float())

        prev_actions_raw = [aos.get_previous_action()
                            for aos in agent_observed_state_list]
        prev_actions = [self.none_action if a is None else a
                        for a in prev_actions_raw]
        prev_actions_batch = cuda_var(torch.from_numpy(np.array(prev_actions)))

        probs_batch, new_model_state, image_emb_seq = self.final_module(image_batch, image_seq_lens_batch,
                                                                        goal_image_batch, prev_actions_batch,
                                                                        mode, model_state)
        return probs_batch, new_model_state, image_emb_seq
    def forward_1(self, instructions, prev_instructions, next_instructions):
        # Assume there is only 1 instruction
        text_emb = self.text_module(instructions)
        if prev_instructions[0][0] is None:
            prev_text_emb = Variable(cuda_tensor(torch.zeros(text_emb.size())), requires_grad=False)
        else:
            prev_text_emb = self.text_module(prev_instructions)

        if next_instructions[0][0] is None:
            next_text_emb = Variable(cuda_tensor(torch.zeros(text_emb.size())), requires_grad=False)
        else:
            next_text_emb = self.text_module(next_instructions)
        x = torch.cat([prev_text_emb, text_emb, next_text_emb], dim=1)
        x = F.relu(self.dense_1(x))
        return F.log_softmax(self.dense_landmark(x)), \
               F.log_softmax(self.dense_theta_1(x)), \
               F.log_softmax(self.dense_theta_2(x)), \
               F.log_softmax(self.dense_r(x))
    def forward(self, image, image_lens, instructions, prev_action, mode,
                model_state):

        image_emb_seq = self.image_module(image)
        num_states = image_emb_seq.size()[0]
        image_emb = image_emb_seq.view(num_states, -1)

        if model_state is None:
            text_emb = self.text_module(instructions)
            image_hidden_states = None
            dims = (num_states, self.image_recurrence_module.output_emb_dim)
            prev_image_memory_emb = Variable(cuda_tensor(torch.zeros(*dims)),
                                             requires_grad=False)
        else:
            text_emb, image_hidden_states, prev_image_memory_emb = model_state

        action_emb = self.action_module(prev_action)
        image_action_embedding = torch.cat([image_emb, action_emb], dim=1)
        image_action_embedding = image_action_embedding.view(num_states, 1, -1)

        # new_image_memory_emb, new_image_hidden_states = \
        #     self.image_recurrence_module(image_emb_seq, image_lens, image_hidden_states)
        new_image_memory_emb, new_image_hidden_states = \
            self.image_recurrence_module(image_action_embedding, image_lens, image_hidden_states)

        new_model_state = (text_emb, new_image_hidden_states,
                           new_image_memory_emb)
        x_input = torch.cat(
            [prev_image_memory_emb, image_emb, text_emb, action_emb], dim=1)

        x_1 = F.leaky_relu(self.dense1(x_input))
        x_2 = F.leaky_relu(self.dense2(torch.cat([x_input, x_1], dim=1)))
        x_3 = F.leaky_relu(self.dense3(torch.cat([x_input, x_1, x_2], dim=1)))

        if mode is None or mode == ReadPointerAgent.ACT_MODE:

            block_logits = self.dense_block(x_3)
            direction_logits = self.dense_direction(x_3)

            block_logprob = F.log_softmax(block_logits)  # 1 x num_block
            direction_logprob = F.log_softmax(
                direction_logits)  # 1 x num_direction

            action_logprob = block_logprob.transpose(
                0, 1) + direction_logprob[:, :4]  # num_block x num_direction
            action_logprob = action_logprob.view(1, -1)
            stop_logprob = direction_logprob[:, 4:5]

            action_logprob = torch.cat([action_logprob, stop_logprob], dim=1)

            # val = torch.clamp(val, min=-2, max=2)
            return action_logprob, new_model_state, image_emb_seq, x_3
        elif mode == ReadPointerAgent.READ_MODE:
            return F.log_softmax(
                self.dense_read(x_3)), new_model_state, image_emb_seq, x_3
        else:
            raise ValueError("invalid mode for model: %r" % mode)
    def forward(self, actions_batch, hidden_vectors):
        """
        @param image_seq_batch: batch of sequence of image embedding
        @param seq_lengths: length of the sequence for each sequence in the batch
        @param hidden_vectors: hidden vectors for each batch """

        b, d = actions_batch.data.shape
        actions_batch = actions_batch.view(b, 1, d)
        batch_size = int(actions_batch.data.shape[0])
        if hidden_vectors is None:
            dims = (self.num_layers, batch_size, self.output_emb_dim)
            hidden_vectors = (Variable(cuda_tensor(torch.zeros(*dims)), requires_grad=False),
                              Variable(cuda_tensor(torch.zeros(*dims)), requires_grad=False))

        # swap so batch dimension is second, sequence dimension is first
        actions_batch = actions_batch.view(1, b, d)
        lstm_out, new_hidden_vector = self.lstm(actions_batch, hidden_vectors)
        # return output embeddings
        return lstm_out.view(batch_size, -1), new_hidden_vector
Пример #13
0
    def forward(self, instructions_batch):
        token_lists, _ = instructions_batch
        batch_size = len(token_lists)
        text_lengths = np.array([len(tokens) for tokens in token_lists])
        dims = (self.num_layers, batch_size, self.hidden_dim)
        hidden_f = (Variable(cuda_tensor(torch.zeros(*dims)),
                             requires_grad=False),
                    Variable(cuda_tensor(torch.zeros(*dims)),
                             requires_grad=False))
        hidden_b = (Variable(cuda_tensor(torch.zeros(*dims)),
                             requires_grad=False),
                    Variable(cuda_tensor(torch.zeros(*dims)),
                             requires_grad=False))

        # pad text tokens with 0's
        tokens_batch_f = [[] for _ in xrange(batch_size)]
        tokens_batch_b = [[] for _ in xrange(batch_size)]
        for i in xrange(batch_size):
            num_zeros = text_lengths[0] - text_lengths[i]
            tokens_batch_f[i] = token_lists[i] + [0] * num_zeros
            tokens_batch_b[i] = token_lists[i][::-1] + [0] * num_zeros
        tokens_batch_f = cuda_var(torch.from_numpy(np.array(tokens_batch_f)))
        tokens_batch_b = cuda_var(torch.from_numpy(np.array(tokens_batch_b)))

        # swap so batch dimension is second, sequence dimension is first
        tokens_batch_f = tokens_batch_f.transpose(0, 1)
        tokens_batch_b = tokens_batch_b.transpose(0, 1)
        emb_sentence_f = self.embedding(tokens_batch_f)
        emb_sentence_b = self.embedding(tokens_batch_b)
        packed_input_f = pack_padded_sequence(emb_sentence_f, text_lengths)
        packed_input_b = pack_padded_sequence(emb_sentence_b, text_lengths)
        lstm_out_packed_f, _ = self.lstm_f(packed_input_f, hidden_f)
        lstm_out_packed_b, _ = self.lstm_b(packed_input_b, hidden_b)

        # return average output embedding
        lstm_out_f, seq_lengths = pad_packed_sequence(lstm_out_packed_f)
        lstm_out_b, _ = pad_packed_sequence(lstm_out_packed_b)
        # transpose again so batch dimension is first
        lstm_out_f = lstm_out_f.transpose(0, 1)
        lstm_out_b = lstm_out_b.transpose(0, 1)
        embeddings_list = []
        emb_len = self.hidden_dim * 2
        for i, seq_len in enumerate(seq_lengths):
            reverse_indices = torch.linspace(seq_len - 1, 0, seq_len).long()
            f_states = lstm_out_f[i][:seq_len]
            b_states = lstm_out_b[i].index_select(0, cuda_var(reverse_indices))
            joined_states = torch.cat([f_states, b_states], dim=1)
            key_input = torch.cat(
                [f_states[seq_len - 1], b_states[seq_len - 1]])
            mean_embedding_list = []
            # iterate over heads to produce each mean embedding
            for j in xrange(self.num_heads):
                dense_1 = self.key_layers_1[j]
                dense_2 = self.key_layers_2[j]
                weights = dense_2(F.tanh(dense_1(joined_states)))
                probs = F.softmax(weights.view(-1))
                probs_mask = probs.repeat(emb_len).view(emb_len,
                                                        seq_len).transpose(
                                                            0, 1)
                mean_state = (probs_mask * joined_states).sum(0)
                mean_embedding_list.append(mean_state)

            total_embedding = torch.cat(mean_embedding_list)
            embeddings_list.append(total_embedding.view(1, -1))

        embeddings_batch = torch.cat(embeddings_list)
        return embeddings_batch
Пример #14
0
    def forward(self, instructions_batch):
        token_lists, _ = instructions_batch
        batch_size = len(token_lists)
        text_lengths = np.array([len(tokens) for tokens in token_lists])
        dims = (self.num_layers, batch_size, self.hidden_dim)
        hidden_f = (Variable(cuda_tensor(torch.zeros(*dims)), requires_grad=False),
                    Variable(cuda_tensor(torch.zeros(*dims)), requires_grad=False))
        hidden_b = (Variable(cuda_tensor(torch.zeros(*dims)), requires_grad=False),
                    Variable(cuda_tensor(torch.zeros(*dims)), requires_grad=False))

        # pad text tokens with 0's
        tokens_batch_f = [[] for _ in xrange(batch_size)]
        tokens_batch_b = [[] for _ in xrange(batch_size)]
        for i in xrange(batch_size):
            num_zeros = text_lengths[0] - text_lengths[i]
            tokens_batch_f[i] = token_lists[i] + [0] * num_zeros
            tokens_batch_b[i] = token_lists[i][::-1] + [0] * num_zeros
        tokens_batch_f = cuda_var(torch.from_numpy(np.array(tokens_batch_f)))
        tokens_batch_b = cuda_var(torch.from_numpy(np.array(tokens_batch_b)))

        # swap so batch dimension is second, sequence dimension is first
        tokens_batch_f = tokens_batch_f.transpose(0, 1)
        tokens_batch_b = tokens_batch_b.transpose(0, 1)
        emb_sentence_f = self.embedding(tokens_batch_f)
        emb_sentence_b = self.embedding(tokens_batch_b)
        packed_input_f = pack_padded_sequence(emb_sentence_f, text_lengths)
        packed_input_b = pack_padded_sequence(emb_sentence_b, text_lengths)
        lstm_out_packed_f, _ = self.lstm_f(packed_input_f, hidden_f)
        lstm_out_packed_b, _ = self.lstm_b(packed_input_b, hidden_b)

        # return average output embedding
        lstm_out_f, seq_lengths = pad_packed_sequence(lstm_out_packed_f)
        lstm_out_b, _ = pad_packed_sequence(lstm_out_packed_b)
        # transpose again so batch dimension is first
        lstm_out_f = lstm_out_f.transpose(0, 1)
        lstm_out_b = lstm_out_b.transpose(0, 1)
        embeddings_list = []
        batch_mean_entropy = []
        emb_len = self.hidden_dim * 2
        for i, seq_len in enumerate(seq_lengths):
            reverse_indices = torch.linspace(seq_len - 1, 0, seq_len).long()
            f_states = lstm_out_f[i][:seq_len]
            b_states = lstm_out_b[i].index_select(0, cuda_var(reverse_indices))
            joined_states = torch.cat([f_states, b_states], dim=1)
            # key_input = torch.cat([f_states[seq_len - 1],
            #                        b_states[seq_len - 1]])
            mean_factor_list = []
            sum_entropy = None
            # iterate over heads to produce each mean embedding
            for j in xrange(self.num_factors):
                dense_1 = self.key_layers_1[j]
                dense_2 = self.key_layers_2[j]
                weights = dense_2(F.tanh(dense_1(joined_states)))
                probs = F.softmax(weights.view(-1))
                probs_mask = probs.repeat(emb_len).view(emb_len, seq_len).transpose(0, 1)
                mean_state = (probs_mask * joined_states).sum(0)

                # Mean state is of size Batch x Dimension
                factor_logit_weight = self.factors_logits_weights[j]
                factor_emb = factor_logit_weight(mean_state)    # Batch x factor-vocab
                probs = F.softmax(factor_emb)    # Batch x factor-vocab
                factor_embeddings = self.factors_vocabulary[j]   # factor-vocab x factor-dim
                mean_factor = factor_embeddings(probs)  # Batch x factor-dim
                factor_distribution_entropy = -torch.sum(probs * torch.log(probs))

                # Compute the mean entropy of the probs
                if sum_entropy is None:
                    sum_entropy = factor_distribution_entropy
                else:
                    sum_entropy += factor_distribution_entropy
                mean_factor_list.append(mean_factor)

            total_embedding = torch.cat(mean_factor_list)
            mean_entropy = sum_entropy / self.num_factors
            batch_mean_entropy.append(mean_entropy)
            embeddings_list.append(total_embedding.view(1, -1))

        embeddings_batch = torch.cat(embeddings_list)
        self.mean_factory_entropy = torch.mean(torch.cat(batch_mean_entropy))
        return embeddings_batch