Exemple #1
0
    def _memory_net_forward(self, batch):
        """Forward pass for memory network to look up fact.

        1. Encodes fact via fact rnn.
        2. Computes attention with fact and utterance encoding.
        3. Attended fact vector and question encoding -> new encoding.

        Args:
          batch: Dict of hist, hist_len, hidden_state
        """
        # kwon : fact = prevuiys utterance + response concatenated as one
        # For example, 'What is the color of the couch? A : Red.'
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        batch_size, num_rounds, enc_time_steps = batch["fact"].shape
        all_ones = np.full((num_rounds, num_rounds), 1)
        fact_mask = np.triu(all_ones, 1)
        fact_mask = np.expand_dims(np.expand_dims(fact_mask, -1), 0)
        fact_mask = torch.BoolTensor(fact_mask)
        if self.params["use_gpu"]:
            fact_mask = fact_mask.cuda()
        fact_mask.requires_grad_(False)

        fact_in = support.flatten(batch["fact"], batch_size, num_rounds)
        fact_len = support.flatten(batch["fact_len"], batch_size, num_rounds)
        fact_embeds = self.word_embed_net(fact_in)

        # Encoder fact and unflatten the last hidden state.
        _, (hidden_state, _) = rnn.dynamic_rnn(self.fact_unit,
                                               fact_embeds,
                                               fact_len,
                                               return_states=True)
        fact_encode = support.unflatten(hidden_state[-1], batch_size,
                                        num_rounds)
        fact_encode = fact_encode.unsqueeze(1).expand(-1, num_rounds, -1, -1)

        utterance_enc = batch["utterance_enc"].unsqueeze(2)
        utterance_enc = utterance_enc.expand(-1, -1, num_rounds, -1)

        # Combine, compute attention, mask, and weight the fact encodings.
        combined_encode = torch.cat([utterance_enc, fact_encode], dim=-1)
        attention = self.fact_attention_net(combined_encode)
        attention.masked_fill_(fact_mask, float("-Inf"))
        attention = self.softmax(attention, dim=2)
        attended_fact = (attention * fact_encode).sum(2)
        return attended_fact
Exemple #2
0
    def forward(self, batch):
        """Forward pass through the encoder.

        Args:
            batch: Dict of batch variables.

        Returns:
            encoder_outputs: Dict of outputs from the forward pass.
        """
        encoder_out = {}
        # Flatten for history_agnostic encoder.
        batch_size, num_rounds, max_length = batch["user_utt"].shape
        encoder_in = support.flatten(batch["user_utt"], batch_size, num_rounds)
        import ipdb
        ipdb.set_trace(context=10)
        encoder_len = support.flatten(batch["user_utt_len"], batch_size,
                                      num_rounds)
        word_embeds_enc = self.word_embed_net(encoder_in)
        # Text encoder: LSTM or Transformer.
        if self.params["text_encoder"] == "lstm":
            all_enc_states, enc_states = rnn.dynamic_rnn(self.encoder_unit,
                                                         word_embeds_enc,
                                                         encoder_len,
                                                         return_states=True)
            encoder_out["hidden_states_all"] = all_enc_states
            encoder_out["hidden_state"] = enc_states

        elif self.params["text_encoder"] == "transformer":
            enc_embeds = self.pos_encoder(word_embeds_enc).transpose(0, 1)
            enc_pad_mask = batch["user_utt"] == batch["pad_token"]
            enc_pad_mask = support.flatten(enc_pad_mask, batch_size,
                                           num_rounds)
            enc_states = self.encoder_unit(enc_embeds,
                                           src_key_padding_mask=enc_pad_mask)
            encoder_out["hidden_states_all"] = enc_states.transpose(0, 1)
        return encoder_out
Exemple #3
0
    def forward(self, batch, encoder_output):
        """Forward pass through the decoder.

        Args:
            batch: Dict of batch variables.
            encoder_output: Dict of outputs from the encoder.

        Returns:
            decoder_outputs: Dict of outputs from the forward pass.
        """
        # Flatten for history_agnostic encoder.
        batch_size, num_rounds, max_length = batch["assist_in"].shape
        decoder_in = support.flatten(batch["assist_in"], batch_size,
                                     num_rounds)
        decoder_out = support.flatten(batch["assist_out"], batch_size,
                                      num_rounds)
        decoder_len = support.flatten(batch["assist_in_len"], batch_size,
                                      num_rounds)
        word_embeds_dec = self.word_embed_net(decoder_in)

        if self.params["encoder"] in self.DIALOG_CONTEXT_ENCODERS:
            dialog_context = support.flatten(encoder_output["dialog_context"],
                                             batch_size,
                                             num_rounds).unsqueeze(1)
            dialog_context = dialog_context.expand(-1, max_length, -1)
            decoder_steps_in = torch.cat([dialog_context, word_embeds_dec], -1)
        else:
            decoder_steps_in = word_embeds_dec

        # Encoder states conditioned on action outputs, if need be.
        if self.params["use_action_output"]:
            action_out = encoder_output["action_output_all"].unsqueeze(1)
            time_steps = encoder_output["hidden_states_all"].shape[1]
            fusion_out = torch.cat(
                [
                    encoder_output["hidden_states_all"],
                    action_out.expand(-1, time_steps, -1),
                ],
                dim=-1,
            )
            encoder_output["hidden_states_all"] = self.action_fusion_net(
                fusion_out)

        if self.params["text_encoder"] == "transformer":
            # Check the status of no_peek_mask.
            if self.no_peek_mask is None or self.no_peek_mask.size(
                    0) != max_length:
                self.no_peek_mask = self._generate_no_peek_mask(max_length)

            hidden_state = encoder_output["hidden_states_all"]
            enc_pad_mask = batch["user_utt"] == batch["pad_token"]
            enc_pad_mask = support.flatten(enc_pad_mask, batch_size,
                                           num_rounds)
            dec_pad_mask = batch["assist_in"] == batch["pad_token"]
            dec_pad_mask = support.flatten(dec_pad_mask, batch_size,
                                           num_rounds)
            if self.params["encoder"] != "pretrained_transformer":
                dec_embeds = self.pos_encoder(decoder_steps_in).transpose(0, 1)
                outputs = self.decoder_unit(
                    dec_embeds,
                    hidden_state.transpose(0, 1),
                    memory_key_padding_mask=enc_pad_mask,
                    tgt_mask=self.no_peek_mask,
                    tgt_key_padding_mask=dec_pad_mask,
                )
                outputs = outputs.transpose(0, 1)
            else:
                outputs = self.decoder_unit(
                    inputs_embeds=decoder_steps_in,
                    attention_mask=~dec_pad_mask,
                    encoder_hidden_states=hidden_state,
                    encoder_attention_mask=~enc_pad_mask,
                )
                outputs = outputs[0]
        else:
            hidden_state = encoder_output["hidden_state"]
            if self.params["encoder"] == "tf_idf":
                hidden_state = None

            # If Bahdahnue attention is to be used.
            if (self.params["use_bahdanau_attention"]
                    and self.params["encoder"] != "tf_idf"):
                encoder_states = encoder_output["hidden_states_all"]
                max_decoder_len = min(decoder_in.shape[1],
                                      self.params["max_decoder_len"])
                encoder_states_proj = self.attention_net(encoder_states)
                enc_mask = (
                    batch["user_utt"] == batch["pad_token"]).unsqueeze(-1)
                enc_mask = support.flatten(enc_mask, batch_size, num_rounds)
                outputs = []
                for step in range(max_decoder_len):
                    previous_state = hidden_state[0][-1].unsqueeze(1)
                    att_logits = previous_state * encoder_states_proj
                    att_logits = att_logits.sum(dim=-1, keepdim=True)
                    # Use encoder mask to replace <pad> with -Inf.
                    att_logits.masked_fill_(enc_mask, float("-Inf"))
                    att_wts = nn.functional.softmax(att_logits, dim=1)
                    context = (encoder_states * att_wts).sum(1, keepdim=True)
                    # Run through LSTM.
                    concat_in = [
                        context, decoder_steps_in[:, step:step + 1, :]
                    ]
                    step_in = torch.cat(concat_in, dim=-1)
                    decoder_output, hidden_state = self.decoder_unit(
                        step_in, hidden_state)
                    concat_out = torch.cat([decoder_output, context], dim=-1)
                    outputs.append(concat_out)
                outputs = torch.cat(outputs, dim=1)
            else:
                outputs = rnn.dynamic_rnn(
                    self.decoder_unit,
                    decoder_steps_in,
                    decoder_len,
                    init_state=hidden_state,
                )
        if self.params["encoder"] == "pretrained_transformer":
            output_logits = outputs
        else:
            # Logits over vocabulary.
            output_logits = self.inv_word_net(outputs)
        # Mask out the criterion while summing.
        pad_mask = support.flatten(batch["assist_mask"], batch_size,
                                   num_rounds)
        loss_token = self.criterion(output_logits.transpose(1, 2), decoder_out)
        loss_token.masked_fill_(pad_mask, 0.0)
        return {"loss_token": loss_token, "pad_mask": pad_mask}