def forward(self, batch): """Forward pass through the encoder. Args: batch: Dict of batch variables. Returns: encoder_outputs: Dict of outputs from the forward pass. """ encoder_out = {} # Flatten for history_agnostic encoder. batch_size, num_rounds, max_length = batch["user_utt"].shape encoder_in = support.flatten(batch["user_utt"], batch_size, num_rounds) encoder_len = support.flatten(batch["user_utt_len"], batch_size, num_rounds) word_embeds_enc = self.word_embed_net(encoder_in) # Text encoder: LSTM or Transformer. if self.params["text_encoder"] == "lstm": all_enc_states, enc_states = rnn.dynamic_rnn(self.encoder_unit, word_embeds_enc, encoder_len, return_states=True) encoder_out["hidden_states_all"] = all_enc_states encoder_out["hidden_state"] = enc_states elif self.params["text_encoder"] == "transformer": enc_embeds = self.pos_encoder(word_embeds_enc).transpose(0, 1) enc_pad_mask = batch["user_utt"] == batch["pad_token"] enc_pad_mask = support.flatten(enc_pad_mask, batch_size, num_rounds) enc_states = self.encoder_unit(enc_embeds, src_key_padding_mask=enc_pad_mask) encoder_out["hidden_states_all"] = enc_states.transpose(0, 1) return encoder_out
def forward(self, batch): """Forward pass through the encoder. Args: batch: Dict of batch variables. Returns: encoder_outputs: Dict of outputs from the forward pass. """ encoder_out = {} device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # Flatten to encode sentences. batch_size, num_rounds, _ = batch["user_utt"].shape encoder_in = support.flatten(batch["user_utt"], batch_size, num_rounds) encoder_len = batch["user_utt_len"].reshape(-1) word_embeds_enc = self.word_embed_net(encoder_in) # Fake encoder_len to be non-zero even for utterances out of dialog. fake_encoder_len = encoder_len.eq(0).long() + encoder_len all_enc_states, enc_states = rnn.dynamic_rnn(self.encoder_unit, word_embeds_enc, fake_encoder_len, return_states=True) encoder_out["hidden_states_all"] = all_enc_states encoder_out["hidden_state"] = enc_states utterance_enc = enc_states[0][-1] new_size = (batch_size, num_rounds, utterance_enc.shape[-1]) utterance_enc = utterance_enc.reshape(new_size) encoder_out["dialog_context"], _ = rnn.dynamic_rnn(self.dialog_unit, utterance_enc, batch["dialog_len"], return_states=True) return encoder_out
def forward(self, batch): """Forward pass through the encoder. Args: batch: Dict of batch variables. Returns: encoder_outputs: Dict of outputs from the forward pass. """ encoder_out = {} # Flatten to encode sentences. batch_size, num_rounds, _ = batch["user_utt"].shape encoder_in = support.flatten(batch["user_utt"], batch_size, num_rounds) encoder_len = batch["user_utt_len"].reshape(-1) word_embeds_enc = self.word_embed_net(encoder_in) # Fake encoder_len to be non-zero even for utterances out of dialog. fake_encoder_len = encoder_len.eq(0).long() + encoder_len all_enc_states, enc_states = rnn.dynamic_rnn(self.encoder_unit, word_embeds_enc, fake_encoder_len, return_states=True) encoder_out["hidden_states_all"] = all_enc_states encoder_out["hidden_state"] = enc_states utterance_enc = enc_states[0][-1] batch["utterance_enc"] = support.unflatten(utterance_enc, batch_size, num_rounds) encoder_out["dialog_context"] = self._memory_net_forward(batch) return encoder_out
def forward(self, multimodal_state, encoder_state, encoder_size): """Multimodal Embedding. Args: multimodal_state: Dict with memory, database, and focus images encoder_state: State of the encoder encoder_size: (batch_size, num_rounds) Returns: multimodal_encode: Encoder state with multimodal information """ # Setup category states if None. if self.category_state is None: self._setup_category_states() # Attend to multimodal memory using encoder states. batch_size, num_rounds = encoder_size memory_images = multimodal_state["memory_images"] memory_images = memory_images.unsqueeze(1).expand( -1, num_rounds, -1, -1) focus_images = multimodal_state["focus_images"][:, :num_rounds, :] focus_images = focus_images.unsqueeze(2) all_images = torch.cat([focus_images, memory_images], dim=2) all_images_flat = support.flatten(all_images, batch_size, num_rounds) category_state = self.category_state.expand(batch_size * num_rounds, -1, -1) cat_images = torch.cat([all_images_flat, category_state], dim=-1) multimodal_memory = self.multimodal_embed_net(cat_images) # Key (L, N, E), value (L, N, E), query (S, N, E) multimodal_memory = multimodal_memory.transpose(0, 1) query = encoder_state.unsqueeze(0) attended_query, attented_wts = self.multimodal_attend( query, multimodal_memory, multimodal_memory) multimodal_encode = torch.cat( [attended_query.squeeze(0), encoder_state], dim=-1) return multimodal_encode
def _memory_net_forward(self, batch): """Forward pass for memory network to look up fact. 1. Encodes fact via fact rnn. 2. Computes attention with fact and utterance encoding. 3. Attended fact vector and question encoding -> new encoding. Args: batch: Dict of hist, hist_len, hidden_state """ # kwon : fact = prevuiys utterance + response concatenated as one # For example, 'What is the color of the couch? A : Red.' device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') batch_size, num_rounds, enc_time_steps = batch["fact"].shape all_ones = np.full((num_rounds, num_rounds), 1) fact_mask = np.triu(all_ones, 1) fact_mask = np.expand_dims(np.expand_dims(fact_mask, -1), 0) fact_mask = torch.BoolTensor(fact_mask) if self.params["use_gpu"]: fact_mask = fact_mask.cuda() fact_mask.requires_grad_(False) fact_in = support.flatten(batch["fact"], batch_size, num_rounds) fact_len = support.flatten(batch["fact_len"], batch_size, num_rounds) fact_embeds = self.word_embed_net(fact_in) # Encoder fact and unflatten the last hidden state. _, (hidden_state, _) = rnn.dynamic_rnn(self.fact_unit, fact_embeds, fact_len, return_states=True) fact_encode = support.unflatten(hidden_state[-1], batch_size, num_rounds) fact_encode = fact_encode.unsqueeze(1).expand(-1, num_rounds, -1, -1) utterance_enc = batch["utterance_enc"].unsqueeze(2) utterance_enc = utterance_enc.expand(-1, -1, num_rounds, -1) # Combine, compute attention, mask, and weight the fact encodings. combined_encode = torch.cat([utterance_enc, fact_encode], dim=-1) attention = self.fact_attention_net(combined_encode) attention.masked_fill_(fact_mask, float("-Inf")) attention = self.softmax(attention, dim=2) attended_fact = (attention * fact_encode).sum(2) return attended_fact
def forward(self, multimodal_state, encoder_state, encoder_size): """Multimodal Embedding. Args: multimodal_state: Dict with memory, database, and focus images encoder_state: State of the encoder encoder_size: (batch_size, num_rounds) Returns: multimodal_encode: Encoder state with multimodal information """ # Setup category states if None. if self.category_state is None: self._setup_category_states() # Attend to multimodal memory using encoder states. batch_size, num_rounds = encoder_size memory_images = multimodal_state["memory_images"] memory_images = memory_images.unsqueeze(1).expand( -1, num_rounds, -1, -1) focus_images = multimodal_state["focus_images"][:, :num_rounds, :] focus_images = focus_images.unsqueeze(2) all_images = torch.cat([focus_images, memory_images], dim=2) all_images_flat = support.flatten(all_images, batch_size, num_rounds) category_state = self.category_state.expand(batch_size * num_rounds, -1, -1) cat_images = torch.cat([all_images_flat, category_state], dim=-1) multimodal_memory = self.multimodal_embed_net(cat_images) # Key (L, N, E), value (L, N, E), query (S, N, E) multimodal_memory = multimodal_memory.transpose(0, 1) query = encoder_state.unsqueeze(0) #MAG if self.params['gate_type'] == "MAG": multimodal_encode = torch.cat( [self.MAG(query, multimodal_memory).squeeze(0), encoder_state], dim=-1) elif self.params['gate_type'] == "none": attended_query, attented_wts = self.multimodal_attend( query, multimodal_memory, multimodal_memory) multimodal_encode = torch.cat( [attended_query.squeeze(0), encoder_state], dim=-1) elif self.params['gate_type'] == "MMI": attended_query_q, attended_wts_q = self.multimodal_attend( query, multimodal_memory, multimodal_memory) v_input = multimodal_memory.mean(dim=0) v_input = v_input.unsqueeze(0) attended_query_p, attended_wts_p = self.multimodal_attend( v_input, query, query) attended_query_a, attended_wts_a = self.multimodal_attend( query, attended_query_p, attended_query_p) b = self.Visualgate(attended_query_a, attended_query_q) multimodal_encode = torch.cat( [attended_query_a.squeeze(0), b.squeeze(0)], dim=-1) return multimodal_encode
def forward(self, batch, prev_outputs): """Forward pass a given batch. Args: batch: Batch to forward pass prev_outputs: Output from previous modules. Returns: outputs: Dict of expected outputs """ outputs = {} if self.params[ "use_action_attention"] and self.params["encoder"] != "tf_idf": encoder_state = prev_outputs["hidden_states_all"] batch_size, num_rounds, max_len = batch["user_mask"].shape encoder_mask = batch["user_utt"].eq(batch["pad_token"]) encoder_mask = support.flatten(encoder_mask, batch_size, num_rounds) encoder_state = self.attention_net(encoder_state, encoder_mask) else: encoder_state = prev_outputs["hidden_state"][0][-1] encoder_state_old = encoder_state # Multimodal state. if self.params["use_multimodal_state"]: if self.params["domain"] == "furniture": encoder_state = self.multimodal_embed( batch["carousel_state"], encoder_state, batch["dialog_mask"].shape[:2]) elif self.params["domain"] == "fashion": multimodal_state = {} for ii in ["memory_images", "focus_images"]: multimodal_state[ii] = batch[ii] encoder_state = self.multimodal_embed( multimodal_state, encoder_state, batch["dialog_mask"].shape[:2]) # B : belief state. if self.params["use_belief_state"]: encoder_state = torch.cat((encoder_state, belief_state), dim=1) # Predict and execute actions. action_logits = self.action_net(encoder_state) dialog_mask = batch["dialog_mask"] batch_size, num_rounds = dialog_mask.shape loss_action = self.criterion(action_logits, batch["action"].view(-1)) loss_action.masked_fill_((~dialog_mask).view(-1), 0.0) loss_action_sum = loss_action.sum() / dialog_mask.sum().item() outputs["action_loss"] = loss_action_sum if not self.training: # Check for action accuracy. action_logits = support.unflatten(action_logits, batch_size, num_rounds) actions = action_logits.argmax(dim=-1) action_logits = nn.functional.log_softmax(action_logits, dim=-1) action_list = self.action_map.get_vocabulary_state() # Convert predictions to dictionary. action_preds_dict = [{ "dialog_id": batch["dialog_id"][ii].item(), "predictions": [{ "action": self.action_map.word(actions[ii, jj].item()), "action_log_prob": { action_token: action_logits[ii, jj, kk].item() for kk, action_token in enumerate(action_list) }, "attributes": {} } for jj in range(batch["dialog_len"][ii])] } for ii in range(batch_size)] outputs["action_preds"] = action_preds_dict else: actions = batch["action"] # Run classifiers based on the action, record supervision if training. if self.training: assert ("action_super" in batch), "Need supervision to learn action attributes" attr_logits = collections.defaultdict(list) attr_loss = collections.defaultdict(list) encoder_state_unflat = support.unflatten(encoder_state, batch_size, num_rounds) host = torch.cuda if self.params["use_gpu"] else torch for inst_id in range(batch_size): for round_id in range(num_rounds): # Turn out of dialog length. if not dialog_mask[inst_id, round_id]: continue cur_action_ind = actions[inst_id, round_id].item() cur_action = self.action_map.word(cur_action_ind) cur_state = encoder_state_unflat[inst_id, round_id] supervision = batch["action_super"][inst_id][round_id] # If there is no supervision, ignore and move on to next round. if supervision is None: continue # Run classifiers on attributes. # Attributes overlaps completely with GT when training. if self.training: classifier_list = self.action_metainfo[cur_action][ "attributes"] if self.params["domain"] == "furniture": for key in classifier_list: cur_gt = (supervision.get(key, None) if supervision is not None else None) new_entry = (cur_state, cur_gt, inst_id, round_id) attr_logits[key].append(new_entry) elif self.params["domain"] == "fashion": for key in classifier_list: cur_gt = supervision.get(key, None) gt_indices = host.FloatTensor( len(self.attribute_vocab[key])).fill_(0.) gt_indices[cur_gt] = 1 new_entry = (cur_state, gt_indices, inst_id, round_id) attr_logits[key].append(new_entry) else: raise ValueError( "Domain neither of furniture/fashion!") else: classifier_list = self.action_metainfo[cur_action][ "attributes"] action_pred_datum = action_preds_dict[inst_id][ "predictions"][round_id] if self.params["domain"] == "furniture": # Predict attributes based on the predicted action. for key in classifier_list: classifier = self.classifiers[key] model_pred = classifier(cur_state).argmax(dim=-1) attr_pred = self.attribute_vocab[key][ model_pred.item()] action_pred_datum["attributes"][key] = attr_pred elif self.params["domain"] == "fashion": # Predict attributes based on predicted action. for key in classifier_list: classifier = self.classifiers[key] model_pred = classifier(cur_state) > 0 attr_pred = [ self.attribute_vocab[key][index] for index, ii in enumerate(model_pred) if ii ] action_pred_datum["attributes"][key] = attr_pred else: raise ValueError( "Domain neither of furniture/fashion!") # Compute losses if training, else predict. if self.training: for key, values in attr_logits.items(): classifier = self.classifiers[key] prelogits = [ii[0] for ii in values if ii[1] is not None] if not prelogits: continue logits = classifier(torch.stack(prelogits, dim=0)) if self.params["domain"] == "furniture": gt_labels = [ii[1] for ii in values if ii[1] is not None] gt_labels = host.LongTensor(gt_labels) attr_loss[key] = self.criterion_mean(logits, gt_labels) elif self.params["domain"] == "fashion": gt_labels = torch.stack( [ii[1] for ii in values if ii[1] is not None], dim=0) attr_loss[key] = self.criterion_multi(logits, gt_labels) else: raise ValueError("Domain neither of furniture/fashion!") total_attr_loss = host.FloatTensor([0.0]) if len(attr_loss.values()): total_attr_loss = sum(attr_loss.values()) / len( attr_loss.values()) outputs["action_attr_loss"] = total_attr_loss # Obtain action outputs as memory cells to attend over. if self.params["use_action_output"]: if self.params["domain"] == "furniture": encoder_state_out = self.action_output_embed( batch["action_output"], encoder_state_old, batch["dialog_mask"].shape[:2], ) elif self.params["domain"] == "fashion": multimodal_state = {} for ii in ["memory_images", "focus_images"]: multimodal_state[ii] = batch[ii] # For action output, advance focus_images by one time step. # Output at step t is input at step t+1. feature_size = batch["focus_images"].shape[-1] zero_tensor = host.FloatTensor(batch_size, 1, feature_size).fill_(0.) multimodal_state["focus_images"] = torch.cat( [batch["focus_images"][:, 1:, :], zero_tensor], dim=1) encoder_state_out = self.multimodal_embed( multimodal_state, encoder_state_old, batch["dialog_mask"].shape[:2]) else: raise ValueError("Domain neither furniture/fashion!") outputs["action_output_all"] = encoder_state_out outputs.update({ "action_logits": action_logits, "action_attr_loss_dict": attr_loss }) return outputs
def load_one_batch(self, sample_ids): """Loads a batch, given the sample ids. Args: sample_ids: List of instance ids to load data for. Returns: batch: Dictionary with relevant fields for training/evaluation. """ batch = { "pad_token": self.pad_token, "start_token": self.start_token, "sample_ids": sample_ids, } batch["dialog_len"] = self.raw_data["dialog_len"][sample_ids] batch["dialog_id"] = self.raw_data["dialog_id"][sample_ids] max_dialog_len = max(batch["dialog_len"]) user_utt_id = self.raw_data["user_utt_id"][sample_ids] batch["user_utt"], batch["user_utt_len"] = self._sample_utterance_pool( user_utt_id, self.raw_data["user_sent"], self.raw_data["user_sent_len"], self.params["max_encoder_len"], ) for key in ("assist_in", "assist_out"): batch[key], batch[key + "_len"] = self._sample_utterance_pool( self.raw_data["assist_utt_id"][sample_ids], self.raw_data[key], self.raw_data["assist_sent_len"], self.params["max_decoder_len"], ) actions = self.raw_data["action"][sample_ids] batch["action"] = np.vectorize(lambda x: self.action_map[x])(actions) # Construct user, assistant, and dialog masks. batch["dialog_mask"] = user_utt_id != -1 batch["user_mask"] = (batch["user_utt"] == batch["pad_token"]) | ( batch["user_utt"] == batch["start_token"]) batch["assist_mask"] = (batch["assist_out"] == batch["pad_token"]) | ( batch["assist_out"] == batch["start_token"]) # Get retrieval candidates if needed. if self.params["get_retrieval_candidates"]: retrieval_inds = self.raw_data["retrieval_candidates"][sample_ids] batch_size, num_rounds, _ = retrieval_inds.shape flat_inds = torch_support.flatten(retrieval_inds, batch_size, num_rounds) for key in ("assist_in", "assist_out"): new_key = key.replace("assist", "candidate") cands, cands_len = self._sample_utterance_pool( flat_inds, self.raw_data[key], self.raw_data["assist_sent_len"], self.params["max_decoder_len"], ) batch[new_key] = torch_support.unflatten( cands, batch_size, num_rounds) batch[new_key + "_len"] = torch_support.unflatten( cands_len, batch_size, num_rounds) batch["candidate_mask"] = ( (batch["candidate_out"] == batch["pad_token"]) | (batch["candidate_out"] == batch["start_token"])) # Action supervision. batch["action_super"] = [ self.raw_data["action_supervision"][ii] for ii in sample_ids ] # Fetch facts if required. if self.params["encoder"] == "memory_network": batch["fact"] = self.raw_data["fact"][sample_ids] batch["fact_len"] = self.raw_data["fact_len"][sample_ids] # Trim to the maximum dialog length. for key in ("assist_in", "assist_out", "candidate_in", "candidate_out", "user_utt", "fact", "user_mask", "assist_mask", "candidate_mask"): if key in batch: batch[key] = batch[key][:, :max_dialog_len] for key in ( "action", "assist_in_len", "assist_out_len", "candidate_in_len", "candidate_out_len", "user_utt_len", "dialog_mask", "fact_len", ): if key in batch: batch[key] = batch[key][:, :max_dialog_len] # TF-IDF features. if self.params["encoder"] == "tf_idf": batch["user_tf_idf"] = self.compute_tf_features( batch["user_utt"], batch["user_utt_len"]) # Domain-specific processing. if self.params["domain"] == "furniture": # Carousel states. if self.params["use_multimodal_state"]: batch["carousel_state"] = [ self.raw_data["carousel_state"][ii] for ii in sample_ids ] # Action output. if self.params["use_action_output"]: batch["action_output"] = [ self.raw_data["action_output_state"][ii] for ii in sample_ids ] elif self.params["domain"] == "fashion": # Asset embeddings -- memory, database, focus images. for dtype in ["memory", "database", "focus"]: indices = self.raw_data["{}_inds".format(dtype)][sample_ids] image_embeds = self.embed_data["embedding"][indices] batch["{}_images".format(dtype)] = image_embeds else: raise ValueError("Domain must be either furniture/fashion!") return self._ship_torch_batch(batch)
def forward(self, batch, encoder_output): """Forward pass through the decoder. Args: batch: Dict of batch variables. encoder_output: Dict of outputs from the encoder. Returns: decoder_outputs: Dict of outputs from the forward pass. """ # Flatten for history_agnostic encoder. batch_size, num_rounds, max_length = batch["assist_in"].shape decoder_in = support.flatten(batch["assist_in"], batch_size, num_rounds) decoder_out = support.flatten(batch["assist_out"], batch_size, num_rounds) decoder_len = support.flatten(batch["assist_in_len"], batch_size, num_rounds) word_embeds_dec = self.word_embed_net(decoder_in) if self.params["encoder"] in self.DIALOG_CONTEXT_ENCODERS: dialog_context = support.flatten(encoder_output["dialog_context"], batch_size, num_rounds).unsqueeze(1) dialog_context = dialog_context.expand(-1, max_length, -1) decoder_steps_in = torch.cat([dialog_context, word_embeds_dec], -1) else: decoder_steps_in = word_embeds_dec # Encoder states conditioned on action outputs, if need be. if self.params["use_action_output"]: action_out = encoder_output["action_output_all"].unsqueeze(1) time_steps = encoder_output["hidden_states_all"].shape[1] fusion_out = torch.cat( [ encoder_output["hidden_states_all"], action_out.expand(-1, time_steps, -1), ], dim=-1, ) encoder_output["hidden_states_all"] = self.action_fusion_net( fusion_out) if self.params["text_encoder"] == "transformer": # Check the status of no_peek_mask. if self.no_peek_mask is None or self.no_peek_mask.size( 0) != max_length: self.no_peek_mask = self._generate_no_peek_mask(max_length) hidden_state = encoder_output["hidden_states_all"] enc_pad_mask = batch["user_utt"] == batch["pad_token"] enc_pad_mask = support.flatten(enc_pad_mask, batch_size, num_rounds) dec_pad_mask = batch["assist_in"] == batch["pad_token"] dec_pad_mask = support.flatten(dec_pad_mask, batch_size, num_rounds) if self.params["encoder"] != "pretrained_transformer": dec_embeds = self.pos_encoder(decoder_steps_in).transpose(0, 1) outputs = self.decoder_unit( dec_embeds, hidden_state.transpose(0, 1), memory_key_padding_mask=enc_pad_mask, tgt_mask=self.no_peek_mask, tgt_key_padding_mask=dec_pad_mask, ) outputs = outputs.transpose(0, 1) else: outputs = self.decoder_unit( inputs_embeds=decoder_steps_in, attention_mask=~dec_pad_mask, encoder_hidden_states=hidden_state, encoder_attention_mask=~enc_pad_mask, ) outputs = outputs[0] else: hidden_state = encoder_output["hidden_state"] if self.params["encoder"] == "tf_idf": hidden_state = None # If Bahdahnue attention is to be used. if (self.params["use_bahdanau_attention"] and self.params["encoder"] != "tf_idf"): encoder_states = encoder_output["hidden_states_all"] max_decoder_len = min(decoder_in.shape[1], self.params["max_decoder_len"]) encoder_states_proj = self.attention_net(encoder_states) enc_mask = ( batch["user_utt"] == batch["pad_token"]).unsqueeze(-1) enc_mask = support.flatten(enc_mask, batch_size, num_rounds) outputs = [] for step in range(max_decoder_len): previous_state = hidden_state[0][-1].unsqueeze(1) att_logits = previous_state * encoder_states_proj att_logits = att_logits.sum(dim=-1, keepdim=True) # Use encoder mask to replace <pad> with -Inf. att_logits.masked_fill_(enc_mask, float("-Inf")) att_wts = nn.functional.softmax(att_logits, dim=1) context = (encoder_states * att_wts).sum(1, keepdim=True) # Run through LSTM. concat_in = [ context, decoder_steps_in[:, step:step + 1, :] ] step_in = torch.cat(concat_in, dim=-1) decoder_output, hidden_state = self.decoder_unit( step_in, hidden_state) concat_out = torch.cat([decoder_output, context], dim=-1) outputs.append(concat_out) outputs = torch.cat(outputs, dim=1) else: outputs = rnn.dynamic_rnn( self.decoder_unit, decoder_steps_in, decoder_len, init_state=hidden_state, ) if self.params["encoder"] == "pretrained_transformer": output_logits = outputs else: # Logits over vocabulary. output_logits = self.inv_word_net(outputs) # Mask out the criterion while summing. pad_mask = support.flatten(batch["assist_mask"], batch_size, num_rounds) loss_token = self.criterion(output_logits.transpose(1, 2), decoder_out) loss_token.masked_fill_(pad_mask, 0.0) return {"loss_token": loss_token, "pad_mask": pad_mask}
def forward_beamsearch_single(self, batch, encoder_output, mode_params): """Evaluates the model using beam search with batch size 1. NOTE: Current implementation only supports beam search for batch size 1 and for RNN text_encoder (will be extended for transformers) Args: batch: Dictionary of inputs, with batch size of 1 beam_size: Number of beams Returns: top_beams: Dictionary of top beams """ # Initializations and aliases. # Tensors are either on CPU or GPU. LENGTH_NORM = True self.host = torch.cuda if self.params["use_gpu"] else torch end_token = self.params["end_token"] start_token = self.params["start_token"] beam_size = mode_params["beam_size"] max_decoder_len = self.params["max_decoder_len"] if self.params["text_encoder"] == "transformer": hidden_state = encoder_output["hidden_states_all"].transpose(0, 1) max_enc_len, batch_size, enc_embed_size = hidden_state.shape hidden_state_expand = hidden_state.expand(max_enc_len, beam_size, enc_embed_size) enc_pad_mask = batch["user_utt"] == batch["pad_token"] enc_pad_mask = support.flatten(enc_pad_mask, 1, 1) enc_pad_mask_expand = enc_pad_mask.expand(beam_size, max_enc_len) if (self.no_peek_mask is None or self.no_peek_mask.size(0) != max_decoder_len): self.no_peek_mask = self._generate_no_peek_mask( max_decoder_len) elif self.params["text_encoder"] == "lstm": hidden_state = encoder_output["hidden_state"] if (self.params["use_bahdanau_attention"] and self.params["encoder"] != "tf_idf" and self.params["text_encoder"] == "lstm"): encoder_states = encoder_output["hidden_states_all"] encoder_states_proj = self.attention_net(encoder_states) enc_mask = (batch["user_utt"] == batch["pad_token"]).unsqueeze(-1) enc_mask = support.flatten(enc_mask, 1, 1) # Per instance initializations. # Copy the hidden state beam_size number of times. if hidden_state is not None: hidden_state = [ii.repeat(1, beam_size, 1) for ii in hidden_state] beams = {-1: self.host.LongTensor(1, beam_size).fill_(start_token)} beam_scores = self.host.FloatTensor(beam_size, 1).fill_(0.) finished_beams = self.host.ByteTensor(beam_size, 1).fill_(False) zero_tensor = self.host.LongTensor(beam_size, 1).fill_(end_token) reverse_inds = {} # Generate beams until max_len time steps. for step in range(max_decoder_len - 1): if self.params["text_encoder"] == "transformer": beams, tokens_list = self._backtrack_beams(beams, reverse_inds) beam_tokens = torch.cat(tokens_list, dim=0).transpose(0, 1) beam_tokens_embed = self.word_embed_net(beam_tokens) if self.params["encoder"] != "pretrained_transformer": dec_embeds = self.pos_encoder(beam_tokens_embed).transpose( 0, 1) output = self.decoder_unit( dec_embeds, hidden_state_expand, tgt_mask=self.no_peek_mask[:step + 1, :step + 1], memory_key_padding_mask=enc_pad_mask_expand, ) logits = self.inv_word_net(output[-1]) else: outputs = self.decoder_unit( inputs_embeds=beam_tokens_embed, encoder_hidden_states=hidden_state_expand.transpose( 0, 1), encoder_attention_mask=~enc_pad_mask_expand, ) logits = outputs[0][:, -1, :] elif self.params["text_encoder"] == "lstm": beam_tokens = beams[step - 1].t() beam_tokens_embed = self.word_embed_net(beam_tokens) # Append dialog context if exists. if self.params["encoder"] in self.DIALOG_CONTEXT_ENCODERS: dialog_context = encoder_output["dialog_context"] beam_tokens_embed = torch.cat( [ dialog_context.repeat(beam_size, 1, 1), beam_tokens_embed ], dim=-1, ) # Use bahdanau attention over encoder hidden states. if (self.params["use_bahdanau_attention"] and self.params["encoder"] != "tf_idf"): previous_state = hidden_state[0][-1].unsqueeze(1) att_logits = previous_state * encoder_states_proj att_logits = att_logits.sum(dim=-1, keepdim=True) # Use encoder mask to replace <pad> with -Inf. att_logits.masked_fill_(enc_mask, float("-Inf")) att_wts = nn.functional.softmax(att_logits, dim=1) context = (encoder_states * att_wts).sum(1, keepdim=True) # Run through LSTM. step_in = torch.cat([context, beam_tokens_embed], dim=-1) decoder_output, new_state = self.decoder_unit( step_in, hidden_state) output = torch.cat([decoder_output, context], dim=-1) else: output, new_state = self.decoder_unit( beam_tokens_embed, hidden_state) logits = self.inv_word_net(output).squeeze(1) log_probs = nn.functional.log_softmax(logits, dim=-1) # Compute the new beam scores. alive = finished_beams.eq(0).float() if LENGTH_NORM: # Add (current log probs / (step + 1)) cur_weight = alive / (step + 1) # Add (previous log probs * (t/t+1) ) <- Mean update prev_weight = alive * step / (step + 1) else: # No length normalization. cur_weight = alive prev_weight = alive # Compute the new beam extensions. if step == 0: # For the first step, make all but first beam # probabilities -inf. log_probs[1:, :] = float("-inf") new_scores = log_probs * cur_weight + beam_scores * prev_weight finished_beam_scores = beam_scores * finished_beams.float() new_scores.scatter_add_(1, zero_tensor, finished_beam_scores) # Finished beams scores are set to -inf for all words but one. new_scores.masked_fill_(new_scores.eq(0), float("-inf")) num_candidates = new_scores.shape[-1] new_scores_flat = new_scores.view(1, -1) beam_scores, top_inds_flat = torch.topk(new_scores_flat, beam_size) beam_scores = beam_scores.t() top_beam_inds = (top_inds_flat / num_candidates).squeeze(0) top_tokens = top_inds_flat % num_candidates # Prepare for next step. beams[step] = top_tokens reverse_inds[step] = top_beam_inds finished_beams = finished_beams[top_beam_inds] if self.params["text_encoder"] == "lstm": hidden_state = tuple( ii.index_select(1, top_beam_inds) for ii in new_state) # Update if any of the latest beams are finished, ie, have <END>. # new_finished_beams = beams[step].eq(end_token) new_finished_beams = beams[step].eq(end_token).type( self.host.ByteTensor) finished_beams = finished_beams | new_finished_beams.t() if torch.sum(finished_beams).item() == beam_size: break # Backtrack the beam through indices. beams, tokens_list = self._backtrack_beams(beams, reverse_inds) # Add an <END> token at the end. tokens_list.append(self.host.LongTensor(1, beam_size).fill_(end_token)) sorted_beam_tokens = torch.cat(tokens_list, 0).t() sorted_beam_lengths = sorted_beam_tokens.ne(end_token).long().sum( dim=1) # Trim all the top beams. top_beams = [] for index in range(beam_size): beam_length = sorted_beam_lengths[index].view(-1) beam = sorted_beam_tokens[index].view(-1, 1)[1:beam_length] top_beams.append(beam) return {"top_beams": top_beams}