def _memory_net_forward(self, batch): """Forward pass for memory network to look up fact. 1. Encodes fact via fact rnn. 2. Computes attention with fact and utterance encoding. 3. Attended fact vector and question encoding -> new encoding. Args: batch: Dict of hist, hist_len, hidden_state """ # kwon : fact = prevuiys utterance + response concatenated as one # For example, 'What is the color of the couch? A : Red.' device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') batch_size, num_rounds, enc_time_steps = batch["fact"].shape all_ones = np.full((num_rounds, num_rounds), 1) fact_mask = np.triu(all_ones, 1) fact_mask = np.expand_dims(np.expand_dims(fact_mask, -1), 0) fact_mask = torch.BoolTensor(fact_mask) if self.params["use_gpu"]: fact_mask = fact_mask.cuda() fact_mask.requires_grad_(False) fact_in = support.flatten(batch["fact"], batch_size, num_rounds) fact_len = support.flatten(batch["fact_len"], batch_size, num_rounds) fact_embeds = self.word_embed_net(fact_in) # Encoder fact and unflatten the last hidden state. _, (hidden_state, _) = rnn.dynamic_rnn(self.fact_unit, fact_embeds, fact_len, return_states=True) fact_encode = support.unflatten(hidden_state[-1], batch_size, num_rounds) fact_encode = fact_encode.unsqueeze(1).expand(-1, num_rounds, -1, -1) utterance_enc = batch["utterance_enc"].unsqueeze(2) utterance_enc = utterance_enc.expand(-1, -1, num_rounds, -1) # Combine, compute attention, mask, and weight the fact encodings. combined_encode = torch.cat([utterance_enc, fact_encode], dim=-1) attention = self.fact_attention_net(combined_encode) attention.masked_fill_(fact_mask, float("-Inf")) attention = self.softmax(attention, dim=2) attended_fact = (attention * fact_encode).sum(2) return attended_fact
def forward(self, batch): """Forward pass through the encoder. Args: batch: Dict of batch variables. Returns: encoder_outputs: Dict of outputs from the forward pass. """ encoder_out = {} # Flatten for history_agnostic encoder. batch_size, num_rounds, max_length = batch["user_utt"].shape encoder_in = support.flatten(batch["user_utt"], batch_size, num_rounds) import ipdb ipdb.set_trace(context=10) encoder_len = support.flatten(batch["user_utt_len"], batch_size, num_rounds) word_embeds_enc = self.word_embed_net(encoder_in) # Text encoder: LSTM or Transformer. if self.params["text_encoder"] == "lstm": all_enc_states, enc_states = rnn.dynamic_rnn(self.encoder_unit, word_embeds_enc, encoder_len, return_states=True) encoder_out["hidden_states_all"] = all_enc_states encoder_out["hidden_state"] = enc_states elif self.params["text_encoder"] == "transformer": enc_embeds = self.pos_encoder(word_embeds_enc).transpose(0, 1) enc_pad_mask = batch["user_utt"] == batch["pad_token"] enc_pad_mask = support.flatten(enc_pad_mask, batch_size, num_rounds) enc_states = self.encoder_unit(enc_embeds, src_key_padding_mask=enc_pad_mask) encoder_out["hidden_states_all"] = enc_states.transpose(0, 1) return encoder_out
def forward(self, batch, encoder_output): """Forward pass through the decoder. Args: batch: Dict of batch variables. encoder_output: Dict of outputs from the encoder. Returns: decoder_outputs: Dict of outputs from the forward pass. """ # Flatten for history_agnostic encoder. batch_size, num_rounds, max_length = batch["assist_in"].shape decoder_in = support.flatten(batch["assist_in"], batch_size, num_rounds) decoder_out = support.flatten(batch["assist_out"], batch_size, num_rounds) decoder_len = support.flatten(batch["assist_in_len"], batch_size, num_rounds) word_embeds_dec = self.word_embed_net(decoder_in) if self.params["encoder"] in self.DIALOG_CONTEXT_ENCODERS: dialog_context = support.flatten(encoder_output["dialog_context"], batch_size, num_rounds).unsqueeze(1) dialog_context = dialog_context.expand(-1, max_length, -1) decoder_steps_in = torch.cat([dialog_context, word_embeds_dec], -1) else: decoder_steps_in = word_embeds_dec # Encoder states conditioned on action outputs, if need be. if self.params["use_action_output"]: action_out = encoder_output["action_output_all"].unsqueeze(1) time_steps = encoder_output["hidden_states_all"].shape[1] fusion_out = torch.cat( [ encoder_output["hidden_states_all"], action_out.expand(-1, time_steps, -1), ], dim=-1, ) encoder_output["hidden_states_all"] = self.action_fusion_net( fusion_out) if self.params["text_encoder"] == "transformer": # Check the status of no_peek_mask. if self.no_peek_mask is None or self.no_peek_mask.size( 0) != max_length: self.no_peek_mask = self._generate_no_peek_mask(max_length) hidden_state = encoder_output["hidden_states_all"] enc_pad_mask = batch["user_utt"] == batch["pad_token"] enc_pad_mask = support.flatten(enc_pad_mask, batch_size, num_rounds) dec_pad_mask = batch["assist_in"] == batch["pad_token"] dec_pad_mask = support.flatten(dec_pad_mask, batch_size, num_rounds) if self.params["encoder"] != "pretrained_transformer": dec_embeds = self.pos_encoder(decoder_steps_in).transpose(0, 1) outputs = self.decoder_unit( dec_embeds, hidden_state.transpose(0, 1), memory_key_padding_mask=enc_pad_mask, tgt_mask=self.no_peek_mask, tgt_key_padding_mask=dec_pad_mask, ) outputs = outputs.transpose(0, 1) else: outputs = self.decoder_unit( inputs_embeds=decoder_steps_in, attention_mask=~dec_pad_mask, encoder_hidden_states=hidden_state, encoder_attention_mask=~enc_pad_mask, ) outputs = outputs[0] else: hidden_state = encoder_output["hidden_state"] if self.params["encoder"] == "tf_idf": hidden_state = None # If Bahdahnue attention is to be used. if (self.params["use_bahdanau_attention"] and self.params["encoder"] != "tf_idf"): encoder_states = encoder_output["hidden_states_all"] max_decoder_len = min(decoder_in.shape[1], self.params["max_decoder_len"]) encoder_states_proj = self.attention_net(encoder_states) enc_mask = ( batch["user_utt"] == batch["pad_token"]).unsqueeze(-1) enc_mask = support.flatten(enc_mask, batch_size, num_rounds) outputs = [] for step in range(max_decoder_len): previous_state = hidden_state[0][-1].unsqueeze(1) att_logits = previous_state * encoder_states_proj att_logits = att_logits.sum(dim=-1, keepdim=True) # Use encoder mask to replace <pad> with -Inf. att_logits.masked_fill_(enc_mask, float("-Inf")) att_wts = nn.functional.softmax(att_logits, dim=1) context = (encoder_states * att_wts).sum(1, keepdim=True) # Run through LSTM. concat_in = [ context, decoder_steps_in[:, step:step + 1, :] ] step_in = torch.cat(concat_in, dim=-1) decoder_output, hidden_state = self.decoder_unit( step_in, hidden_state) concat_out = torch.cat([decoder_output, context], dim=-1) outputs.append(concat_out) outputs = torch.cat(outputs, dim=1) else: outputs = rnn.dynamic_rnn( self.decoder_unit, decoder_steps_in, decoder_len, init_state=hidden_state, ) if self.params["encoder"] == "pretrained_transformer": output_logits = outputs else: # Logits over vocabulary. output_logits = self.inv_word_net(outputs) # Mask out the criterion while summing. pad_mask = support.flatten(batch["assist_mask"], batch_size, num_rounds) loss_token = self.criterion(output_logits.transpose(1, 2), decoder_out) loss_token.masked_fill_(pad_mask, 0.0) return {"loss_token": loss_token, "pad_mask": pad_mask}