Ejemplo n.º 1
0
def test_end2end():
    # go to README.md of each model for more information
    # BERT nlu
    sys_nlu = BERTNLU()
    # simple rule DST
    sys_dst = RuleDST()
    # rule policy
    sys_policy = PPOPolicy()
    # template NLG
    sys_nlg = TemplateNLG(is_user=False)
    # assemble
    sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys')

    # BERT nlu trained on sys utterance
    user_nlu = BERTNLU(mode='sys', config_file='multiwoz_sys_context.json',
                       model_file='https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_sys_context.zip')
    # not use dst
    user_dst = None
    # rule policy
    user_policy = RulePolicy(character='usr')
    # template NLG
    user_nlg = TemplateNLG(is_user=True)
    # assemble
    user_agent = PipelineAgent(user_nlu, user_dst, user_policy, user_nlg, name='user')

    analyzer = Analyzer(user_agent=user_agent, dataset='multiwoz')

    set_seed(20200202)
    analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name='BERTNLU-RuleDST-PPOPolicy-TemplateNLG', total_dialog=1000)
Ejemplo n.º 2
0
    def __init__(self, model_file=DEFAULT_MODEL_URL, name="Dialog"):
        super(Dialog, self).__init__(name=name)
        if not os.path.exists(os.path.join(DEFAULT_DIRECTORY,'multiwoz/data')):
            os.mkdir(os.path.join(DEFAULT_DIRECTORY,'multiwoz/data'))
            ### download multiwoz data
            print('down load data from', DEFAULT_ARCHIVE_FILE_URL)
        if not os.path.exists(os.path.join(DEFAULT_DIRECTORY,'multiwoz/save')):
            os.mkdir(os.path.join(DEFAULT_DIRECTORY,'multiwoz/save'))
            ### download trained model
            print('down load model from', DEFAULT_MODEL_URL)
        model_path = ""

        config = Config()
        parser = config.parser
        config = parser.parse_args()
        with open("assets/never_split.txt") as f:
            never_split = f.read().split("\n")
        self.tokenizer = BertTokenizer("assets/vocab.txt", never_split=never_split)
        self.nlu = BERTNLU()
        self.dst_ = DST(config).cuda()
        ckpt = torch.load("save/model_Sun_Jun_21_07:08:48_2020.pt", map_location = lambda storage, loc: storage.cuda(local_rank))
        self.dst_.load_state_dict(ckpt["model"])
        self.dst_.eval()
        self.policy = RulePolicy()
        self.nlg = TemplateNLG(is_user=False)
        self.init_session()
        self.slot_mapping = {
            "leave": "leaveAt",
            "arrive": "arriveBy"
        }
 def __init__(self, vocab, db, config):
     super(DIALOG, self).__init__()
     self.vocab = vocab
     self.db = db
     self.pointer_size = config.pointer_size
     self.max_belief_len = config.max_belief_len
     self.nlu = BERTNLU()  # fixed, not finetuning
     self.context_encoder = ContextEncoder(vocab.vocab_size, config.hidden_size, config.hidden_size, config.dropout, config.num_layers, vocab.word2idx["<pad>"])
     self.belief_decoder = BeliefDecoder(vocab, self.context_encoder.embedding, config.hidden_size, config.dropout, config.num_layers, config.max_value_len)
     self.policy = RulePolicy()
     self.nlg = TemplateNLG(is_user=False)
     # self.action_decoder = ActionDecoder(vocab, self.context_encoder.embedding, config.hidden_size, config.pointer_size, config.dropout, config.num_layers, config.max_act_len)
     # self.response_decoder = ResponseDecoder(vocab, self.context_encoder.embedding, config.hidden_size, config.pointer_size, config.dropout, config.num_layers, config.max_sentence_len)
     self.load_embedding()  # load Glove & Kazuma embedding
Ejemplo n.º 4
0
def set_user(user):
    # MILU
    user_nlu = BERTNLU()
    # not use dst
    user_dst = None
    # rule policy
    user_policy = RulePolicy(character='usr')
    #'attraction', 'hotel', 'restaurant', 'train', 'taxi', 'hospital', 'police'
    if user:
        if user == '7':
            user_policy.policy.goal_generator = GoalGenerator_7()
        if user == 'attraction':
            user_policy.policy.goal_generator = GoalGenerator_attraction()
        elif user == 'hospital':
            user_policy.policy.goal_generator = GoalGenerator_hospital()
        elif user == 'hotel':
            user_policy.policy.goal_generator = GoalGenerator_hotel()
        elif user == 'police':
            user_policy.policy.goal_generator = GoalGenerator_police()
        elif user == 'restaurant':
            user_policy.policy.goal_generator = GoalGenerator_restaurant()
        elif user == 'taxi':
            user_policy.policy.goal_generator = GoalGenerator_taxi()
        elif user == 'train':
            user_policy.policy.goal_generator = GoalGenerator_train()
    # template NLG
    user_nlg = TemplateNLG(is_user=True)

    return user_nlu, user_dst, user_policy, user_nlg
Ejemplo n.º 5
0
def build_user_agent_bertnlu():
    user_nlu = BERTNLU()
    user_dst = None
    user_policy = RulePolicy(character='usr')
    user_nlg = TemplateNLG(is_user=True)
    user_agent = PipelineAgent(user_nlu, user_dst, user_policy, user_nlg, 'user')
    return user_agent
Ejemplo n.º 6
0
def build_sys_agent_bertnlu():
    sys_nlu = BERTNLU()
    sys_dst = RuleDST()
    sys_policy = RulePolicy(character='sys')
    sys_nlg = TemplateNLG(is_user=False)
    sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, 'sys')
    return sys_agent
Ejemplo n.º 7
0
def set_system(sys_policy, sys_path):
    # BERT nlu
    sys_nlu = BERTNLU()
    # simple rule DST
    sys_dst = RuleDST()
    # rule policy
    sys_policy = get_policy(sys_policy, sys_path)
    # template NLG
    sys_nlg = TemplateNLG(is_user=False)

    return sys_nlu, sys_dst, sys_policy, sys_nlg
Ejemplo n.º 8
0
def build_sys_agent_bertnlu_context(use_nlu=True):
    sys_nlu = BERTNLU(mode='all', config_file='multiwoz_all_context.json',
                      model_file='https://tatk-data.s3-ap-northeast-1.amazonaws.com/bert_multiwoz_all_context.zip')
    sys_dst = RuleDST()

    sys_policy = RulePolicy(character='sys')

    sys_nlg = TemplateNLG(is_user=False)

    if use_nlu:
        sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, 'sys')
    else:
        sys_agent = PipelineAgent(None, sys_dst, sys_policy, None, 'sys')
    return sys_agent
Ejemplo n.º 9
0
def build_user_agent_bertnlu(use_nlu=True):
    user_nlu = BERTNLU(mode='all', config_file='multiwoz_all.json',
                       model_file='https://tatk-data.s3-ap-northeast-1.amazonaws.com/bert_multiwoz_all.zip')

    user_dst = None

    user_policy = RulePolicy(character='usr')

    user_nlg = TemplateNLG(is_user=True)

    if use_nlu:
        user_agent = PipelineAgent(user_nlu, None, user_policy, user_nlg, 'user')
    else:
        user_agent = PipelineAgent(None, None, user_policy, None, 'user')

    return user_agent
Ejemplo n.º 10
0
 def __init__(self, dataset='multiwoz-test'):
     user_nlu = BERTNLU(
         mode='sys',
         config_file='multiwoz_sys_context.json',
         model_file=
         'https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_sys_context.zip'
     )
     user_dst = None
     user_policy = RulePolicy(character='usr')
     user_nlg = TemplateNLG(is_user=True)
     user_agent = PipelineAgent(user_nlu,
                                user_dst,
                                user_policy,
                                user_nlg,
                                name='user')
     dataset, _ = data.split_name(dataset)
     super().__init__(user_agent, dataset)
Ejemplo n.º 11
0
class Dialog(Agent):
    def __init__(self, model_file=DEFAULT_MODEL_URL, name="Dialog"):
        super(Dialog, self).__init__(name=name)
        if not os.path.exists(os.path.join(DEFAULT_DIRECTORY,'multiwoz/data')):
            os.mkdir(os.path.join(DEFAULT_DIRECTORY,'multiwoz/data'))
            ### download multiwoz data
            print('down load data from', DEFAULT_ARCHIVE_FILE_URL)
        if not os.path.exists(os.path.join(DEFAULT_DIRECTORY,'multiwoz/save')):
            os.mkdir(os.path.join(DEFAULT_DIRECTORY,'multiwoz/save'))
            ### download trained model
            print('down load model from', DEFAULT_MODEL_URL)
        model_path = ""

        config = Config()
        parser = config.parser
        config = parser.parse_args()
        with open("assets/never_split.txt") as f:
            never_split = f.read().split("\n")
        self.tokenizer = BertTokenizer("assets/vocab.txt", never_split=never_split)
        self.nlu = BERTNLU()
        self.dst_ = DST(config).cuda()
        ckpt = torch.load("save/model_Sun_Jun_21_07:08:48_2020.pt", map_location = lambda storage, loc: storage.cuda(local_rank))
        self.dst_.load_state_dict(ckpt["model"])
        self.dst_.eval()
        self.policy = RulePolicy()
        self.nlg = TemplateNLG(is_user=False)
        self.init_session()
        self.slot_mapping = {
            "leave": "leaveAt",
            "arrive": "arriveBy"
        }

    def init_session(self):
        self.nlu.init_session()
        self.policy.init_session()
        self.nlg.init_session()
        self.history = []
        self.state = default_state()
        pass

    def response(self, user):
        self.history.append(["user", user])

        user_action = []

        self.input_action = self.nlu.predict(user, context=[x[1] for x in self.history[:-1]])
        self.input_action = deepcopy(self.input_action)
        for act in self.input_action:
            intent, domain, slot, value = act
            if intent == "Request":
                user_action.append(act)
                if not self.state["request_state"].get(domain):
                    self.state["request_state"][domain] = {}
                if slot not in self.state["request_state"][domain]:
                    self.state['request_state'][domain][slot] = 0
        
        context = " ".join([utterance[1] for utterance in self.history])
        context = context[-MAX_CONTEXT_LENGTH:]
        context = self.tokenizer.encode(context)
        context = torch.tensor(context, dtype=torch.int64).unsqueeze(dim=0).cuda()  # [1, len]

        belief_gen = self.dst_(None, context, 0, test=True)[0]  # [slots, len]
        for slot_idx, domain_slot in enumerate(ontology.all_info_slots):
            domain, slot = domain_slot.split("-")
            slot = self.slot_mapping.get(slot, slot)
            value = belief_gen[slot_idx][:-1]  # remove <EOS>
            value = self.tokenizer.decode(value)
            if value != "none":
                if slot in self.state["belief_state"][domain]["book"].keys():
                    if self.state["belief_state"][domain]["book"][slot] == "":
                        action = ["Inform", domain.capitalize(), REF_USR_DA[domain].get(slot, slot), value]
                        user_action.append(action)
                    self.state["belief_state"][domain]["book"][slot] = value
                elif slot in self.state["belief_state"][domain]["semi"].keys():
                    if self.state["belief_state"][domain]["semi"][slot] == "":
                        action = ["Inform", domain.capitalize(), REF_USR_DA[domain].get(slot, slot), value]
                        user_action.append(action)
                    self.state["belief_state"][domain]["semi"][slot] = value

        self.state["user_action"] = user_action

        self.output_action = deepcopy(self.policy.predict(self.state))
        model_response = self.nlg.generate(self.output_action)
        self.history.append(["sys", model_response])

        return model_response
Ejemplo n.º 12
0
    if not randomize:
        random.seed(r_seed)
        np.random.seed(r_seed)
        torch.manual_seed(r_seed)
    else:
        seed = int(time.time())
        print(seed)
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)


if __name__ == '__main__':

    # BERT nlu
    sys_nlu = BERTNLU()
    # simple rule DST
    sys_dst = RuleDST()
    # rule policy
    sys_policy = RulePolicy()
    from convlab2.policy.mgail.multiwoz import MGAIL
    policy_sys = MGAIL()
    policy_sys.load(
        '/home/nightop/ConvLab-2/convlab2/policy/mgail/multiwoz/save/all/99')
    # template NLG
    sys_nlg = TemplateNLG(is_user=False)
    # assemble
    sys_agent = PipelineAgent(sys_nlu,
                              sys_dst,
                              sys_policy,
                              sys_nlg,
class DIALOG(nn.Module):
    def __init__(self, vocab, db, config):
        super(DIALOG, self).__init__()
        self.vocab = vocab
        self.db = db
        self.pointer_size = config.pointer_size
        self.max_belief_len = config.max_belief_len
        self.nlu = BERTNLU()  # fixed, not finetuning
        self.context_encoder = ContextEncoder(vocab.vocab_size, config.hidden_size, config.hidden_size, config.dropout, config.num_layers, vocab.word2idx["<pad>"])
        self.belief_decoder = BeliefDecoder(vocab, self.context_encoder.embedding, config.hidden_size, config.dropout, config.num_layers, config.max_value_len)
        self.policy = RulePolicy()
        self.nlg = TemplateNLG(is_user=False)
        # self.action_decoder = ActionDecoder(vocab, self.context_encoder.embedding, config.hidden_size, config.pointer_size, config.dropout, config.num_layers, config.max_act_len)
        # self.response_decoder = ResponseDecoder(vocab, self.context_encoder.embedding, config.hidden_size, config.pointer_size, config.dropout, config.num_layers, config.max_sentence_len)
        self.load_embedding()  # load Glove & Kazuma embedding

    def forward(self, turn_inputs, turn_contexts, turn_context_lengths, action_history, teacher_forcing, test=False):
        """
        turn_inputs: {
            "user": [batch, len]
            "response": [batch, len]
            "belief": [batch, slots, len]
            "gate": [batch, slots]
            "action": [batch, len]
            "usr": [batch] => string list
            "context: [batch, turns, 2] => string list
            "prev_gate": [batch, slots]
        }
        turn_contexts: [batch, len]
        turn_context_lengths: [batch]
        action_history: [batch, actions, 4] => string list
        """

        if test:
            return self.test_forward()
        else:
            return self.train_forward(turn_inputs, turn_contexts, turn_context_lengths, action_history, teacher_forcing)

    def train_forward(self, turn_inputs, turn_contexts, turn_context_lengths, action_history, teacher_forcing):
        batch_size = turn_contexts.size(0)

        user = turn_inputs["usr"]
        history = turn_inputs["context"]
        for batch_idx in range(batch_size):
            self.states[batch_idx]["history"].append(["usr", turn_inputs["usr"][batch_idx]])

            input_action = self.nlu.predict(user[batch_idx], history[batch_idx])  # [[intent, domain, slot, value], ..., [intent, domain, slot, value]]
            self.states[batch_idx]["user_action"] = input_action
            action_history[batch_idx] += input_action

            # update request state
            for action in input_action:
                intent, domain, slot, _ = action
                slot, _ = clean_slot_values(domain, slot, _)
                domain = domain.lower()
                if intent == "Request":
                    if not self.states[batch_idx]["request_state"].get(domain):
                        self.states[batch_idx]["request_state"][domain] = {}
                    self.states[batch_idx]["request_state"][domain][slot] = 0

        # calculate score of input action generated by BERT NLU
        encoded_action_history = []
        history_lengths = []
        for batch_idx, actions in enumerate(action_history):
            encoded_actions = []
            for action in actions:
                intent, domain, slot, value = action
                slot = ontology.normlize_slot_names.get(slot, slot)
                encoded_actions += self.vocab.encode(" ".join([intent, domain, slot, value]))[1:-1]
            encoded_actions = [self.vocab.word2idx["<bos>"]] + encoded_actions + [self.vocab.word2idx["<eos>"]]
            history_lengths.append(len(encoded_actions))
            encoded_action_history.append(encoded_actions)
        history_lengths = torch.tensor(history_lengths).cuda()
        encoded_action_history_ = torch.zeros((batch_size, history_lengths.max().item()), dtype=torch.int64).cuda()
        for batch_idx, actions in enumerate(encoded_action_history):
            encoded_action_history_[batch_idx, :history_lengths[batch_idx]] = torch.tensor(actions)

        action_history_outputs, action_history_hidden = self.context_encoder(encoded_action_history_, history_lengths)  # [batch, len, hidden], [layers, batch, hidden]
        encoder_outputs_, encoder_hidden_ = self.context_encoder(turn_contexts, turn_context_lengths)  # [batch, len, hidden], [layers, batch, hidden]

        action_score = torch.sigmoid(torch.matmul(action_history_hidden[-1].unsqueeze(dim=1), encoder_hidden_[-1].unsqueeze(dim=2)))  # [batch, 1, 1]

        # action score for attention score
        action_score_attenion = torch.zeros(batch_size, (turn_context_lengths+history_lengths).max().item()).cuda()  # [batch, len]
        for batch_idx, context_len in enumerate(turn_context_lengths):
            action_score_attenion[batch_idx, :context_len] = 1-action_score.squeeze()[batch_idx]
            action_score_attenion[batch_idx, context_len:context_len+history_lengths[batch_idx].item()] = action_score.squeeze()[batch_idx]

        # weighted sum & weighted concat
        encoder_outputs, turn_contexts, turn_context_lengths = self.concat(encoder_outputs_, action_history_outputs, turn_contexts, turn_context_lengths, \
            encoded_action_history_, history_lengths)
        action_score = action_score.transpose(0,1).contiguous()
        encoder_hidden = (1 - action_score) * encoder_hidden_ + action_score * action_history_hidden  # [layers, batch, hidden]

        gate_outputs, all_probs, all_pred_words = self.belief_decoder(encoder_outputs, encoder_hidden, turn_contexts, turn_context_lengths, \
            turn_inputs["belief"], teacher_forcing, action_score_attenion)  # [batch, slots, 3], [batch, slots, len, vocab], [batch, slots, len], [batch, slots]
        
        gate_preds = gate_outputs.argmax(dim=2)

        # prev_gate = turn_inputs.get("prev_gate")

        max_value_len = 0
        belief_gen = []  # [batch, slots, len]
        for batch_idx, batch in enumerate(all_pred_words):
            belief_gen_ = []  # [slots, len]
            for slot_idx, pred_words in enumerate(batch):
                if gate_preds[batch_idx, slot_idx].item() == ontology.gate_dict["none"]:
                    belief_gen_.append(self.vocab.encode("none")[1:])
                    len_ = len(self.vocab.encode("none")[1:])
                elif gate_preds[batch_idx, slot_idx].item() == ontology.gate_dict["don't care"]:
                    belief_gen_.append(self.vocab.encode("don't care")[1:])
                    len_ = len(self.vocab.encode("don't care")[1:])
                else:
                    for idx, value in enumerate(pred_words):
                        if value == self.vocab.word2idx["<eos>"]:
                            break
                    belief_gen_.append(pred_words[:idx+1].tolist())
                    len_ = idx + 1
                max_value_len = max(max_value_len, len_)
            belief_gen.append(belief_gen_)

        gate_label = turn_inputs["gate"]
        gate_loss = F.cross_entropy(gate_outputs.view(-1, 3), gate_label.view(-1))

        # if prev_gate is not None:
        #     turn_domain = self.make_turn_domain(prev_gate, gate_preds)  # [batch]
        # else:
        #     turn_domain = self.make_turn_domain(None, gate_preds)

        # prev_gate = gate_preds.detach()  # [batch, slots]

        acc_belief = torch.ones(batch_size, len(ontology.all_info_slots)).cuda()

        gate_mask = (gate_label != gate_preds)
        acc_belief.masked_fill_(gate_mask, value=0)  # fail to predict gate

        value_label = turn_inputs["belief"]
        value_label_lengths = torch.zeros(batch_size, len(ontology.all_info_slots), dtype=torch.int64).cuda()
        for batch_idx, batch in enumerate(value_label):
            for slot_idx, pred_words in enumerate(batch):
                value = pred_words[pred_words != self.vocab.word2idx["<pad>"]].tolist()  # remove padding
                value_label_lengths[batch_idx, slot_idx] = len(value)
                if value != belief_gen[batch_idx][slot_idx]:
                    acc_belief[batch_idx, slot_idx] = 0  # fail to predict value

        if teacher_forcing:
            value_loss = masked_cross_entropy_for_value(all_probs, value_label, value_label_lengths)
        else:
            min_len = min(value_label.size(2), all_probs.size(2))
            value_loss = masked_cross_entropy_for_value(all_probs[:, :, :min_len, :].contiguous(), value_label[:, :, :min_len].contiguous(), value_label_lengths)

        ### make state for Rule policy
        # for batch_idx, belief in enumerate(belief_gen):
        #     for slot_idx, slot in enumerate(belief):
        #         value = self.vocab.decode(slot[:-1])
        #         domain, slot = ontology.all_info_slots[slot_idx].split("-")

        #         # book slots
        #         for slot_ in self.states[batch_idx]["belief_state"][domain]["book"].keys():
        #             if slot_ == "booked":
        #                 continue
        #             slot_ = ontology.normlize_slot_names.get(slot_, slot_)
        #             if slot_ == slot:
        #                 self.states[batch_idx]["belief_state"][domain]["book"][slot_] = value

        #         # semi slots
        #         for slot_ in self.states[batch_idx]["belief_state"][domain]["semi"].keys():
        #             slot_ = ontology.normlize_slot_names.get(slot_, slot_)
        #             if slot_ == slot:
        #                 self.states[batch_idx]["belief_state"][domain]["semi"][slot_] = value

        # ### policy
        # output_actions = []
        # for batch_idx, state in enumerate(self.states):
        #     output_actions.append(self.policy.predict(state))

        # ### NLG
        # model_responses = []
        # for batch_idx, output_action in enumerate(output_actions):
        #     model_responses.append(self.nlg.generate(output_actions[batch_idx]))

        return gate_loss, value_loss, acc_belief, belief_gen, action_history

    def load_embedding(self):
        glove = GloveEmbedding()
        kazuma = KazumaCharEmbedding()
        embed = self.context_encoder.embedding.weight.data
        for word, idx in self.vocab.word2idx.items():
            embed[idx] = torch.tensor(glove.emb(word, default="zero") + kazuma.emb(word, default="zero"))
        # self.context_encoder.embedding.weight.data = embed
        # self.belief_decoder.slot_embedding.weight.data = embed

    def make_turn_domain(self, prev_gate, gate):
        batch_size = gate.size(0)
        turn_domain = torch.zeros((batch_size, len(ontology.all_domains))).cuda()
        if prev_gate is None:  # first turn
            turn_gate = (gate != ontology.gate_dict["none"]).long()
        else:
            turn_gate = (gate != prev_gate).long()  # find changed gate
        for slot_idx in range(len(ontology.all_info_slots)):
            domain, slot = ontology.all_info_slots[slot_idx].split("-")
            domain_idx = ontology.all_domains.index(domain)
            turn_domain[:, domain_idx] += turn_gate[:, slot_idx]
        turn_domain = turn_domain.argmax(dim=1).tolist()  # [batch]

        return turn_domain
        
    def parse_action(self, action):
        """
        action: [len] => list
        """

        domains = ontology.domain_action_slot["domain"]
        actions = ontology.domain_action_slot["action"]
        slots = ontology.domain_action_slot["slot"]

        parsed_action = []
        decoded_action = []

        for token in action:
            decoded_action.append(self.vocab.idx2word[token])
        
        for token in decoded_action:
            if token in domains:
                domain = token
            if token in actions:
                action = token
                if domain == "general":
                    parsed_action.append("{}-{}".format(domain, action))
            if token in slots:
                slot = token
                parsed_action.append("{}-{}-{}".format(domain, action, slot))

        return parsed_action

    def concat(self, encoder_outputs, action_history_outputs, contexts, context_lengths, action_history, history_lengths):
        batch_size = contexts.size(0)
        hidden_size = encoder_outputs.size(2)
        lengths = context_lengths + history_lengths
        new_contexts = torch.zeros(size=(batch_size, lengths.max().item()), dtype=torch.int64).cuda()
        new_encoder_outputs = torch.zeros(size=(batch_size, lengths.max().item(), hidden_size)).cuda()

        for batch_idx in range(batch_size):
            new_contexts[batch_idx, :lengths[batch_idx]] = torch.cat([contexts[batch_idx, :context_lengths[batch_idx]], \
                action_history[batch_idx, :history_lengths[batch_idx]]], dim=0)  # [batch, len]
            new_encoder_outputs[batch_idx, :lengths[batch_idx], :] = torch.cat([encoder_outputs[batch_idx, :context_lengths[batch_idx], :], \
                action_history_outputs[batch_idx, :history_lengths[batch_idx], :]], dim=0)  # [batch, len, hidden]

        return new_encoder_outputs, new_contexts, lengths
Ejemplo n.º 14
0
import os
import sys
import json

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F

import ontology
from reader import Reader, Vocab
from config import Config
from convlab2.nlu.jointBERT.multiwoz import BERTNLU
from convlab2.dst.rule.multiwoz import RuleDST

nlu = BERTNLU()
dst = RuleDST()

config = Config()
parser = config.parser
config = parser.parse_args()

vocab = Vocab(config)
vocab.load("save/vocab")
reader = Reader(vocab, config)
reader.load_data("train")
data = json.load(open("data/MultiWOZ_2.1/dev_data.json", "r"))

max_iter = len(list(reader.make_batch(reader.dev)))
iterator = reader.make_batch(reader.dev)
t = tqdm(enumerate(iterator), total=max_iter, ncols=250)