예제 #1
0
    def forward(self, batch):
        """Forward pass through the encoder.

        Args:
            batch: Dict of batch variables.

        Returns:
            encoder_outputs: Dict of outputs from the forward pass.
        """
        encoder_out = {}
        # Flatten for history_agnostic encoder.
        batch_size, num_rounds, max_length = batch["user_utt"].shape
        encoder_in = support.flatten(batch["user_utt"], batch_size, num_rounds)
        encoder_len = support.flatten(batch["user_utt_len"], batch_size,
                                      num_rounds)
        word_embeds_enc = self.word_embed_net(encoder_in)
        # Text encoder: LSTM or Transformer.
        if self.params["text_encoder"] == "lstm":
            all_enc_states, enc_states = rnn.dynamic_rnn(self.encoder_unit,
                                                         word_embeds_enc,
                                                         encoder_len,
                                                         return_states=True)
            encoder_out["hidden_states_all"] = all_enc_states
            encoder_out["hidden_state"] = enc_states

        elif self.params["text_encoder"] == "transformer":
            enc_embeds = self.pos_encoder(word_embeds_enc).transpose(0, 1)
            enc_pad_mask = batch["user_utt"] == batch["pad_token"]
            enc_pad_mask = support.flatten(enc_pad_mask, batch_size,
                                           num_rounds)
            enc_states = self.encoder_unit(enc_embeds,
                                           src_key_padding_mask=enc_pad_mask)
            encoder_out["hidden_states_all"] = enc_states.transpose(0, 1)
        return encoder_out
예제 #2
0
    def forward(self, batch):
        """Forward pass through the encoder.

        Args:
            batch: Dict of batch variables.

        Returns:
            encoder_outputs: Dict of outputs from the forward pass.
        """
        encoder_out = {}
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        # Flatten to encode sentences.
        batch_size, num_rounds, _ = batch["user_utt"].shape
        encoder_in = support.flatten(batch["user_utt"], batch_size, num_rounds)
        encoder_len = batch["user_utt_len"].reshape(-1)
        word_embeds_enc = self.word_embed_net(encoder_in)

        # Fake encoder_len to be non-zero even for utterances out of dialog.
        fake_encoder_len = encoder_len.eq(0).long() + encoder_len
        all_enc_states, enc_states = rnn.dynamic_rnn(self.encoder_unit,
                                                     word_embeds_enc,
                                                     fake_encoder_len,
                                                     return_states=True)
        encoder_out["hidden_states_all"] = all_enc_states
        encoder_out["hidden_state"] = enc_states

        utterance_enc = enc_states[0][-1]
        new_size = (batch_size, num_rounds, utterance_enc.shape[-1])
        utterance_enc = utterance_enc.reshape(new_size)
        encoder_out["dialog_context"], _ = rnn.dynamic_rnn(self.dialog_unit,
                                                           utterance_enc,
                                                           batch["dialog_len"],
                                                           return_states=True)
        return encoder_out
예제 #3
0
    def forward(self, batch):
        """Forward pass through the encoder.

        Args:
            batch: Dict of batch variables.

        Returns:
            encoder_outputs: Dict of outputs from the forward pass.
        """
        encoder_out = {}
        # Flatten to encode sentences.
        batch_size, num_rounds, _ = batch["user_utt"].shape
        encoder_in = support.flatten(batch["user_utt"], batch_size, num_rounds)
        encoder_len = batch["user_utt_len"].reshape(-1)
        word_embeds_enc = self.word_embed_net(encoder_in)

        # Fake encoder_len to be non-zero even for utterances out of dialog.
        fake_encoder_len = encoder_len.eq(0).long() + encoder_len
        all_enc_states, enc_states = rnn.dynamic_rnn(self.encoder_unit,
                                                     word_embeds_enc,
                                                     fake_encoder_len,
                                                     return_states=True)
        encoder_out["hidden_states_all"] = all_enc_states
        encoder_out["hidden_state"] = enc_states

        utterance_enc = enc_states[0][-1]
        batch["utterance_enc"] = support.unflatten(utterance_enc, batch_size,
                                                   num_rounds)
        encoder_out["dialog_context"] = self._memory_net_forward(batch)
        return encoder_out
예제 #4
0
    def forward(self, multimodal_state, encoder_state, encoder_size):
        """Multimodal Embedding.

        Args:
            multimodal_state: Dict with memory, database, and focus images
            encoder_state: State of the encoder
            encoder_size: (batch_size, num_rounds)

        Returns:
            multimodal_encode: Encoder state with multimodal information
        """
        # Setup category states if None.
        if self.category_state is None:
            self._setup_category_states()
        # Attend to multimodal memory using encoder states.
        batch_size, num_rounds = encoder_size
        memory_images = multimodal_state["memory_images"]
        memory_images = memory_images.unsqueeze(1).expand(
            -1, num_rounds, -1, -1)
        focus_images = multimodal_state["focus_images"][:, :num_rounds, :]
        focus_images = focus_images.unsqueeze(2)
        all_images = torch.cat([focus_images, memory_images], dim=2)
        all_images_flat = support.flatten(all_images, batch_size, num_rounds)
        category_state = self.category_state.expand(batch_size * num_rounds,
                                                    -1, -1)
        cat_images = torch.cat([all_images_flat, category_state], dim=-1)
        multimodal_memory = self.multimodal_embed_net(cat_images)
        # Key (L, N, E), value (L, N, E), query (S, N, E)
        multimodal_memory = multimodal_memory.transpose(0, 1)
        query = encoder_state.unsqueeze(0)
        attended_query, attented_wts = self.multimodal_attend(
            query, multimodal_memory, multimodal_memory)
        multimodal_encode = torch.cat(
            [attended_query.squeeze(0), encoder_state], dim=-1)
        return multimodal_encode
예제 #5
0
    def _memory_net_forward(self, batch):
        """Forward pass for memory network to look up fact.

        1. Encodes fact via fact rnn.
        2. Computes attention with fact and utterance encoding.
        3. Attended fact vector and question encoding -> new encoding.

        Args:
          batch: Dict of hist, hist_len, hidden_state
        """
        # kwon : fact = prevuiys utterance + response concatenated as one
        # For example, 'What is the color of the couch? A : Red.'
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        batch_size, num_rounds, enc_time_steps = batch["fact"].shape
        all_ones = np.full((num_rounds, num_rounds), 1)
        fact_mask = np.triu(all_ones, 1)
        fact_mask = np.expand_dims(np.expand_dims(fact_mask, -1), 0)
        fact_mask = torch.BoolTensor(fact_mask)
        if self.params["use_gpu"]:
            fact_mask = fact_mask.cuda()
        fact_mask.requires_grad_(False)

        fact_in = support.flatten(batch["fact"], batch_size, num_rounds)
        fact_len = support.flatten(batch["fact_len"], batch_size, num_rounds)
        fact_embeds = self.word_embed_net(fact_in)

        # Encoder fact and unflatten the last hidden state.
        _, (hidden_state, _) = rnn.dynamic_rnn(self.fact_unit,
                                               fact_embeds,
                                               fact_len,
                                               return_states=True)
        fact_encode = support.unflatten(hidden_state[-1], batch_size,
                                        num_rounds)
        fact_encode = fact_encode.unsqueeze(1).expand(-1, num_rounds, -1, -1)

        utterance_enc = batch["utterance_enc"].unsqueeze(2)
        utterance_enc = utterance_enc.expand(-1, -1, num_rounds, -1)

        # Combine, compute attention, mask, and weight the fact encodings.
        combined_encode = torch.cat([utterance_enc, fact_encode], dim=-1)
        attention = self.fact_attention_net(combined_encode)
        attention.masked_fill_(fact_mask, float("-Inf"))
        attention = self.softmax(attention, dim=2)
        attended_fact = (attention * fact_encode).sum(2)
        return attended_fact
예제 #6
0
    def forward(self, multimodal_state, encoder_state, encoder_size):
        """Multimodal Embedding.

        Args:
            multimodal_state: Dict with memory, database, and focus images
            encoder_state: State of the encoder
            encoder_size: (batch_size, num_rounds)

        Returns:
            multimodal_encode: Encoder state with multimodal information
        """
        # Setup category states if None.
        if self.category_state is None:
            self._setup_category_states()
        # Attend to multimodal memory using encoder states.
        batch_size, num_rounds = encoder_size
        memory_images = multimodal_state["memory_images"]
        memory_images = memory_images.unsqueeze(1).expand(
            -1, num_rounds, -1, -1)
        focus_images = multimodal_state["focus_images"][:, :num_rounds, :]
        focus_images = focus_images.unsqueeze(2)
        all_images = torch.cat([focus_images, memory_images], dim=2)
        all_images_flat = support.flatten(all_images, batch_size, num_rounds)
        category_state = self.category_state.expand(batch_size * num_rounds,
                                                    -1, -1)
        cat_images = torch.cat([all_images_flat, category_state], dim=-1)
        multimodal_memory = self.multimodal_embed_net(cat_images)
        # Key (L, N, E), value (L, N, E), query (S, N, E)
        multimodal_memory = multimodal_memory.transpose(0, 1)
        query = encoder_state.unsqueeze(0)

        #MAG
        if self.params['gate_type'] == "MAG":
            multimodal_encode = torch.cat(
                [self.MAG(query, multimodal_memory).squeeze(0), encoder_state],
                dim=-1)
        elif self.params['gate_type'] == "none":
            attended_query, attented_wts = self.multimodal_attend(
                query, multimodal_memory, multimodal_memory)
            multimodal_encode = torch.cat(
                [attended_query.squeeze(0), encoder_state], dim=-1)
        elif self.params['gate_type'] == "MMI":
            attended_query_q, attended_wts_q = self.multimodal_attend(
                query, multimodal_memory, multimodal_memory)
            v_input = multimodal_memory.mean(dim=0)
            v_input = v_input.unsqueeze(0)
            attended_query_p, attended_wts_p = self.multimodal_attend(
                v_input, query, query)
            attended_query_a, attended_wts_a = self.multimodal_attend(
                query, attended_query_p, attended_query_p)
            b = self.Visualgate(attended_query_a, attended_query_q)
            multimodal_encode = torch.cat(
                [attended_query_a.squeeze(0),
                 b.squeeze(0)], dim=-1)

        return multimodal_encode
예제 #7
0
    def forward(self, batch, prev_outputs):
        """Forward pass a given batch.

        Args:
            batch: Batch to forward pass
            prev_outputs: Output from previous modules.

        Returns:
            outputs: Dict of expected outputs
        """
        outputs = {}

        if self.params[
                "use_action_attention"] and self.params["encoder"] != "tf_idf":
            encoder_state = prev_outputs["hidden_states_all"]
            batch_size, num_rounds, max_len = batch["user_mask"].shape
            encoder_mask = batch["user_utt"].eq(batch["pad_token"])
            encoder_mask = support.flatten(encoder_mask, batch_size,
                                           num_rounds)
            encoder_state = self.attention_net(encoder_state, encoder_mask)
        else:
            encoder_state = prev_outputs["hidden_state"][0][-1]

        encoder_state_old = encoder_state
        # Multimodal state.
        if self.params["use_multimodal_state"]:
            if self.params["domain"] == "furniture":
                encoder_state = self.multimodal_embed(
                    batch["carousel_state"], encoder_state,
                    batch["dialog_mask"].shape[:2])
            elif self.params["domain"] == "fashion":
                multimodal_state = {}
                for ii in ["memory_images", "focus_images"]:
                    multimodal_state[ii] = batch[ii]
                encoder_state = self.multimodal_embed(
                    multimodal_state, encoder_state,
                    batch["dialog_mask"].shape[:2])

        # B : belief state.
        if self.params["use_belief_state"]:
            encoder_state = torch.cat((encoder_state, belief_state), dim=1)

        # Predict and execute actions.
        action_logits = self.action_net(encoder_state)
        dialog_mask = batch["dialog_mask"]
        batch_size, num_rounds = dialog_mask.shape
        loss_action = self.criterion(action_logits, batch["action"].view(-1))
        loss_action.masked_fill_((~dialog_mask).view(-1), 0.0)
        loss_action_sum = loss_action.sum() / dialog_mask.sum().item()
        outputs["action_loss"] = loss_action_sum
        if not self.training:
            # Check for action accuracy.
            action_logits = support.unflatten(action_logits, batch_size,
                                              num_rounds)
            actions = action_logits.argmax(dim=-1)
            action_logits = nn.functional.log_softmax(action_logits, dim=-1)
            action_list = self.action_map.get_vocabulary_state()
            # Convert predictions to dictionary.
            action_preds_dict = [{
                "dialog_id":
                batch["dialog_id"][ii].item(),
                "predictions": [{
                    "action":
                    self.action_map.word(actions[ii, jj].item()),
                    "action_log_prob": {
                        action_token: action_logits[ii, jj, kk].item()
                        for kk, action_token in enumerate(action_list)
                    },
                    "attributes": {}
                } for jj in range(batch["dialog_len"][ii])]
            } for ii in range(batch_size)]
            outputs["action_preds"] = action_preds_dict
        else:
            actions = batch["action"]

        # Run classifiers based on the action, record supervision if training.
        if self.training:
            assert ("action_super"
                    in batch), "Need supervision to learn action attributes"
        attr_logits = collections.defaultdict(list)
        attr_loss = collections.defaultdict(list)
        encoder_state_unflat = support.unflatten(encoder_state, batch_size,
                                                 num_rounds)

        host = torch.cuda if self.params["use_gpu"] else torch
        for inst_id in range(batch_size):
            for round_id in range(num_rounds):
                # Turn out of dialog length.
                if not dialog_mask[inst_id, round_id]:
                    continue

                cur_action_ind = actions[inst_id, round_id].item()
                cur_action = self.action_map.word(cur_action_ind)
                cur_state = encoder_state_unflat[inst_id, round_id]
                supervision = batch["action_super"][inst_id][round_id]
                # If there is no supervision, ignore and move on to next round.
                if supervision is None:
                    continue

                # Run classifiers on attributes.
                # Attributes overlaps completely with GT when training.
                if self.training:
                    classifier_list = self.action_metainfo[cur_action][
                        "attributes"]
                    if self.params["domain"] == "furniture":
                        for key in classifier_list:
                            cur_gt = (supervision.get(key, None)
                                      if supervision is not None else None)
                            new_entry = (cur_state, cur_gt, inst_id, round_id)
                            attr_logits[key].append(new_entry)
                    elif self.params["domain"] == "fashion":
                        for key in classifier_list:
                            cur_gt = supervision.get(key, None)
                            gt_indices = host.FloatTensor(
                                len(self.attribute_vocab[key])).fill_(0.)
                            gt_indices[cur_gt] = 1
                            new_entry = (cur_state, gt_indices, inst_id,
                                         round_id)
                            attr_logits[key].append(new_entry)
                    else:
                        raise ValueError(
                            "Domain neither of furniture/fashion!")
                else:
                    classifier_list = self.action_metainfo[cur_action][
                        "attributes"]
                    action_pred_datum = action_preds_dict[inst_id][
                        "predictions"][round_id]
                    if self.params["domain"] == "furniture":
                        # Predict attributes based on the predicted action.
                        for key in classifier_list:
                            classifier = self.classifiers[key]
                            model_pred = classifier(cur_state).argmax(dim=-1)
                            attr_pred = self.attribute_vocab[key][
                                model_pred.item()]
                            action_pred_datum["attributes"][key] = attr_pred
                    elif self.params["domain"] == "fashion":
                        # Predict attributes based on predicted action.
                        for key in classifier_list:
                            classifier = self.classifiers[key]
                            model_pred = classifier(cur_state) > 0
                            attr_pred = [
                                self.attribute_vocab[key][index]
                                for index, ii in enumerate(model_pred) if ii
                            ]
                            action_pred_datum["attributes"][key] = attr_pred
                    else:
                        raise ValueError(
                            "Domain neither of furniture/fashion!")

        # Compute losses if training, else predict.
        if self.training:
            for key, values in attr_logits.items():
                classifier = self.classifiers[key]
                prelogits = [ii[0] for ii in values if ii[1] is not None]
                if not prelogits:
                    continue
                logits = classifier(torch.stack(prelogits, dim=0))
                if self.params["domain"] == "furniture":
                    gt_labels = [ii[1] for ii in values if ii[1] is not None]
                    gt_labels = host.LongTensor(gt_labels)
                    attr_loss[key] = self.criterion_mean(logits, gt_labels)
                elif self.params["domain"] == "fashion":
                    gt_labels = torch.stack(
                        [ii[1] for ii in values if ii[1] is not None], dim=0)
                    attr_loss[key] = self.criterion_multi(logits, gt_labels)
                else:
                    raise ValueError("Domain neither of furniture/fashion!")

            total_attr_loss = host.FloatTensor([0.0])
            if len(attr_loss.values()):
                total_attr_loss = sum(attr_loss.values()) / len(
                    attr_loss.values())
            outputs["action_attr_loss"] = total_attr_loss

        # Obtain action outputs as memory cells to attend over.
        if self.params["use_action_output"]:
            if self.params["domain"] == "furniture":
                encoder_state_out = self.action_output_embed(
                    batch["action_output"],
                    encoder_state_old,
                    batch["dialog_mask"].shape[:2],
                )
            elif self.params["domain"] == "fashion":
                multimodal_state = {}
                for ii in ["memory_images", "focus_images"]:
                    multimodal_state[ii] = batch[ii]
                # For action output, advance focus_images by one time step.
                # Output at step t is input at step t+1.
                feature_size = batch["focus_images"].shape[-1]
                zero_tensor = host.FloatTensor(batch_size, 1,
                                               feature_size).fill_(0.)
                multimodal_state["focus_images"] = torch.cat(
                    [batch["focus_images"][:, 1:, :], zero_tensor], dim=1)
                encoder_state_out = self.multimodal_embed(
                    multimodal_state, encoder_state_old,
                    batch["dialog_mask"].shape[:2])
            else:
                raise ValueError("Domain neither furniture/fashion!")
            outputs["action_output_all"] = encoder_state_out

        outputs.update({
            "action_logits": action_logits,
            "action_attr_loss_dict": attr_loss
        })
        return outputs
예제 #8
0
    def load_one_batch(self, sample_ids):
        """Loads a batch, given the sample ids.

        Args:
            sample_ids: List of instance ids to load data for.

        Returns:
            batch: Dictionary with relevant fields for training/evaluation.
        """
        batch = {
            "pad_token": self.pad_token,
            "start_token": self.start_token,
            "sample_ids": sample_ids,
        }
        batch["dialog_len"] = self.raw_data["dialog_len"][sample_ids]
        batch["dialog_id"] = self.raw_data["dialog_id"][sample_ids]
        max_dialog_len = max(batch["dialog_len"])

        user_utt_id = self.raw_data["user_utt_id"][sample_ids]
        batch["user_utt"], batch["user_utt_len"] = self._sample_utterance_pool(
            user_utt_id,
            self.raw_data["user_sent"],
            self.raw_data["user_sent_len"],
            self.params["max_encoder_len"],
        )

        for key in ("assist_in", "assist_out"):
            batch[key], batch[key + "_len"] = self._sample_utterance_pool(
                self.raw_data["assist_utt_id"][sample_ids],
                self.raw_data[key],
                self.raw_data["assist_sent_len"],
                self.params["max_decoder_len"],
            )
        actions = self.raw_data["action"][sample_ids]
        batch["action"] = np.vectorize(lambda x: self.action_map[x])(actions)
        # Construct user, assistant, and dialog masks.
        batch["dialog_mask"] = user_utt_id != -1
        batch["user_mask"] = (batch["user_utt"] == batch["pad_token"]) | (
            batch["user_utt"] == batch["start_token"])
        batch["assist_mask"] = (batch["assist_out"] == batch["pad_token"]) | (
            batch["assist_out"] == batch["start_token"])

        # Get retrieval candidates if needed.
        if self.params["get_retrieval_candidates"]:
            retrieval_inds = self.raw_data["retrieval_candidates"][sample_ids]
            batch_size, num_rounds, _ = retrieval_inds.shape
            flat_inds = torch_support.flatten(retrieval_inds, batch_size,
                                              num_rounds)
            for key in ("assist_in", "assist_out"):
                new_key = key.replace("assist", "candidate")
                cands, cands_len = self._sample_utterance_pool(
                    flat_inds,
                    self.raw_data[key],
                    self.raw_data["assist_sent_len"],
                    self.params["max_decoder_len"],
                )
                batch[new_key] = torch_support.unflatten(
                    cands, batch_size, num_rounds)
                batch[new_key + "_len"] = torch_support.unflatten(
                    cands_len, batch_size, num_rounds)
            batch["candidate_mask"] = (
                (batch["candidate_out"] == batch["pad_token"])
                | (batch["candidate_out"] == batch["start_token"]))

        # Action supervision.
        batch["action_super"] = [
            self.raw_data["action_supervision"][ii] for ii in sample_ids
        ]

        # Fetch facts if required.
        if self.params["encoder"] == "memory_network":
            batch["fact"] = self.raw_data["fact"][sample_ids]
            batch["fact_len"] = self.raw_data["fact_len"][sample_ids]

        # Trim to the maximum dialog length.
        for key in ("assist_in", "assist_out", "candidate_in", "candidate_out",
                    "user_utt", "fact", "user_mask", "assist_mask",
                    "candidate_mask"):
            if key in batch:
                batch[key] = batch[key][:, :max_dialog_len]
        for key in (
                "action",
                "assist_in_len",
                "assist_out_len",
                "candidate_in_len",
                "candidate_out_len",
                "user_utt_len",
                "dialog_mask",
                "fact_len",
        ):
            if key in batch:
                batch[key] = batch[key][:, :max_dialog_len]
        # TF-IDF features.
        if self.params["encoder"] == "tf_idf":
            batch["user_tf_idf"] = self.compute_tf_features(
                batch["user_utt"], batch["user_utt_len"])

        # Domain-specific processing.
        if self.params["domain"] == "furniture":
            # Carousel states.
            if self.params["use_multimodal_state"]:
                batch["carousel_state"] = [
                    self.raw_data["carousel_state"][ii] for ii in sample_ids
                ]
            # Action output.
            if self.params["use_action_output"]:
                batch["action_output"] = [
                    self.raw_data["action_output_state"][ii]
                    for ii in sample_ids
                ]
        elif self.params["domain"] == "fashion":
            # Asset embeddings -- memory, database, focus images.
            for dtype in ["memory", "database", "focus"]:
                indices = self.raw_data["{}_inds".format(dtype)][sample_ids]
                image_embeds = self.embed_data["embedding"][indices]
                batch["{}_images".format(dtype)] = image_embeds
        else:
            raise ValueError("Domain must be either furniture/fashion!")
        return self._ship_torch_batch(batch)
예제 #9
0
    def forward(self, batch, encoder_output):
        """Forward pass through the decoder.

        Args:
            batch: Dict of batch variables.
            encoder_output: Dict of outputs from the encoder.

        Returns:
            decoder_outputs: Dict of outputs from the forward pass.
        """
        # Flatten for history_agnostic encoder.
        batch_size, num_rounds, max_length = batch["assist_in"].shape
        decoder_in = support.flatten(batch["assist_in"], batch_size,
                                     num_rounds)
        decoder_out = support.flatten(batch["assist_out"], batch_size,
                                      num_rounds)
        decoder_len = support.flatten(batch["assist_in_len"], batch_size,
                                      num_rounds)
        word_embeds_dec = self.word_embed_net(decoder_in)

        if self.params["encoder"] in self.DIALOG_CONTEXT_ENCODERS:
            dialog_context = support.flatten(encoder_output["dialog_context"],
                                             batch_size,
                                             num_rounds).unsqueeze(1)
            dialog_context = dialog_context.expand(-1, max_length, -1)
            decoder_steps_in = torch.cat([dialog_context, word_embeds_dec], -1)
        else:
            decoder_steps_in = word_embeds_dec

        # Encoder states conditioned on action outputs, if need be.
        if self.params["use_action_output"]:
            action_out = encoder_output["action_output_all"].unsqueeze(1)
            time_steps = encoder_output["hidden_states_all"].shape[1]
            fusion_out = torch.cat(
                [
                    encoder_output["hidden_states_all"],
                    action_out.expand(-1, time_steps, -1),
                ],
                dim=-1,
            )
            encoder_output["hidden_states_all"] = self.action_fusion_net(
                fusion_out)

        if self.params["text_encoder"] == "transformer":
            # Check the status of no_peek_mask.
            if self.no_peek_mask is None or self.no_peek_mask.size(
                    0) != max_length:
                self.no_peek_mask = self._generate_no_peek_mask(max_length)

            hidden_state = encoder_output["hidden_states_all"]
            enc_pad_mask = batch["user_utt"] == batch["pad_token"]
            enc_pad_mask = support.flatten(enc_pad_mask, batch_size,
                                           num_rounds)
            dec_pad_mask = batch["assist_in"] == batch["pad_token"]
            dec_pad_mask = support.flatten(dec_pad_mask, batch_size,
                                           num_rounds)
            if self.params["encoder"] != "pretrained_transformer":
                dec_embeds = self.pos_encoder(decoder_steps_in).transpose(0, 1)
                outputs = self.decoder_unit(
                    dec_embeds,
                    hidden_state.transpose(0, 1),
                    memory_key_padding_mask=enc_pad_mask,
                    tgt_mask=self.no_peek_mask,
                    tgt_key_padding_mask=dec_pad_mask,
                )
                outputs = outputs.transpose(0, 1)
            else:
                outputs = self.decoder_unit(
                    inputs_embeds=decoder_steps_in,
                    attention_mask=~dec_pad_mask,
                    encoder_hidden_states=hidden_state,
                    encoder_attention_mask=~enc_pad_mask,
                )
                outputs = outputs[0]
        else:
            hidden_state = encoder_output["hidden_state"]
            if self.params["encoder"] == "tf_idf":
                hidden_state = None

            # If Bahdahnue attention is to be used.
            if (self.params["use_bahdanau_attention"]
                    and self.params["encoder"] != "tf_idf"):
                encoder_states = encoder_output["hidden_states_all"]
                max_decoder_len = min(decoder_in.shape[1],
                                      self.params["max_decoder_len"])
                encoder_states_proj = self.attention_net(encoder_states)
                enc_mask = (
                    batch["user_utt"] == batch["pad_token"]).unsqueeze(-1)
                enc_mask = support.flatten(enc_mask, batch_size, num_rounds)
                outputs = []
                for step in range(max_decoder_len):
                    previous_state = hidden_state[0][-1].unsqueeze(1)
                    att_logits = previous_state * encoder_states_proj
                    att_logits = att_logits.sum(dim=-1, keepdim=True)
                    # Use encoder mask to replace <pad> with -Inf.
                    att_logits.masked_fill_(enc_mask, float("-Inf"))
                    att_wts = nn.functional.softmax(att_logits, dim=1)
                    context = (encoder_states * att_wts).sum(1, keepdim=True)
                    # Run through LSTM.
                    concat_in = [
                        context, decoder_steps_in[:, step:step + 1, :]
                    ]
                    step_in = torch.cat(concat_in, dim=-1)
                    decoder_output, hidden_state = self.decoder_unit(
                        step_in, hidden_state)
                    concat_out = torch.cat([decoder_output, context], dim=-1)
                    outputs.append(concat_out)
                outputs = torch.cat(outputs, dim=1)
            else:
                outputs = rnn.dynamic_rnn(
                    self.decoder_unit,
                    decoder_steps_in,
                    decoder_len,
                    init_state=hidden_state,
                )
        if self.params["encoder"] == "pretrained_transformer":
            output_logits = outputs
        else:
            # Logits over vocabulary.
            output_logits = self.inv_word_net(outputs)
        # Mask out the criterion while summing.
        pad_mask = support.flatten(batch["assist_mask"], batch_size,
                                   num_rounds)
        loss_token = self.criterion(output_logits.transpose(1, 2), decoder_out)
        loss_token.masked_fill_(pad_mask, 0.0)
        return {"loss_token": loss_token, "pad_mask": pad_mask}
예제 #10
0
    def forward_beamsearch_single(self, batch, encoder_output, mode_params):
        """Evaluates the model using beam search with batch size 1.

        NOTE: Current implementation only supports beam search for batch size 1
              and for RNN text_encoder (will be extended for transformers)

        Args:
            batch: Dictionary of inputs, with batch size of 1
            beam_size: Number of beams

        Returns:
            top_beams: Dictionary of top beams
        """
        # Initializations and aliases.
        # Tensors are either on CPU or GPU.
        LENGTH_NORM = True
        self.host = torch.cuda if self.params["use_gpu"] else torch
        end_token = self.params["end_token"]
        start_token = self.params["start_token"]
        beam_size = mode_params["beam_size"]
        max_decoder_len = self.params["max_decoder_len"]

        if self.params["text_encoder"] == "transformer":
            hidden_state = encoder_output["hidden_states_all"].transpose(0, 1)
            max_enc_len, batch_size, enc_embed_size = hidden_state.shape
            hidden_state_expand = hidden_state.expand(max_enc_len, beam_size,
                                                      enc_embed_size)
            enc_pad_mask = batch["user_utt"] == batch["pad_token"]
            enc_pad_mask = support.flatten(enc_pad_mask, 1, 1)
            enc_pad_mask_expand = enc_pad_mask.expand(beam_size, max_enc_len)
            if (self.no_peek_mask is None
                    or self.no_peek_mask.size(0) != max_decoder_len):
                self.no_peek_mask = self._generate_no_peek_mask(
                    max_decoder_len)
        elif self.params["text_encoder"] == "lstm":
            hidden_state = encoder_output["hidden_state"]

        if (self.params["use_bahdanau_attention"]
                and self.params["encoder"] != "tf_idf"
                and self.params["text_encoder"] == "lstm"):
            encoder_states = encoder_output["hidden_states_all"]
            encoder_states_proj = self.attention_net(encoder_states)
            enc_mask = (batch["user_utt"] == batch["pad_token"]).unsqueeze(-1)
            enc_mask = support.flatten(enc_mask, 1, 1)

        # Per instance initializations.
        # Copy the hidden state beam_size number of times.
        if hidden_state is not None:
            hidden_state = [ii.repeat(1, beam_size, 1) for ii in hidden_state]
        beams = {-1: self.host.LongTensor(1, beam_size).fill_(start_token)}
        beam_scores = self.host.FloatTensor(beam_size, 1).fill_(0.)
        finished_beams = self.host.ByteTensor(beam_size, 1).fill_(False)
        zero_tensor = self.host.LongTensor(beam_size, 1).fill_(end_token)
        reverse_inds = {}
        # Generate beams until max_len time steps.
        for step in range(max_decoder_len - 1):
            if self.params["text_encoder"] == "transformer":
                beams, tokens_list = self._backtrack_beams(beams, reverse_inds)
                beam_tokens = torch.cat(tokens_list, dim=0).transpose(0, 1)
                beam_tokens_embed = self.word_embed_net(beam_tokens)

                if self.params["encoder"] != "pretrained_transformer":
                    dec_embeds = self.pos_encoder(beam_tokens_embed).transpose(
                        0, 1)
                    output = self.decoder_unit(
                        dec_embeds,
                        hidden_state_expand,
                        tgt_mask=self.no_peek_mask[:step + 1, :step + 1],
                        memory_key_padding_mask=enc_pad_mask_expand,
                    )
                    logits = self.inv_word_net(output[-1])
                else:
                    outputs = self.decoder_unit(
                        inputs_embeds=beam_tokens_embed,
                        encoder_hidden_states=hidden_state_expand.transpose(
                            0, 1),
                        encoder_attention_mask=~enc_pad_mask_expand,
                    )
                    logits = outputs[0][:, -1, :]

            elif self.params["text_encoder"] == "lstm":
                beam_tokens = beams[step - 1].t()
                beam_tokens_embed = self.word_embed_net(beam_tokens)
                # Append dialog context if exists.
                if self.params["encoder"] in self.DIALOG_CONTEXT_ENCODERS:
                    dialog_context = encoder_output["dialog_context"]
                    beam_tokens_embed = torch.cat(
                        [
                            dialog_context.repeat(beam_size, 1, 1),
                            beam_tokens_embed
                        ],
                        dim=-1,
                    )
                # Use bahdanau attention over encoder hidden states.
                if (self.params["use_bahdanau_attention"]
                        and self.params["encoder"] != "tf_idf"):
                    previous_state = hidden_state[0][-1].unsqueeze(1)
                    att_logits = previous_state * encoder_states_proj
                    att_logits = att_logits.sum(dim=-1, keepdim=True)
                    # Use encoder mask to replace <pad> with -Inf.
                    att_logits.masked_fill_(enc_mask, float("-Inf"))
                    att_wts = nn.functional.softmax(att_logits, dim=1)
                    context = (encoder_states * att_wts).sum(1, keepdim=True)
                    # Run through LSTM.
                    step_in = torch.cat([context, beam_tokens_embed], dim=-1)
                    decoder_output, new_state = self.decoder_unit(
                        step_in, hidden_state)
                    output = torch.cat([decoder_output, context], dim=-1)
                else:
                    output, new_state = self.decoder_unit(
                        beam_tokens_embed, hidden_state)
                logits = self.inv_word_net(output).squeeze(1)
            log_probs = nn.functional.log_softmax(logits, dim=-1)
            # Compute the new beam scores.
            alive = finished_beams.eq(0).float()
            if LENGTH_NORM:
                # Add (current log probs / (step + 1))
                cur_weight = alive / (step + 1)
                # Add (previous log probs * (t/t+1) ) <- Mean update
                prev_weight = alive * step / (step + 1)
            else:
                # No length normalization.
                cur_weight = alive
                prev_weight = alive

            # Compute the new beam extensions.
            if step == 0:
                # For the first step, make all but first beam
                # probabilities -inf.
                log_probs[1:, :] = float("-inf")
            new_scores = log_probs * cur_weight + beam_scores * prev_weight
            finished_beam_scores = beam_scores * finished_beams.float()
            new_scores.scatter_add_(1, zero_tensor, finished_beam_scores)
            # Finished beams scores are set to -inf for all words but one.
            new_scores.masked_fill_(new_scores.eq(0), float("-inf"))

            num_candidates = new_scores.shape[-1]
            new_scores_flat = new_scores.view(1, -1)
            beam_scores, top_inds_flat = torch.topk(new_scores_flat, beam_size)
            beam_scores = beam_scores.t()
            top_beam_inds = (top_inds_flat / num_candidates).squeeze(0)
            top_tokens = top_inds_flat % num_candidates

            # Prepare for next step.
            beams[step] = top_tokens
            reverse_inds[step] = top_beam_inds
            finished_beams = finished_beams[top_beam_inds]
            if self.params["text_encoder"] == "lstm":
                hidden_state = tuple(
                    ii.index_select(1, top_beam_inds) for ii in new_state)

            # Update if any of the latest beams are finished, ie, have <END>.
            # new_finished_beams = beams[step].eq(end_token)
            new_finished_beams = beams[step].eq(end_token).type(
                self.host.ByteTensor)
            finished_beams = finished_beams | new_finished_beams.t()
            if torch.sum(finished_beams).item() == beam_size:
                break

        # Backtrack the beam through indices.
        beams, tokens_list = self._backtrack_beams(beams, reverse_inds)
        # Add an <END> token at the end.
        tokens_list.append(self.host.LongTensor(1, beam_size).fill_(end_token))

        sorted_beam_tokens = torch.cat(tokens_list, 0).t()
        sorted_beam_lengths = sorted_beam_tokens.ne(end_token).long().sum(
            dim=1)
        # Trim all the top beams.
        top_beams = []
        for index in range(beam_size):
            beam_length = sorted_beam_lengths[index].view(-1)
            beam = sorted_beam_tokens[index].view(-1, 1)[1:beam_length]
            top_beams.append(beam)
        return {"top_beams": top_beams}