def forward(self, batch): """Forward pass through the encoder. Args: batch: Dict of batch variables. Returns: encoder_outputs: Dict of outputs from the forward pass. """ encoder_out = {} # Flatten to encode sentences. batch_size, num_rounds, _ = batch["user_utt"].shape encoder_in = support.flatten(batch["user_utt"], batch_size, num_rounds) encoder_len = batch["user_utt_len"].reshape(-1) word_embeds_enc = self.word_embed_net(encoder_in) # Fake encoder_len to be non-zero even for utterances out of dialog. fake_encoder_len = encoder_len.eq(0).long() + encoder_len all_enc_states, enc_states = rnn.dynamic_rnn(self.encoder_unit, word_embeds_enc, fake_encoder_len, return_states=True) encoder_out["hidden_states_all"] = all_enc_states encoder_out["hidden_state"] = enc_states utterance_enc = enc_states[0][-1] batch["utterance_enc"] = support.unflatten(utterance_enc, batch_size, num_rounds) encoder_out["dialog_context"] = self._memory_net_forward(batch) return encoder_out
def _memory_net_forward(self, batch): """Forward pass for memory network to look up fact. 1. Encodes fact via fact rnn. 2. Computes attention with fact and utterance encoding. 3. Attended fact vector and question encoding -> new encoding. Args: batch: Dict of hist, hist_len, hidden_state """ # kwon : fact = prevuiys utterance + response concatenated as one # For example, 'What is the color of the couch? A : Red.' device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') batch_size, num_rounds, enc_time_steps = batch["fact"].shape all_ones = np.full((num_rounds, num_rounds), 1) fact_mask = np.triu(all_ones, 1) fact_mask = np.expand_dims(np.expand_dims(fact_mask, -1), 0) fact_mask = torch.BoolTensor(fact_mask) if self.params["use_gpu"]: fact_mask = fact_mask.cuda() fact_mask.requires_grad_(False) fact_in = support.flatten(batch["fact"], batch_size, num_rounds) fact_len = support.flatten(batch["fact_len"], batch_size, num_rounds) fact_embeds = self.word_embed_net(fact_in) # Encoder fact and unflatten the last hidden state. _, (hidden_state, _) = rnn.dynamic_rnn(self.fact_unit, fact_embeds, fact_len, return_states=True) fact_encode = support.unflatten(hidden_state[-1], batch_size, num_rounds) fact_encode = fact_encode.unsqueeze(1).expand(-1, num_rounds, -1, -1) utterance_enc = batch["utterance_enc"].unsqueeze(2) utterance_enc = utterance_enc.expand(-1, -1, num_rounds, -1) # Combine, compute attention, mask, and weight the fact encodings. combined_encode = torch.cat([utterance_enc, fact_encode], dim=-1) attention = self.fact_attention_net(combined_encode) attention.masked_fill_(fact_mask, float("-Inf")) attention = self.softmax(attention, dim=2) attended_fact = (attention * fact_encode).sum(2) return attended_fact
def forward(self, batch, prev_outputs): """Forward pass a given batch. Args: batch: Batch to forward pass prev_outputs: Output from previous modules. Returns: outputs: Dict of expected outputs """ outputs = {} if self.params[ "use_action_attention"] and self.params["encoder"] != "tf_idf": encoder_state = prev_outputs["hidden_states_all"] batch_size, num_rounds, max_len = batch["user_mask"].shape encoder_mask = batch["user_utt"].eq(batch["pad_token"]) encoder_mask = support.flatten(encoder_mask, batch_size, num_rounds) encoder_state = self.attention_net(encoder_state, encoder_mask) else: encoder_state = prev_outputs["hidden_state"][0][-1] encoder_state_old = encoder_state # Multimodal state. if self.params["use_multimodal_state"]: if self.params["domain"] == "furniture": encoder_state = self.multimodal_embed( batch["carousel_state"], encoder_state, batch["dialog_mask"].shape[:2]) elif self.params["domain"] == "fashion": multimodal_state = {} for ii in ["memory_images", "focus_images"]: multimodal_state[ii] = batch[ii] encoder_state = self.multimodal_embed( multimodal_state, encoder_state, batch["dialog_mask"].shape[:2]) # B : belief state. if self.params["use_belief_state"]: encoder_state = torch.cat((encoder_state, belief_state), dim=1) # Predict and execute actions. action_logits = self.action_net(encoder_state) dialog_mask = batch["dialog_mask"] batch_size, num_rounds = dialog_mask.shape loss_action = self.criterion(action_logits, batch["action"].view(-1)) loss_action.masked_fill_((~dialog_mask).view(-1), 0.0) loss_action_sum = loss_action.sum() / dialog_mask.sum().item() outputs["action_loss"] = loss_action_sum if not self.training: # Check for action accuracy. action_logits = support.unflatten(action_logits, batch_size, num_rounds) actions = action_logits.argmax(dim=-1) action_logits = nn.functional.log_softmax(action_logits, dim=-1) action_list = self.action_map.get_vocabulary_state() # Convert predictions to dictionary. action_preds_dict = [{ "dialog_id": batch["dialog_id"][ii].item(), "predictions": [{ "action": self.action_map.word(actions[ii, jj].item()), "action_log_prob": { action_token: action_logits[ii, jj, kk].item() for kk, action_token in enumerate(action_list) }, "attributes": {} } for jj in range(batch["dialog_len"][ii])] } for ii in range(batch_size)] outputs["action_preds"] = action_preds_dict else: actions = batch["action"] # Run classifiers based on the action, record supervision if training. if self.training: assert ("action_super" in batch), "Need supervision to learn action attributes" attr_logits = collections.defaultdict(list) attr_loss = collections.defaultdict(list) encoder_state_unflat = support.unflatten(encoder_state, batch_size, num_rounds) host = torch.cuda if self.params["use_gpu"] else torch for inst_id in range(batch_size): for round_id in range(num_rounds): # Turn out of dialog length. if not dialog_mask[inst_id, round_id]: continue cur_action_ind = actions[inst_id, round_id].item() cur_action = self.action_map.word(cur_action_ind) cur_state = encoder_state_unflat[inst_id, round_id] supervision = batch["action_super"][inst_id][round_id] # If there is no supervision, ignore and move on to next round. if supervision is None: continue # Run classifiers on attributes. # Attributes overlaps completely with GT when training. if self.training: classifier_list = self.action_metainfo[cur_action][ "attributes"] if self.params["domain"] == "furniture": for key in classifier_list: cur_gt = (supervision.get(key, None) if supervision is not None else None) new_entry = (cur_state, cur_gt, inst_id, round_id) attr_logits[key].append(new_entry) elif self.params["domain"] == "fashion": for key in classifier_list: cur_gt = supervision.get(key, None) gt_indices = host.FloatTensor( len(self.attribute_vocab[key])).fill_(0.) gt_indices[cur_gt] = 1 new_entry = (cur_state, gt_indices, inst_id, round_id) attr_logits[key].append(new_entry) else: raise ValueError( "Domain neither of furniture/fashion!") else: classifier_list = self.action_metainfo[cur_action][ "attributes"] action_pred_datum = action_preds_dict[inst_id][ "predictions"][round_id] if self.params["domain"] == "furniture": # Predict attributes based on the predicted action. for key in classifier_list: classifier = self.classifiers[key] model_pred = classifier(cur_state).argmax(dim=-1) attr_pred = self.attribute_vocab[key][ model_pred.item()] action_pred_datum["attributes"][key] = attr_pred elif self.params["domain"] == "fashion": # Predict attributes based on predicted action. for key in classifier_list: classifier = self.classifiers[key] model_pred = classifier(cur_state) > 0 attr_pred = [ self.attribute_vocab[key][index] for index, ii in enumerate(model_pred) if ii ] action_pred_datum["attributes"][key] = attr_pred else: raise ValueError( "Domain neither of furniture/fashion!") # Compute losses if training, else predict. if self.training: for key, values in attr_logits.items(): classifier = self.classifiers[key] prelogits = [ii[0] for ii in values if ii[1] is not None] if not prelogits: continue logits = classifier(torch.stack(prelogits, dim=0)) if self.params["domain"] == "furniture": gt_labels = [ii[1] for ii in values if ii[1] is not None] gt_labels = host.LongTensor(gt_labels) attr_loss[key] = self.criterion_mean(logits, gt_labels) elif self.params["domain"] == "fashion": gt_labels = torch.stack( [ii[1] for ii in values if ii[1] is not None], dim=0) attr_loss[key] = self.criterion_multi(logits, gt_labels) else: raise ValueError("Domain neither of furniture/fashion!") total_attr_loss = host.FloatTensor([0.0]) if len(attr_loss.values()): total_attr_loss = sum(attr_loss.values()) / len( attr_loss.values()) outputs["action_attr_loss"] = total_attr_loss # Obtain action outputs as memory cells to attend over. if self.params["use_action_output"]: if self.params["domain"] == "furniture": encoder_state_out = self.action_output_embed( batch["action_output"], encoder_state_old, batch["dialog_mask"].shape[:2], ) elif self.params["domain"] == "fashion": multimodal_state = {} for ii in ["memory_images", "focus_images"]: multimodal_state[ii] = batch[ii] # For action output, advance focus_images by one time step. # Output at step t is input at step t+1. feature_size = batch["focus_images"].shape[-1] zero_tensor = host.FloatTensor(batch_size, 1, feature_size).fill_(0.) multimodal_state["focus_images"] = torch.cat( [batch["focus_images"][:, 1:, :], zero_tensor], dim=1) encoder_state_out = self.multimodal_embed( multimodal_state, encoder_state_old, batch["dialog_mask"].shape[:2]) else: raise ValueError("Domain neither furniture/fashion!") outputs["action_output_all"] = encoder_state_out outputs.update({ "action_logits": action_logits, "action_attr_loss_dict": attr_loss }) return outputs
def load_one_batch(self, sample_ids): """Loads a batch, given the sample ids. Args: sample_ids: List of instance ids to load data for. Returns: batch: Dictionary with relevant fields for training/evaluation. """ batch = { "pad_token": self.pad_token, "start_token": self.start_token, "sample_ids": sample_ids, } batch["dialog_len"] = self.raw_data["dialog_len"][sample_ids] batch["dialog_id"] = self.raw_data["dialog_id"][sample_ids] max_dialog_len = max(batch["dialog_len"]) user_utt_id = self.raw_data["user_utt_id"][sample_ids] batch["user_utt"], batch["user_utt_len"] = self._sample_utterance_pool( user_utt_id, self.raw_data["user_sent"], self.raw_data["user_sent_len"], self.params["max_encoder_len"], ) for key in ("assist_in", "assist_out"): batch[key], batch[key + "_len"] = self._sample_utterance_pool( self.raw_data["assist_utt_id"][sample_ids], self.raw_data[key], self.raw_data["assist_sent_len"], self.params["max_decoder_len"], ) actions = self.raw_data["action"][sample_ids] batch["action"] = np.vectorize(lambda x: self.action_map[x])(actions) # Construct user, assistant, and dialog masks. batch["dialog_mask"] = user_utt_id != -1 batch["user_mask"] = (batch["user_utt"] == batch["pad_token"]) | ( batch["user_utt"] == batch["start_token"]) batch["assist_mask"] = (batch["assist_out"] == batch["pad_token"]) | ( batch["assist_out"] == batch["start_token"]) # Get retrieval candidates if needed. if self.params["get_retrieval_candidates"]: retrieval_inds = self.raw_data["retrieval_candidates"][sample_ids] batch_size, num_rounds, _ = retrieval_inds.shape flat_inds = torch_support.flatten(retrieval_inds, batch_size, num_rounds) for key in ("assist_in", "assist_out"): new_key = key.replace("assist", "candidate") cands, cands_len = self._sample_utterance_pool( flat_inds, self.raw_data[key], self.raw_data["assist_sent_len"], self.params["max_decoder_len"], ) batch[new_key] = torch_support.unflatten( cands, batch_size, num_rounds) batch[new_key + "_len"] = torch_support.unflatten( cands_len, batch_size, num_rounds) batch["candidate_mask"] = ( (batch["candidate_out"] == batch["pad_token"]) | (batch["candidate_out"] == batch["start_token"])) # Action supervision. batch["action_super"] = [ self.raw_data["action_supervision"][ii] for ii in sample_ids ] # Fetch facts if required. if self.params["encoder"] == "memory_network": batch["fact"] = self.raw_data["fact"][sample_ids] batch["fact_len"] = self.raw_data["fact_len"][sample_ids] # Trim to the maximum dialog length. for key in ("assist_in", "assist_out", "candidate_in", "candidate_out", "user_utt", "fact", "user_mask", "assist_mask", "candidate_mask"): if key in batch: batch[key] = batch[key][:, :max_dialog_len] for key in ( "action", "assist_in_len", "assist_out_len", "candidate_in_len", "candidate_out_len", "user_utt_len", "dialog_mask", "fact_len", ): if key in batch: batch[key] = batch[key][:, :max_dialog_len] # TF-IDF features. if self.params["encoder"] == "tf_idf": batch["user_tf_idf"] = self.compute_tf_features( batch["user_utt"], batch["user_utt_len"]) # Domain-specific processing. if self.params["domain"] == "furniture": # Carousel states. if self.params["use_multimodal_state"]: batch["carousel_state"] = [ self.raw_data["carousel_state"][ii] for ii in sample_ids ] # Action output. if self.params["use_action_output"]: batch["action_output"] = [ self.raw_data["action_output_state"][ii] for ii in sample_ids ] elif self.params["domain"] == "fashion": # Asset embeddings -- memory, database, focus images. for dtype in ["memory", "database", "focus"]: indices = self.raw_data["{}_inds".format(dtype)][sample_ids] image_embeds = self.embed_data["embedding"][indices] batch["{}_images".format(dtype)] = image_embeds else: raise ValueError("Domain must be either furniture/fashion!") return self._ship_torch_batch(batch)
def forward(self, batch, mode=None): """Forward propagation. Args: batch: Dict of batch input variables. mode: None for training or teaching forcing evaluation; BEAMSEARCH / SAMPLE / MAX to generate text """ outputs = self.encoder(batch) action_output = self.action_executor(batch, outputs) outputs.update(action_output) decoder_output = self.decoder(batch, outputs) if mode: generation_output = self.decoder.forward_beamsearch_multiple( batch, outputs, mode) outputs.update(generation_output) # If evaluating by retrieval, construct fake batch for each candidate. # Inputs from batch used in decoder: # assist_in, assist_out, assist_in_len, assist_mask if self.params["retrieval_evaluation"] and not self.training: option_scores = [] batch_size, num_rounds, num_candidates, _ = batch[ "candidate_in"].shape replace_keys = ("assist_in", "assist_out", "assist_in_len", "assist_mask") for ii in range(num_candidates): for key in replace_keys: new_key = key.replace("assist", "candidate") batch[key] = batch[new_key][:, :, ii] decoder_output = self.decoder(batch, outputs) log_probs = torch_support.unflatten( decoder_output["loss_token"], batch_size, num_rounds) option_scores.append(-1 * log_probs.sum(-1)) option_scores = torch.stack(option_scores, 2) outputs["candidate_scores"] = [{ "dialog_id": batch["dialog_id"][ii].item(), "candidate_scores": [ list(option_scores[ii, jj].cpu().numpy()) for jj in range(batch["dialog_len"][ii]) ] } for ii in range(batch_size)] # Local aliases. loss_token = decoder_output["loss_token"] pad_mask = decoder_output["pad_mask"] if self.training: loss_token = loss_token.sum() / (~pad_mask).sum().item() loss_action = action_output["action_loss"] loss_action_attr = action_output["action_attr_loss"] loss_total = loss_action + loss_token + loss_action_attr return { "token": loss_token, "action": loss_action, "action_attr": loss_action_attr, "total": loss_total, } else: outputs.update({ "loss_sum": loss_token.sum(), "num_tokens": (~pad_mask).sum() }) return outputs