示例#1
0
    def forward(self, batch):
        """Forward pass through the encoder.

        Args:
            batch: Dict of batch variables.

        Returns:
            encoder_outputs: Dict of outputs from the forward pass.
        """
        encoder_out = {}
        # Flatten to encode sentences.
        batch_size, num_rounds, _ = batch["user_utt"].shape
        encoder_in = support.flatten(batch["user_utt"], batch_size, num_rounds)
        encoder_len = batch["user_utt_len"].reshape(-1)
        word_embeds_enc = self.word_embed_net(encoder_in)

        # Fake encoder_len to be non-zero even for utterances out of dialog.
        fake_encoder_len = encoder_len.eq(0).long() + encoder_len
        all_enc_states, enc_states = rnn.dynamic_rnn(self.encoder_unit,
                                                     word_embeds_enc,
                                                     fake_encoder_len,
                                                     return_states=True)
        encoder_out["hidden_states_all"] = all_enc_states
        encoder_out["hidden_state"] = enc_states

        utterance_enc = enc_states[0][-1]
        batch["utterance_enc"] = support.unflatten(utterance_enc, batch_size,
                                                   num_rounds)
        encoder_out["dialog_context"] = self._memory_net_forward(batch)
        return encoder_out
示例#2
0
    def _memory_net_forward(self, batch):
        """Forward pass for memory network to look up fact.

        1. Encodes fact via fact rnn.
        2. Computes attention with fact and utterance encoding.
        3. Attended fact vector and question encoding -> new encoding.

        Args:
          batch: Dict of hist, hist_len, hidden_state
        """
        # kwon : fact = prevuiys utterance + response concatenated as one
        # For example, 'What is the color of the couch? A : Red.'
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        batch_size, num_rounds, enc_time_steps = batch["fact"].shape
        all_ones = np.full((num_rounds, num_rounds), 1)
        fact_mask = np.triu(all_ones, 1)
        fact_mask = np.expand_dims(np.expand_dims(fact_mask, -1), 0)
        fact_mask = torch.BoolTensor(fact_mask)
        if self.params["use_gpu"]:
            fact_mask = fact_mask.cuda()
        fact_mask.requires_grad_(False)

        fact_in = support.flatten(batch["fact"], batch_size, num_rounds)
        fact_len = support.flatten(batch["fact_len"], batch_size, num_rounds)
        fact_embeds = self.word_embed_net(fact_in)

        # Encoder fact and unflatten the last hidden state.
        _, (hidden_state, _) = rnn.dynamic_rnn(self.fact_unit,
                                               fact_embeds,
                                               fact_len,
                                               return_states=True)
        fact_encode = support.unflatten(hidden_state[-1], batch_size,
                                        num_rounds)
        fact_encode = fact_encode.unsqueeze(1).expand(-1, num_rounds, -1, -1)

        utterance_enc = batch["utterance_enc"].unsqueeze(2)
        utterance_enc = utterance_enc.expand(-1, -1, num_rounds, -1)

        # Combine, compute attention, mask, and weight the fact encodings.
        combined_encode = torch.cat([utterance_enc, fact_encode], dim=-1)
        attention = self.fact_attention_net(combined_encode)
        attention.masked_fill_(fact_mask, float("-Inf"))
        attention = self.softmax(attention, dim=2)
        attended_fact = (attention * fact_encode).sum(2)
        return attended_fact
示例#3
0
    def forward(self, batch, prev_outputs):
        """Forward pass a given batch.

        Args:
            batch: Batch to forward pass
            prev_outputs: Output from previous modules.

        Returns:
            outputs: Dict of expected outputs
        """
        outputs = {}

        if self.params[
                "use_action_attention"] and self.params["encoder"] != "tf_idf":
            encoder_state = prev_outputs["hidden_states_all"]
            batch_size, num_rounds, max_len = batch["user_mask"].shape
            encoder_mask = batch["user_utt"].eq(batch["pad_token"])
            encoder_mask = support.flatten(encoder_mask, batch_size,
                                           num_rounds)
            encoder_state = self.attention_net(encoder_state, encoder_mask)
        else:
            encoder_state = prev_outputs["hidden_state"][0][-1]

        encoder_state_old = encoder_state
        # Multimodal state.
        if self.params["use_multimodal_state"]:
            if self.params["domain"] == "furniture":
                encoder_state = self.multimodal_embed(
                    batch["carousel_state"], encoder_state,
                    batch["dialog_mask"].shape[:2])
            elif self.params["domain"] == "fashion":
                multimodal_state = {}
                for ii in ["memory_images", "focus_images"]:
                    multimodal_state[ii] = batch[ii]
                encoder_state = self.multimodal_embed(
                    multimodal_state, encoder_state,
                    batch["dialog_mask"].shape[:2])

        # B : belief state.
        if self.params["use_belief_state"]:
            encoder_state = torch.cat((encoder_state, belief_state), dim=1)

        # Predict and execute actions.
        action_logits = self.action_net(encoder_state)
        dialog_mask = batch["dialog_mask"]
        batch_size, num_rounds = dialog_mask.shape
        loss_action = self.criterion(action_logits, batch["action"].view(-1))
        loss_action.masked_fill_((~dialog_mask).view(-1), 0.0)
        loss_action_sum = loss_action.sum() / dialog_mask.sum().item()
        outputs["action_loss"] = loss_action_sum
        if not self.training:
            # Check for action accuracy.
            action_logits = support.unflatten(action_logits, batch_size,
                                              num_rounds)
            actions = action_logits.argmax(dim=-1)
            action_logits = nn.functional.log_softmax(action_logits, dim=-1)
            action_list = self.action_map.get_vocabulary_state()
            # Convert predictions to dictionary.
            action_preds_dict = [{
                "dialog_id":
                batch["dialog_id"][ii].item(),
                "predictions": [{
                    "action":
                    self.action_map.word(actions[ii, jj].item()),
                    "action_log_prob": {
                        action_token: action_logits[ii, jj, kk].item()
                        for kk, action_token in enumerate(action_list)
                    },
                    "attributes": {}
                } for jj in range(batch["dialog_len"][ii])]
            } for ii in range(batch_size)]
            outputs["action_preds"] = action_preds_dict
        else:
            actions = batch["action"]

        # Run classifiers based on the action, record supervision if training.
        if self.training:
            assert ("action_super"
                    in batch), "Need supervision to learn action attributes"
        attr_logits = collections.defaultdict(list)
        attr_loss = collections.defaultdict(list)
        encoder_state_unflat = support.unflatten(encoder_state, batch_size,
                                                 num_rounds)

        host = torch.cuda if self.params["use_gpu"] else torch
        for inst_id in range(batch_size):
            for round_id in range(num_rounds):
                # Turn out of dialog length.
                if not dialog_mask[inst_id, round_id]:
                    continue

                cur_action_ind = actions[inst_id, round_id].item()
                cur_action = self.action_map.word(cur_action_ind)
                cur_state = encoder_state_unflat[inst_id, round_id]
                supervision = batch["action_super"][inst_id][round_id]
                # If there is no supervision, ignore and move on to next round.
                if supervision is None:
                    continue

                # Run classifiers on attributes.
                # Attributes overlaps completely with GT when training.
                if self.training:
                    classifier_list = self.action_metainfo[cur_action][
                        "attributes"]
                    if self.params["domain"] == "furniture":
                        for key in classifier_list:
                            cur_gt = (supervision.get(key, None)
                                      if supervision is not None else None)
                            new_entry = (cur_state, cur_gt, inst_id, round_id)
                            attr_logits[key].append(new_entry)
                    elif self.params["domain"] == "fashion":
                        for key in classifier_list:
                            cur_gt = supervision.get(key, None)
                            gt_indices = host.FloatTensor(
                                len(self.attribute_vocab[key])).fill_(0.)
                            gt_indices[cur_gt] = 1
                            new_entry = (cur_state, gt_indices, inst_id,
                                         round_id)
                            attr_logits[key].append(new_entry)
                    else:
                        raise ValueError(
                            "Domain neither of furniture/fashion!")
                else:
                    classifier_list = self.action_metainfo[cur_action][
                        "attributes"]
                    action_pred_datum = action_preds_dict[inst_id][
                        "predictions"][round_id]
                    if self.params["domain"] == "furniture":
                        # Predict attributes based on the predicted action.
                        for key in classifier_list:
                            classifier = self.classifiers[key]
                            model_pred = classifier(cur_state).argmax(dim=-1)
                            attr_pred = self.attribute_vocab[key][
                                model_pred.item()]
                            action_pred_datum["attributes"][key] = attr_pred
                    elif self.params["domain"] == "fashion":
                        # Predict attributes based on predicted action.
                        for key in classifier_list:
                            classifier = self.classifiers[key]
                            model_pred = classifier(cur_state) > 0
                            attr_pred = [
                                self.attribute_vocab[key][index]
                                for index, ii in enumerate(model_pred) if ii
                            ]
                            action_pred_datum["attributes"][key] = attr_pred
                    else:
                        raise ValueError(
                            "Domain neither of furniture/fashion!")

        # Compute losses if training, else predict.
        if self.training:
            for key, values in attr_logits.items():
                classifier = self.classifiers[key]
                prelogits = [ii[0] for ii in values if ii[1] is not None]
                if not prelogits:
                    continue
                logits = classifier(torch.stack(prelogits, dim=0))
                if self.params["domain"] == "furniture":
                    gt_labels = [ii[1] for ii in values if ii[1] is not None]
                    gt_labels = host.LongTensor(gt_labels)
                    attr_loss[key] = self.criterion_mean(logits, gt_labels)
                elif self.params["domain"] == "fashion":
                    gt_labels = torch.stack(
                        [ii[1] for ii in values if ii[1] is not None], dim=0)
                    attr_loss[key] = self.criterion_multi(logits, gt_labels)
                else:
                    raise ValueError("Domain neither of furniture/fashion!")

            total_attr_loss = host.FloatTensor([0.0])
            if len(attr_loss.values()):
                total_attr_loss = sum(attr_loss.values()) / len(
                    attr_loss.values())
            outputs["action_attr_loss"] = total_attr_loss

        # Obtain action outputs as memory cells to attend over.
        if self.params["use_action_output"]:
            if self.params["domain"] == "furniture":
                encoder_state_out = self.action_output_embed(
                    batch["action_output"],
                    encoder_state_old,
                    batch["dialog_mask"].shape[:2],
                )
            elif self.params["domain"] == "fashion":
                multimodal_state = {}
                for ii in ["memory_images", "focus_images"]:
                    multimodal_state[ii] = batch[ii]
                # For action output, advance focus_images by one time step.
                # Output at step t is input at step t+1.
                feature_size = batch["focus_images"].shape[-1]
                zero_tensor = host.FloatTensor(batch_size, 1,
                                               feature_size).fill_(0.)
                multimodal_state["focus_images"] = torch.cat(
                    [batch["focus_images"][:, 1:, :], zero_tensor], dim=1)
                encoder_state_out = self.multimodal_embed(
                    multimodal_state, encoder_state_old,
                    batch["dialog_mask"].shape[:2])
            else:
                raise ValueError("Domain neither furniture/fashion!")
            outputs["action_output_all"] = encoder_state_out

        outputs.update({
            "action_logits": action_logits,
            "action_attr_loss_dict": attr_loss
        })
        return outputs
示例#4
0
    def load_one_batch(self, sample_ids):
        """Loads a batch, given the sample ids.

        Args:
            sample_ids: List of instance ids to load data for.

        Returns:
            batch: Dictionary with relevant fields for training/evaluation.
        """
        batch = {
            "pad_token": self.pad_token,
            "start_token": self.start_token,
            "sample_ids": sample_ids,
        }
        batch["dialog_len"] = self.raw_data["dialog_len"][sample_ids]
        batch["dialog_id"] = self.raw_data["dialog_id"][sample_ids]
        max_dialog_len = max(batch["dialog_len"])

        user_utt_id = self.raw_data["user_utt_id"][sample_ids]
        batch["user_utt"], batch["user_utt_len"] = self._sample_utterance_pool(
            user_utt_id,
            self.raw_data["user_sent"],
            self.raw_data["user_sent_len"],
            self.params["max_encoder_len"],
        )

        for key in ("assist_in", "assist_out"):
            batch[key], batch[key + "_len"] = self._sample_utterance_pool(
                self.raw_data["assist_utt_id"][sample_ids],
                self.raw_data[key],
                self.raw_data["assist_sent_len"],
                self.params["max_decoder_len"],
            )
        actions = self.raw_data["action"][sample_ids]
        batch["action"] = np.vectorize(lambda x: self.action_map[x])(actions)
        # Construct user, assistant, and dialog masks.
        batch["dialog_mask"] = user_utt_id != -1
        batch["user_mask"] = (batch["user_utt"] == batch["pad_token"]) | (
            batch["user_utt"] == batch["start_token"])
        batch["assist_mask"] = (batch["assist_out"] == batch["pad_token"]) | (
            batch["assist_out"] == batch["start_token"])

        # Get retrieval candidates if needed.
        if self.params["get_retrieval_candidates"]:
            retrieval_inds = self.raw_data["retrieval_candidates"][sample_ids]
            batch_size, num_rounds, _ = retrieval_inds.shape
            flat_inds = torch_support.flatten(retrieval_inds, batch_size,
                                              num_rounds)
            for key in ("assist_in", "assist_out"):
                new_key = key.replace("assist", "candidate")
                cands, cands_len = self._sample_utterance_pool(
                    flat_inds,
                    self.raw_data[key],
                    self.raw_data["assist_sent_len"],
                    self.params["max_decoder_len"],
                )
                batch[new_key] = torch_support.unflatten(
                    cands, batch_size, num_rounds)
                batch[new_key + "_len"] = torch_support.unflatten(
                    cands_len, batch_size, num_rounds)
            batch["candidate_mask"] = (
                (batch["candidate_out"] == batch["pad_token"])
                | (batch["candidate_out"] == batch["start_token"]))

        # Action supervision.
        batch["action_super"] = [
            self.raw_data["action_supervision"][ii] for ii in sample_ids
        ]

        # Fetch facts if required.
        if self.params["encoder"] == "memory_network":
            batch["fact"] = self.raw_data["fact"][sample_ids]
            batch["fact_len"] = self.raw_data["fact_len"][sample_ids]

        # Trim to the maximum dialog length.
        for key in ("assist_in", "assist_out", "candidate_in", "candidate_out",
                    "user_utt", "fact", "user_mask", "assist_mask",
                    "candidate_mask"):
            if key in batch:
                batch[key] = batch[key][:, :max_dialog_len]
        for key in (
                "action",
                "assist_in_len",
                "assist_out_len",
                "candidate_in_len",
                "candidate_out_len",
                "user_utt_len",
                "dialog_mask",
                "fact_len",
        ):
            if key in batch:
                batch[key] = batch[key][:, :max_dialog_len]
        # TF-IDF features.
        if self.params["encoder"] == "tf_idf":
            batch["user_tf_idf"] = self.compute_tf_features(
                batch["user_utt"], batch["user_utt_len"])

        # Domain-specific processing.
        if self.params["domain"] == "furniture":
            # Carousel states.
            if self.params["use_multimodal_state"]:
                batch["carousel_state"] = [
                    self.raw_data["carousel_state"][ii] for ii in sample_ids
                ]
            # Action output.
            if self.params["use_action_output"]:
                batch["action_output"] = [
                    self.raw_data["action_output_state"][ii]
                    for ii in sample_ids
                ]
        elif self.params["domain"] == "fashion":
            # Asset embeddings -- memory, database, focus images.
            for dtype in ["memory", "database", "focus"]:
                indices = self.raw_data["{}_inds".format(dtype)][sample_ids]
                image_embeds = self.embed_data["embedding"][indices]
                batch["{}_images".format(dtype)] = image_embeds
        else:
            raise ValueError("Domain must be either furniture/fashion!")
        return self._ship_torch_batch(batch)
示例#5
0
    def forward(self, batch, mode=None):
        """Forward propagation.

        Args:
          batch: Dict of batch input variables.
          mode: None for training or teaching forcing evaluation;
                BEAMSEARCH / SAMPLE / MAX to generate text
        """
        outputs = self.encoder(batch)
        action_output = self.action_executor(batch, outputs)
        outputs.update(action_output)
        decoder_output = self.decoder(batch, outputs)
        if mode:
            generation_output = self.decoder.forward_beamsearch_multiple(
                batch, outputs, mode)
            outputs.update(generation_output)

        # If evaluating by retrieval, construct fake batch for each candidate.
        # Inputs from batch used in decoder:
        #   assist_in, assist_out, assist_in_len, assist_mask
        if self.params["retrieval_evaluation"] and not self.training:
            option_scores = []
            batch_size, num_rounds, num_candidates, _ = batch[
                "candidate_in"].shape
            replace_keys = ("assist_in", "assist_out", "assist_in_len",
                            "assist_mask")
            for ii in range(num_candidates):
                for key in replace_keys:
                    new_key = key.replace("assist", "candidate")
                    batch[key] = batch[new_key][:, :, ii]
                decoder_output = self.decoder(batch, outputs)
                log_probs = torch_support.unflatten(
                    decoder_output["loss_token"], batch_size, num_rounds)
                option_scores.append(-1 * log_probs.sum(-1))
            option_scores = torch.stack(option_scores, 2)
            outputs["candidate_scores"] = [{
                "dialog_id":
                batch["dialog_id"][ii].item(),
                "candidate_scores": [
                    list(option_scores[ii, jj].cpu().numpy())
                    for jj in range(batch["dialog_len"][ii])
                ]
            } for ii in range(batch_size)]

        # Local aliases.
        loss_token = decoder_output["loss_token"]
        pad_mask = decoder_output["pad_mask"]
        if self.training:
            loss_token = loss_token.sum() / (~pad_mask).sum().item()
            loss_action = action_output["action_loss"]
            loss_action_attr = action_output["action_attr_loss"]
            loss_total = loss_action + loss_token + loss_action_attr
            return {
                "token": loss_token,
                "action": loss_action,
                "action_attr": loss_action_attr,
                "total": loss_total,
            }
        else:
            outputs.update({
                "loss_sum": loss_token.sum(),
                "num_tokens": (~pad_mask).sum()
            })
            return outputs