Esempio n. 1
0
 def init_session(self):
     self.nlu.init_session()
     self.policy.init_session()
     self.nlg.init_session()
     self.history = []
     self.state = default_state()
     pass
Esempio n. 2
0
def main():
    s = default_state()
    s['history'] = [
        ['sys', ''],
        [
            'user',
            'Actually I need a expensively priced restaurant . Are there any fitting that description ?'
        ],
        ['sys', 'yes , there are 57 in the centre. do you have a preference?'],
        ['user', 'Can you give me the address ?']
    ]
    # s['belief_state']['attraction']['semi']['area'] = 'centre'
    s['belief_state']['restaurant']['semi']['pricerange'] = 'expensive'
    testPolicy = MDRGWordPolicy()
    # print(s)
    print(testPolicy.predict(s))

    s['history'].append(
        ['sys', 'the address is 106 Regent Street City Centre.'])
    s['history'].append([
        'user',
        'It will be for 4 people . Arriving at the restaurant by 18:15 . I \'ll be needing a table for friday .'
    ])
    s['belief_state']['restaurant']['book']['people'] = '4'
    s['belief_state']['restaurant']['book']['time'] = '18:15'
    s['belief_state']['restaurant']['book']['day'] = 'friday'
    # s = default_state()
    # s['history'] = [['null', 'I want a chinese restaurant']]
    # # s['belief_state']['attraction']['semi']['area'] = 'centre'
    # s['belief_state']['restaurant']['semi']['area'] = 'south'
    # testPolicy = MDRGWordPolicy()
    # print(s)
    print(testPolicy.predict(s))
def fake_state():
    user_action = {
        'Hotel-Request': [['Name', '?']],
        'Train-Inform': [['Day', 'don\'t care']]
    }
    from convlab2.util.multiwoz.state import default_state
    init_belief_state = default_state()['belief_state']
    kb_results = [None, None]
    kb_results[0] = {
        'name': 'xxx_train',
        'day': 'tuesday',
        'dest': 'cam',
        'phone': '123-3333',
        'area': 'south'
    }
    kb_results[1] = {
        'name': 'xxx_train',
        'day': 'tuesday',
        'dest': 'cam',
        'phone': '123-3333',
        'area': 'north'
    }
    state = {
        'user_action': user_action,
        'belief_state': init_belief_state,
        'kb_results_dict': kb_results,
        'hotel-request': [['phone']]
    }
    '''
    state = {'user_action': dict(),
             'belief_state: dict(),
             'kb_results_dict': kb_results
    }
    '''
    return state
Esempio n. 4
0
    def _build_data(self, root_dir, processed_dir):
        self.data = {}
        data_loader = ActPolicyDataloader(
            dataset_dataloader=MultiWOZDataloader())
        for part in ['train', 'val', 'test']:
            self.data[part] = []
            raw_data = data_loader.load_data(data_key=part, role='sys')[part]

            for belief_state, context_dialog_act, terminated, dialog_act in \
                zip(raw_data['belief_state'], raw_data['context_dialog_act'], raw_data['terminated'], raw_data['dialog_act']):
                state = default_state()
                state['belief_state'] = belief_state
                state['user_action'] = context_dialog_act[-1]
                state['system_action'] = context_dialog_act[-2] if len(
                    context_dialog_act) > 1 else {}
                state['terminated'] = terminated
                action = dialog_act
                self.data[part].append([
                    self.vector.state_vectorize(state),
                    self.vector.action_vectorize(action)
                ])

        os.makedirs(processed_dir)
        for part in ['train', 'val', 'test']:
            with open(os.path.join(processed_dir, '{}.pkl'.format(part)),
                      'wb') as f:
                pickle.dump(self.data[part], f)
Esempio n. 5
0
def main():
    s = default_state()
    s['history'] = [['null', 'I want a chinese restaurant']]
    # s['belief_state']['attraction']['semi']['area'] = 'centre'
    s['belief_state']['restaurant']['semi']['area'] = 'south'
    s['belief_state']['restaurant']['semi']['food'] = 'mexican'
    testPolicy = MDRGWordPolicy()
    print(s)
    print(testPolicy.predict(s))
Esempio n. 6
0
 def __init__(self):
     DST.__init__(self)
     self.state = default_state()
     path = os.path.dirname(
         os.path.dirname(
             os.path.dirname(
                 os.path.dirname(os.path.dirname(
                     os.path.abspath(__file__))))))
     path = os.path.join(path, 'data/multiwoz/value_dict.json')
     self.value_dict = json.load(open(path))
Esempio n. 7
0
 def init_session(self):
     self.state = default_state()
     if not self.param_restored:
         if os.path.isfile(os.path.join(DOWNLOAD_DIRECTORY, 'pytorch_model.bin')):
             print('loading weights from downloaded model')
             self.load_weights(model_path=os.path.join(DOWNLOAD_DIRECTORY, 'pytorch_model.bin'))
         elif os.path.isfile(os.path.join(SUMBT_PATH, args.output_dir, 'pytorch_model.bin')):
             print('loading weights from trained model')
             self.load_weights(model_path=os.path.join(SUMBT_PATH, args.output_dir, 'pytorch_model.bin'))
         else:
             raise ValueError('no availabel weights found.')
         self.param_restored = True
Esempio n. 8
0
    def __init__(self):
        super(MultiWozSUMBT, self).__init__()
        convert_to_glue_format()

        self.belief_tracker = BeliefTracker()
        self.batch = None  # generated with dataloader
        self.current_turn = 0
        self.idx2slot = {}
        self.idx2value = {}  # slot value for each slot, use processor.get_labels()

        if DEVICE == 'cuda':
            if not torch.cuda.is_available():
                raise ValueError('cuda not available')
            n_gpu = torch.cuda.device_count()
            if n_gpu < N_GPU:
                raise ValueError('gpu not enough')

        print("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(DEVICE, n_gpu,
                                                                                            bool(N_GPU > 1), FP16))

        # Get Processor
        self.processor = Processor()
        self.label_list = self.processor.get_labels()
        self.num_labels = [len(labels) for labels in self.label_list]  # number of slot-values in each slot-type
        self.belief_tracker.init_session(self.num_labels)
        if N_GPU > 1:
                self.belief_tracker = torch.nn.DataParallel(self.belief_tracker)

        # tokenizer
        vocab_dir = os.path.join(BERT_DIR, 'vocab.txt')
        if not os.path.exists(vocab_dir):
            raise ValueError("Can't find %s " % vocab_dir)
        self.tokenizer = BertTokenizer.from_pretrained(vocab_dir, do_lower_case=DO_LOWER_CASE)

        self.num_train_steps = None
        self.accumulation = False

        logger.info('dataset processed')
        if os.path.exists(OUTPUT_DIR) and os.listdir(OUTPUT_DIR):
            print("output dir {} not empty".format(OUTPUT_DIR))
        else:
            os.mkdir(OUTPUT_DIR)

        fileHandler = logging.FileHandler(os.path.join(OUTPUT_DIR, "log.txt"))
        logger.addHandler(fileHandler)

        random.seed(SEED)
        np.random.seed(SEED)
        torch.manual_seed(SEED)
        if N_GPU > 0:
            torch.cuda.manual_seed_all(SEED)

        self.state = default_state()
Esempio n. 9
0
    def __init__(self, ontology_vectors, ontology, slots, data_dir):
        DST.__init__(self)
        # data profile
        self.data_dir = data_dir
        self.validation_url = os.path.join(self.data_dir, 'data/validate.json')
        self.word_vectors_url = os.path.join(
            self.data_dir, 'word-vectors/paragram_300_sl999.txt')
        self.training_url = os.path.join(self.data_dir, 'data/train.json')
        self.ontology_url = os.path.join(self.data_dir, 'data/ontology.json')
        self.testing_url = os.path.join(self.data_dir, 'data/test.json')
        self.model_url = os.path.join(self.data_dir, 'models/model-1')
        self.graph_url = os.path.join(self.data_dir, 'graphs/graph-1')
        self.results_url = os.path.join(self.data_dir, 'results/log-1.txt')
        self.kb_url = os.path.join(self.data_dir, 'data/')  # not used
        self.train_model_url = os.path.join(self.data_dir,
                                            'train_models/model-1')
        self.train_graph_url = os.path.join(self.data_dir,
                                            'train_graph/graph-1')

        self.model_variables = model_definition(ontology_vectors,
                                                len(ontology),
                                                slots,
                                                num_hidden=None,
                                                bidir=True,
                                                net_type=None,
                                                test=True,
                                                dev='cpu')
        self.state = default_state()
        _config = tf.ConfigProto()
        _config.gpu_options.allow_growth = True
        _config.allow_soft_placement = True
        self.sess = tf.Session(config=_config)
        self.param_restored = False
        self.det_dic = {}
        for domain, dic in REF_USR_DA.items():
            for key, value in dic.items():
                assert '-' not in key
                self.det_dic[key.lower()] = key + '-' + domain
                self.det_dic[value.lower()] = key + '-' + domain

        def parent_dir(path, time=1):
            for _ in range(time):
                path = os.path.dirname(path)
            return path

        root_dir = parent_dir(os.path.abspath(__file__), 4)
        self.value_dict = json.load(
            open(os.path.join(root_dir, 'data/multiwoz/value_dict.json')))
Esempio n. 10
0
    def generate_dict(self):
        """
        init the dict for mapping state/action into vector
        """
        self.act2vec = dict((a, i) for i, a in enumerate(self.da_voc))
        self.vec2act = dict((v, k) for k, v in self.act2vec.items())
        self.da_dim = len(self.da_voc)
        self.opp2vec = dict((a, i) for i, a in enumerate(self.da_voc_opp))
        self.da_opp_dim = len(self.da_voc_opp)

        self.belief_state_dim = 0
        for domain in self.belief_domains:
            for slot, value in default_state()['belief_state'][domain.lower()]['semi'].items():
                self.belief_state_dim += 1

        self.state_dim = self.da_opp_dim + self.da_dim + self.belief_state_dim + \
                         len(self.db_domains) + 6 * len(self.db_domains) + 1
Esempio n. 11
0
    def setup_class(cls):
        # domain_slots will be used for checking request_state
        cls.domain_slots = {}
        for key, value in REF_SYS_DA.items():
            new_key = key.lower()
            new_value = set(value.values())
            cls.domain_slots[new_key] = new_value

        cls.default_belief_state = default_state()['belief_state']

        cls.usr_acts = [
            {
                "Hotel-Inform": [["Area", "east"], ["Stars", "4"]]
            },
            {
                "Hotel-Inform": [["Parking", "yes"], ["Internet", "yes"]]
            },
            {},
            {
                "Hotel-Inform": [["Day", "friday"]],
                "Hotel-Request": [["Ref", "?"]]
            },
            {
                "Train-Inform": [["Dest", "bishops stortford"],
                                 ["Day", "friday"], ["Depart", "cambridge"]]
            },
            {
                "Train-Inform": [["Arrive", "19:45"]]
            },
            {
                "Train-Request": [["Leave", "?"], ["Time", "?"],
                                  ["Ticket", "?"]]
            },
            {
                "Hotel-Inform": [["Stay", "4"], ["Day", "monday"],
                                 ["People", "3"]]
            },
        ]
Esempio n. 12
0
def revert_state(model_output:dict, reversed_vocab:dict):
    slotdict = {'price range':'pricerange', 'leave at':'leaveAt', 'arrive by':'arriveBy'}
    valuedict = {'do not care': 'dontcare'}
    belief_state = default_state()['belief_state']
    for d in model_output:
        for s in model_output[d]:
            for v in model_output[d][s]:
                domain = reversed_vocab[d]
                if domain not in belief_state:
                    continue
                slot = reversed_vocab[s]
                value = reversed_vocab[v]
                if slot.startswith('book '):
                    slot = slot[5:]
                    table = belief_state[domain]['book']
                else:
                    table = belief_state[domain]['semi']
                if slot in slotdict:
                    slot = slotdict[slot]
                if value in valuedict:
                    value = valuedict[value]
                if slot in table:
                    table[slot] = value
    return belief_state
Esempio n. 13
0
    def __init__(
            self,
            archive_file=DEFAULT_ARCHIVE_FILE,
            cuda_device=DEFAULT_CUDA_DEVICE,
            model_file="https://convlab.blob.core.windows.net/convlab-2/larl.zip"
    ):

        if not os.path.isfile(archive_file):
            if not model_file:
                raise Exception("No model for LaRL is specified!")
            archive_file = cached_path(model_file)

        temp_path = os.path.dirname(os.path.abspath(__file__))
        #print(temp_path)
        zip_ref = zipfile.ZipFile(archive_file, 'r')
        zip_ref.extractall(temp_path)
        zip_ref.close()

        self.prev_state = default_state()
        self.prev_active_domain = None

        domain_name = 'object_division'
        domain_info = domain.get_domain(domain_name)
        self.db = Database()
        data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)),
                                 'data')
        train_data_path = os.path.join(data_path, 'train_dials.json')
        if not os.path.exists(train_data_path):
            zipped_file = os.path.join(data_path, 'norm-multi-woz.zip')
            archive = zipfile.ZipFile(zipped_file, 'r')
            archive.extractall(data_path)

        norm_multiwoz_path = data_path
        with open(
                os.path.join(norm_multiwoz_path,
                             'input_lang.index2word.json')) as f:
            self.input_lang_index2word = json.load(f)
        with open(
                os.path.join(norm_multiwoz_path,
                             'input_lang.word2index.json')) as f:
            self.input_lang_word2index = json.load(f)
        with open(
                os.path.join(norm_multiwoz_path,
                             'output_lang.index2word.json')) as f:
            self.output_lang_index2word = json.load(f)
        with open(
                os.path.join(norm_multiwoz_path,
                             'output_lang.word2index.json')) as f:
            self.output_lang_word2index = json.load(f)

        config = Pack(
            seed=10,
            train_path=train_data_path,
            max_vocab_size=1000,
            last_n_model=5,
            max_utt_len=50,
            max_dec_len=50,
            backward_size=2,
            batch_size=1,
            use_gpu=True,
            op='adam',
            init_lr=0.001,
            l2_norm=1e-05,
            momentum=0.0,
            grad_clip=5.0,
            dropout=0.5,
            max_epoch=100,
            embed_size=100,
            num_layers=1,
            utt_rnn_cell='gru',
            utt_cell_size=300,
            bi_utt_cell=True,
            enc_use_attn=True,
            dec_use_attn=True,
            dec_rnn_cell='lstm',
            dec_cell_size=300,
            dec_attn_mode='cat',
            y_size=10,
            k_size=20,
            beta=0.001,
            simple_posterior=True,
            contextual_posterior=True,
            use_mi=False,
            use_pr=True,
            use_diversity=False,
            #
            beam_size=20,
            fix_batch=True,
            fix_train_batch=False,
            avg_type='word',
            print_step=300,
            ckpt_step=1416,
            improve_threshold=0.996,
            patient_increase=2.0,
            save_model=True,
            early_stop=False,
            gen_type='greedy',
            preview_batch_num=None,
            k=domain_info.input_length(),
            init_range=0.1,
            pretrain_folder='2019-09-20-21-43-06-sl_cat',
            forward_only=False)

        config.use_gpu = config.use_gpu and torch.cuda.is_available()
        self.corpus = corpora_inference.NormMultiWozCorpus(config)
        self.model = SysPerfectBD2Cat(self.corpus, config)
        self.config = config
        if config.use_gpu:
            self.model.load_state_dict(
                torch.load(os.path.join(temp_path, 'larl_model/best-model')))
            self.model.cuda()
        else:
            self.model.load_state_dict(
                torch.load(os.path.join(temp_path, 'larl_model/best-model'),
                           map_location=lambda storage, loc: storage))
        self.model.eval()
        self.dic = pickle.load(
            open(os.path.join(temp_path, 'larl_model/svdic.pkl'), 'rb'))
Esempio n. 14
0
    def update(self, user_act=None):
        """Update the dialog state."""
        if type(user_act) is not str:
            raise Exception('Expected user_act to be <class \'str\'> type, but get {}.'.format(type(user_act)))
        prev_state = copy.deepcopy(self.state)
        if not os.path.exists(os.path.join(self.data_dir, "results")):
            os.makedirs(os.path.join(self.data_dir, "results"))

        global train_batch_size

        model_variables = self.model_variables
        (user, sys_res, no_turns, user_uttr_len, sys_uttr_len, labels, domain_labels, domain_accuracy,
         slot_accuracy, value_accuracy, value_f1, train_step, keep_prob, predictions,
         true_predictions, [y, _]) = model_variables

        # Note: Comment the following line since the first node is already i
        # prev_state['history'] = [['sys', 'null']] if len(prev_state['history']) == 0 else prev_state['history']
        assert len(prev_state['history']) > 0
        first_turn = prev_state['history'][0]
        if first_turn[0] != 'sys':
            prev_state['history'] = [['sys', '']] + prev_state['history']
        actual_history = []
        assert len(prev_state['history']) % 2 == 0
        for name, utt in prev_state['history']:
            if not utt:
                utt = 'null'
            if len(actual_history)==0 or len(actual_history[-1])==2:
                actual_history.append([utt])
            else:
                actual_history[-1].append(utt)
        # actual_history[-1].append(user_act)
        # actual_history = self.normalize_history(actual_history)
        # if len(actual_history) == 0:
        #     actual_history = [['', user_act if len(user_act)>0 else 'fake user act']]
        fake_dialogue = {}
        turn_no = 0
        for _sys, _user in actual_history:
            turn = {}
            turn['system'] = _sys
            fake_user = {}
            fake_user['text'] = _user
            fake_user['belief_state'] = default_state()['belief_state']
            turn['user'] = fake_user
            key = str(turn_no)
            fake_dialogue[key] = turn
            turn_no += 1
        context, actual_context = process_history([fake_dialogue], self.word_vectors, self.ontology)
        batch_user, batch_sys, batch_labels, batch_domain_labels, batch_user_uttr_len, batch_sys_uttr_len, \
                batch_no_turns = generate_batch(context, 0, 1, len(self.ontology))  # old feature

        # run model
        [pred, y_pred] = self.sess.run(
            [predictions, y],
            feed_dict={user: batch_user, sys_res: batch_sys,
                       labels: batch_labels,
                       domain_labels: batch_domain_labels,
                       user_uttr_len: batch_user_uttr_len,
                       sys_uttr_len: batch_sys_uttr_len,
                       no_turns: batch_no_turns,
                       keep_prob: 1.0})

        # convert to str output
        dialgs, _, _ = track_dialogue(actual_context, self.ontology, pred, y_pred)
        assert len(dialgs) >= 1
        last_turn = dialgs[0][-1]
        predictions = last_turn['prediction']
        new_belief_state = copy.deepcopy(prev_state['belief_state'])

        # update belief state
        for item in predictions:
            item = item.lower()
            domain, slot, value = item.strip().split('-')
            value = value[::-1].split(':', 1)[1][::-1]
            if slot == 'price range':
                slot = 'pricerange'
            if slot not in ['name', 'book']:
                if domain not in new_belief_state:
                    raise Exception('Error: domain <{}> not in belief state'.format(domain))
                slot = REF_SYS_DA[domain.capitalize( )].get(slot, slot)
                assert 'semi' in new_belief_state[domain]
                assert 'book' in new_belief_state[domain]
                if 'book' in slot:
                    assert slot.startswith('book ')
                    slot = slot.strip().split()[1]
                if slot == 'arriveby':
                    slot = 'arriveBy'
                elif slot == 'leaveat':
                    slot = 'leaveAt'
                domain_dic = new_belief_state[domain]
                if slot in domain_dic['semi']:
                    new_belief_state[domain]['semi'][slot] = normalize_value(self.value_dict, domain, slot, value)
                elif slot in domain_dic['book']:
                    new_belief_state[domain]['book'][slot] = value
                elif slot.lower() in domain_dic['book']:
                    new_belief_state[domain]['book'][slot.lower()] = value
                else:
                    with open('mdbt_unknown_slot.log', 'a+') as f:
                        f.write('unknown slot name <{}> with value <{}> of domain <{}>\nitem: {}\n\n'.format(slot, value,
                                domain, item))
        new_request_state = copy.deepcopy(prev_state['request_state'])
        # update request_state
        user_request_slot = self.detect_requestable_slots(user_act)
        for domain in user_request_slot:
            for key in user_request_slot[domain]:
                if domain not in new_request_state:
                    new_request_state[domain] = {}
                if key not in new_request_state[domain]:
                    new_request_state[domain][key] = user_request_slot[domain][key]
        # update state
        new_state = copy.deepcopy(dict(prev_state))
        new_state['belief_state'] = new_belief_state
        new_state['request_state'] = new_request_state
        self.state = new_state
        return self.state
Esempio n. 15
0
 def init_session(self):
     self.state = default_state()
     if not self.param_restored:
         self.restore()
Esempio n. 16
0
 def init_session(self):
     """Initialize ``self.state`` with a default state, which ``convlab2.util.multiwoz.state.default_state`` returns."""
     self.state = default_state()
Esempio n. 17
0
    def __init__(self, data_dir=DATA_PATH, model_file='https://convlab.blob.core.windows.net/convlab-2/sumbt.tar.gz', eval_slots=multiwoz_slot_list):

        DST.__init__(self)

        # if not os.path.exists(data_dir):
        #     if model_file == '':
        #         raise Exception(
        #             'Please provide remote model file path in config')
        #     resp = urllib.request.urlretrieve(model_file)[0]
        #     temp_file = tarfile.open(resp)
        #     temp_file.extractall('data')
        #     assert os.path.exists(data_dir)

        processor = Processor(args)
        self.processor = processor
        label_list = processor.get_labels()
        num_labels = [len(labels) for labels in label_list]  # number of slot-values in each slot-type

        # tokenizer
        self.tokenizer = BertTokenizer.from_pretrained(args.bert_model_name, cache_dir=args.bert_model_cache_dir)
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)

        self.device = torch.device("cuda" if USE_CUDA else "cpu")

        self.sumbt_model = BeliefTracker(args, num_labels, self.device)
        if USE_CUDA and N_GPU > 1:
            self.sumbt_model = torch.nn.DataParallel(self.sumbt_model)
        if args.fp16:
            self.sumbt_model.half()
        self.sumbt_model.to(self.device)

        ## Get slot-value embeddings
        self.label_token_ids, self.label_len = [], []
        for labels in label_list:
            token_ids, lens = get_label_embedding(labels, args.max_label_length, self.tokenizer, self.device)
            self.label_token_ids.append(token_ids)
            self.label_len.append(lens)
        self.label_map = [{label: i for i, label in enumerate(labels)} for labels in label_list]
        self.label_map_inv = [{i: label for i, label in enumerate(labels)} for labels in label_list]
        self.label_list = label_list
        self.target_slot = processor.target_slot
        ## Get domain-slot-type embeddings
        self.slot_token_ids, self.slot_len = \
            get_label_embedding(processor.target_slot, args.max_label_length, self.tokenizer, self.device)

        self.args = args
        self.state = default_state()
        self.param_restored = False
        if USE_CUDA and N_GPU == 1:
            self.sumbt_model.initialize_slot_value_lookup(self.label_token_ids, self.slot_token_ids)
        elif USE_CUDA and N_GPU > 1:
            self.sumbt_model.module.initialize_slot_value_lookup(self.label_token_ids, self.slot_token_ids)

        self.det_dic = {}
        for domain, dic in REF_USR_DA.items():
            for key, value in dic.items():
                assert '-' not in key
                self.det_dic[key.lower()] = key + '-' + domain
                self.det_dic[value.lower()] = key + '-' + domain

        self.cached_res = {}
        convert_to_glue_format(DATA_PATH, SUMBT_PATH)
        if not os.path.isdir(os.path.join(SUMBT_PATH, args.output_dir)):
            os.makedirs(os.path.join(SUMBT_PATH, args.output_dir))
        self.train_examples = processor.get_train_examples(os.path.join(SUMBT_PATH, args.tmp_data_dir), accumulation=False)
        self.dev_examples = processor.get_dev_examples(os.path.join(SUMBT_PATH, args.tmp_data_dir), accumulation=False)
        self.test_examples = processor.get_test_examples(os.path.join(SUMBT_PATH, args.tmp_data_dir), accumulation=False)
        self.eval_slots = eval_slots
        self.download_model()
Esempio n. 18
0
    def __init__(self, num=1):
        parser = argparse.ArgumentParser(description='S2S')
        parser.add_argument('--no_cuda',
                            type=util.str2bool,
                            nargs='?',
                            const=True,
                            default=True,
                            help='enables CUDA training')
        parser.add_argument('--seed',
                            type=int,
                            default=1,
                            metavar='S',
                            help='random seed (default: 1)')

        parser.add_argument('--no_models',
                            type=int,
                            default=20,
                            help='how many models to evaluate')
        parser.add_argument('--original',
                            type=str,
                            default='model/model/',
                            help='Original path.')

        parser.add_argument('--dropout', type=float, default=0.0)
        parser.add_argument('--use_emb', type=str, default='False')

        parser.add_argument('--beam_width',
                            type=int,
                            default=10,
                            help='Beam width used in beamsearch')
        parser.add_argument('--write_n_best',
                            type=util.str2bool,
                            nargs='?',
                            const=True,
                            default=False,
                            help='Write n-best list (n=beam_width)')

        parser.add_argument('--model_path',
                            type=str,
                            default='model/model/translate.ckpt',
                            help='Path to a specific model checkpoint.')
        parser.add_argument('--model_dir',
                            type=str,
                            default='data/multi-woz/model/model/')
        parser.add_argument('--model_name', type=str, default='translate.ckpt')

        parser.add_argument('--valid_output',
                            type=str,
                            default='model/data/val_dials/',
                            help='Validation Decoding output dir path')
        parser.add_argument('--decode_output',
                            type=str,
                            default='model/data/test_dials/',
                            help='Decoding output dir path')

        args = parser.parse_args([])

        args.cuda = not args.no_cuda and torch.cuda.is_available()

        torch.manual_seed(args.seed)

        self.device = torch.device("cuda" if args.cuda else "cpu")
        with open(
                os.path.join(os.path.dirname(__file__),
                             args.model_path + '.config'), 'r') as f:
            add_args = json.load(f)
            # print(add_args)
            for k, v in add_args.items():
                setattr(args, k, v)
            # print(args)
            args.mode = 'test'
            args.load_param = True
            args.dropout = 0.0
            assert args.dropout == 0.0

        # Start going through models
        args.original = args.model_path
        args.model_path = args.original
        self.model = loadModel(num, args)
        self.dial = {"cur": {"log": []}}
        self.prev_state = default_state()
        self.prev_active_domain = None
Esempio n. 19
0
def main():
    dialogue = {
        "log": [{
            "text":
            "am looking for a place to to stay that has cheap price range it should be in a type of hotel",
            "metadata": {}
        }, {
            "text": "Okay, do you have a specific area you want to stay in?",
            "metadata": {
                "taxi": {
                    "book": {
                        "booked": []
                    },
                    "semi": {
                        "leaveAt": "",
                        "destination": "",
                        "departure": "",
                        "arriveBy": ""
                    }
                },
                "police": {
                    "book": {
                        "booked": []
                    },
                    "semi": {}
                },
                "restaurant": {
                    "book": {
                        "booked": [],
                        "time": "",
                        "day": "",
                        "people": ""
                    },
                    "semi": {
                        "food": "",
                        "pricerange": "",
                        "name": "",
                        "area": ""
                    }
                },
                "hospital": {
                    "book": {
                        "booked": []
                    },
                    "semi": {
                        "department": ""
                    }
                },
                "hotel": {
                    "book": {
                        "booked": [],
                        "stay": "",
                        "day": "",
                        "people": ""
                    },
                    "semi": {
                        "name": "not mentioned",
                        "area": "not mentioned",
                        "parking": "not mentioned",
                        "pricerange": "cheap",
                        "stars": "not mentioned",
                        "internet": "not mentioned",
                        "type": "hotel"
                    }
                },
                "attraction": {
                    "book": {
                        "booked": []
                    },
                    "semi": {
                        "type": "",
                        "name": "",
                        "area": ""
                    }
                },
                "train": {
                    "book": {
                        "booked": [],
                        "people": ""
                    },
                    "semi": {
                        "leaveAt": "",
                        "destination": "",
                        "day": "",
                        "arriveBy": "",
                        "departure": ""
                    }
                }
            }
        }, {
            "text":
            "no, i just need to make sure it's cheap. oh, and i need parking",
            "metadata": {}
        }, {
            "text":
            "I found 1 cheap hotel for you that includes parking. Do you like me to book it?",
            "metadata": {
                "taxi": {
                    "book": {
                        "booked": []
                    },
                    "semi": {
                        "leaveAt": "",
                        "destination": "",
                        "departure": "",
                        "arriveBy": ""
                    }
                },
                "police": {
                    "book": {
                        "booked": []
                    },
                    "semi": {}
                },
                "restaurant": {
                    "book": {
                        "booked": [],
                        "time": "",
                        "day": "",
                        "people": ""
                    },
                    "semi": {
                        "food": "",
                        "pricerange": "",
                        "name": "",
                        "area": ""
                    }
                },
                "hospital": {
                    "book": {
                        "booked": []
                    },
                    "semi": {
                        "department": ""
                    }
                },
                "hotel": {
                    "book": {
                        "booked": [],
                        "stay": "",
                        "day": "",
                        "people": ""
                    },
                    "semi": {
                        "name": "not mentioned",
                        "area": "not mentioned",
                        "parking": "yes",
                        "pricerange": "cheap",
                        "stars": "not mentioned",
                        "internet": "not mentioned",
                        "type": "hotel"
                    }
                },
                "attraction": {
                    "book": {
                        "booked": []
                    },
                    "semi": {
                        "type": "",
                        "name": "",
                        "area": ""
                    }
                },
                "train": {
                    "book": {
                        "booked": [],
                        "people": ""
                    },
                    "semi": {
                        "leaveAt": "",
                        "destination": "",
                        "day": "",
                        "arriveBy": "",
                        "departure": ""
                    }
                }
            }
        }, {
            "text": "Yes, please. 6 people 3 nights starting on tuesday.",
            "metadata": {}
        }, {
            "text":
            "I am sorry but I wasn't able to book that for you for Tuesday. Is there another day you would like to stay or perhaps a shorter stay?",
            "metadata": {
                "taxi": {
                    "book": {
                        "booked": []
                    },
                    "semi": {
                        "leaveAt": "",
                        "destination": "",
                        "departure": "",
                        "arriveBy": ""
                    }
                },
                "police": {
                    "book": {
                        "booked": []
                    },
                    "semi": {}
                },
                "restaurant": {
                    "book": {
                        "booked": [],
                        "time": "",
                        "day": "",
                        "people": ""
                    },
                    "semi": {
                        "food": "",
                        "pricerange": "",
                        "name": "",
                        "area": ""
                    }
                },
                "hospital": {
                    "book": {
                        "booked": []
                    },
                    "semi": {
                        "department": ""
                    }
                },
                "hotel": {
                    "book": {
                        "booked": [],
                        "stay": "3",
                        "day": "tuesday",
                        "people": "6"
                    },
                    "semi": {
                        "name": "not mentioned",
                        "area": "not mentioned",
                        "parking": "yes",
                        "pricerange": "cheap",
                        "stars": "not mentioned",
                        "internet": "not mentioned",
                        "type": "hotel"
                    }
                },
                "attraction": {
                    "book": {
                        "booked": []
                    },
                    "semi": {
                        "type": "",
                        "name": "",
                        "area": ""
                    }
                },
                "train": {
                    "book": {
                        "booked": [],
                        "people": ""
                    },
                    "semi": {
                        "leaveAt": "",
                        "destination": "",
                        "day": "",
                        "arriveBy": "",
                        "departure": ""
                    }
                }
            }
        }, {
            "text": "how about only 2 nights.",
            "metadata": {}
        }, {
            "text":
            "Booking was successful.\nReference number is : 7GAWK763. Anything else I can do for you?",
            "metadata": {
                "taxi": {
                    "book": {
                        "booked": []
                    },
                    "semi": {
                        "leaveAt": "",
                        "destination": "",
                        "departure": "",
                        "arriveBy": ""
                    }
                },
                "police": {
                    "book": {
                        "booked": []
                    },
                    "semi": {}
                },
                "restaurant": {
                    "book": {
                        "booked": [],
                        "time": "",
                        "day": "",
                        "people": ""
                    },
                    "semi": {
                        "food": "",
                        "pricerange": "",
                        "name": "",
                        "area": ""
                    }
                },
                "hospital": {
                    "book": {
                        "booked": []
                    },
                    "semi": {
                        "department": ""
                    }
                },
                "hotel": {
                    "book": {
                        "booked": [{
                            "name": "the cambridge belfry",
                            "reference": "7GAWK763"
                        }],
                        "stay":
                        "2",
                        "day":
                        "tuesday",
                        "people":
                        "6"
                    },
                    "semi": {
                        "name": "not mentioned",
                        "area": "not mentioned",
                        "parking": "yes",
                        "pricerange": "cheap",
                        "stars": "not mentioned",
                        "internet": "not mentioned",
                        "type": "hotel"
                    }
                },
                "attraction": {
                    "book": {
                        "booked": []
                    },
                    "semi": {
                        "type": "",
                        "name": "",
                        "area": ""
                    }
                },
                "train": {
                    "book": {
                        "booked": [],
                        "people": ""
                    },
                    "semi": {
                        "leaveAt": "",
                        "destination": "",
                        "day": "",
                        "arriveBy": "",
                        "departure": ""
                    }
                }
            }
        }, {
            "text": "No, that will be all. Good bye.",
            "metadata": {}
        }, {
            "text": "Thank you for using our services.",
            "metadata": {
                "taxi": {
                    "book": {
                        "booked": []
                    },
                    "semi": {
                        "leaveAt": "",
                        "destination": "",
                        "departure": "",
                        "arriveBy": ""
                    }
                },
                "police": {
                    "book": {
                        "booked": []
                    },
                    "semi": {}
                },
                "restaurant": {
                    "book": {
                        "booked": [],
                        "time": "",
                        "day": "",
                        "people": ""
                    },
                    "semi": {
                        "food": "",
                        "pricerange": "",
                        "name": "",
                        "area": ""
                    }
                },
                "hospital": {
                    "book": {
                        "booked": []
                    },
                    "semi": {
                        "department": ""
                    }
                },
                "hotel": {
                    "book": {
                        "booked": [{
                            "name": "the cambridge belfry",
                            "reference": "7GAWK763"
                        }],
                        "stay":
                        "2",
                        "day":
                        "tuesday",
                        "people":
                        "6"
                    },
                    "semi": {
                        "name": "not mentioned",
                        "area": "not mentioned",
                        "parking": "yes",
                        "pricerange": "cheap",
                        "stars": "not mentioned",
                        "internet": "not mentioned",
                        "type": "hotel"
                    }
                },
                "attraction": {
                    "book": {
                        "booked": []
                    },
                    "semi": {
                        "type": "",
                        "name": "",
                        "area": ""
                    }
                },
                "train": {
                    "book": {
                        "booked": [],
                        "people": ""
                    },
                    "semi": {
                        "leaveAt": "",
                        "destination": "",
                        "day": "",
                        "arriveBy": "",
                        "departure": ""
                    }
                }
            }
        }]
    }
    dialogue2 = {
        "log": [{
            "text":
            "am looking for a place to to stay that has cheap price range it should be in a type of hotel"
        }, {
            "text": "Okay, do you have a specific area you want to stay in?",
            "metadata": default_state()['belief_state']
        }, {
            "text":
            "no, i just need to make sure it's cheap. oh, and i need parking"
        }, {
            "text":
            "I found 1 cheap hotel for you that includes parking. Do you like me to book it?",
            "metadata": default_state()['belief_state']
        }, {
            "text": "Yes, please. 6 people 3 nights starting on tuesday."
        }, {
            "text":
            "I am sorry but I wasn't able to book that for you for Tuesday. Is there another day you would like to stay or perhaps a shorter stay?",
            "metadata": default_state()['belief_state']
        }, {
            "text": "how about only 2 nights."
        }, {
            "text":
            "Booking was successful.\nReference number is : 7GAWK763. Anything else I can do for you?",
            "metadata": default_state()['belief_state']
        }, {
            "text": "No, that will be all. Good bye."
        }, {
            "text": "Thank you for using our services.",
            "metadata": default_state()['belief_state']
        }]
    }
    create_data(dialogue2)
    print(
        'Create WOZ-like dialogues. Get yourself a coffee, this might take a while.'
    )
    delex_data = createData()
    print('Divide dialogues...')
    divideData(delex_data)
Esempio n. 20
0
 def init_session(self):
     self.dial = {"cur": {"log": []}}
     self.prev_state = default_state()
     self.prev_active_domain = None
Esempio n. 21
0
 def reset():
     self.prev_state = default_state()
Esempio n. 22
0
def train(model, reader, optimizer, writer, config, local_rank, evaluator):
    iterator = reader.make_batch(reader.train)

    inform_rate = 0
    success_rate = 0
    inform_test = 0
    success_test = 0
    dialog_count = 0

    if local_rank == 0:  # only one process prints something
        t = tqdm(enumerate(iterator), total=train.max_iter, ncols=250)
    else:
        t = enumerate(iterator)

    for batch_idx, batch in t:
        # try:
        inputs, contexts, context_lengths, dial_ids = reader.make_input(batch)
        batch_size = len(contexts[0])

        turns = len(inputs)
        gate_loss = 0
        value_loss = 0
        action_loss = 0
        response_loss = 0
        slot_acc = 0
        joint_acc = 0
        perplexity = 0
        batch_count = 0  # number of batches

        beliefs = []  # [batch, slots, len] * turn => list
        belief_gens = []  # [batch, slots, len] * turn => list
        response_gens = []  # [batch, len] * turn => list

        distributed_batch_size = math.ceil(batch_size / config.num_gpus)

        action_history = [[] for i in range(distributed_batch_size)]

        model.module.states = [
            default_state() for i in range(distributed_batch_size)
        ]  # user_action, system_action, belief_state, request_state, terminated, history
        for b_idx in range(distributed_batch_size):
            model.module.states[b_idx]["history"].append(["sys", "null"])

        for turn_idx in range(turns):
            context_len = contexts[turn_idx].size(1)

            # distribute batches to each gpu
            for key, value in inputs[turn_idx].items():
                if key != "prev_gate":
                    inputs[turn_idx][key] = distribute_data(
                        value, config.num_gpus)[local_rank]
            contexts[turn_idx] = distribute_data(contexts[turn_idx],
                                                 config.num_gpus)[local_rank]
            context_lengths[turn_idx] = distribute_data(
                context_lengths[turn_idx], config.num_gpus)[local_rank]

            # if turn_idx > 0:
            #     inputs[turn_idx]["prev_gate"] = prev_gate
            contexts[turn_idx] = contexts[
                turn_idx][:, :context_lengths[turn_idx].max()]

            first_turn = (turn_idx == 0)

            # teacher_forcing = 1 if np.random.rand() >= 0.5 else 0
            teacher_forcing = 0

            optimizer.zero_grad()

            gate_loss_, value_loss_, acc_belief, belief_gen, action_history_ = model.forward(inputs[turn_idx], contexts[turn_idx], context_lengths[turn_idx], \
                action_history, teacher_forcing)
            loss = gate_loss_ + value_loss_
            gate_loss += gate_loss_ * distributed_batch_size
            value_loss += value_loss_ * distributed_batch_size
            slot_acc += acc_belief.sum(dim=1).sum(dim=0)
            joint_acc += (acc_belief.mean(dim=1) == 1).sum(dim=0).float()
            batch_count += distributed_batch_size

            action_history = action_history_

            # gate_loss_, value_loss_, action_loss_, response_loss_, acc_belief, ppl, prev_gate, belief_gen, response_gen, _, _, _ = \
            #     model.forward(inputs[turn_idx], contexts[turn_idx], context_lengths[turn_idx], action_history, teacher_forcing)
            #     # [1], [1], [1], [1], [batch, slots], [1], [batch, slots]

            # loss = gate_loss_ + value_loss_ + action_loss_ + response_loss_
            # gate_loss += gate_loss_ * distributed_batch_size
            # value_loss += value_loss_ * distributed_batch_size
            # action_loss += action_loss_ * distributed_batch_size
            # response_loss += response_loss_ * distributed_batch_size
            # slot_acc += acc_belief.sum(dim=1).sum(dim=0)
            # joint_acc += (acc_belief.mean(dim=1) == 1).sum(dim=0)
            # perplexity += ppl * distributed_batch_size
            # batch_count += distributed_batch_size

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 10)
            optimizer.step()
            torch.cuda.empty_cache()

            # belief = []  # [batch, slots, len] => list
            # for belief_batch in inputs[turn_idx]["belief"]:
            #     belief_ = []
            #     for belief_slot in belief_batch:
            #         belief_.append(reader.vocab.decode(belief_slot[belief_slot != 0].tolist()[:-1]))  # remove <eos>
            #     belief.append(belief_)
            # beliefs.append(belief)

            # belief_gens.append(belief_gen)
            # response_gens.append([reader.vocab.decode(response) for response in response_gen])

        # mean of all gpus
        # dist.all_reduce(gate_loss)
        # gate_loss = gate_loss / config.num_gpus
        # dist.all_reduce(value_loss)
        # value_loss = value_loss / config.num_gpus
        # dist.all_reduce(action_loss)
        # action_loss = action_loss / config.num_gpus
        # dist.all_reduce(response_loss)
        # response_loss = response_loss / config.num_gpus
        # dist.all_reduce(slot_acc)
        # slot_acc = slot_acc / config.num_gpus
        # dist.all_reduce(joint_acc)
        # joint_acc = joint_acc / config.num_gpus
        # dist.all_reduce(perplexity)
        # perplexity = perplexity / config.num_gpus

        gate_loss = gate_loss.item() / batch_count
        value_loss = value_loss.item() / batch_count
        # action_loss = action_loss.item() / batch_count
        # response_loss = response_loss.item() / batch_count
        slot_acc = slot_acc.item() / batch_count / len(
            ontology.all_info_slots) * 100
        joint_acc = joint_acc.item() / batch_count * 100
        # perplexity = perplexity.item() / batch_count
        total_loss = gate_loss + value_loss  #+ action_loss + response_loss
        train.global_step += 1

        # dial_ids = distribute_data(dial_ids, config.num_gpus)[local_rank]
        # inform_rate_, success_rate_, inform_test_, success_test_ = evaluator.evaluate(beliefs, belief_gens, response_gens, dial_ids)

        # inform_rate += inform_rate_
        # success_rate += success_rate_
        # inform_test += inform_test_
        # success_test += success_test_
        # dialog_count += distributed_batch_size

        if local_rank == 0:
            writer.add_scalar("Train/loss", total_loss, train.global_step)
            t.set_description("iter: {}, total loss: {:.4f}, gate loss: {:.4f}, value loss: {:.4f}, joint accuracy: {:.4f}, slot accuracy: {:.4f}"\
                .format(batch_idx+1, total_loss, gate_loss, value_loss, joint_acc, slot_acc))
Esempio n. 23
0
    def predict_response(self, state):
        history = []
        for i in range(len(state['history'])):
            history.append(state['history'][i][1])

        e_idx = len(history)
        s_idx = max(0, e_idx - self.config.backward_size)
        context = []
        for turn in history[s_idx:e_idx]:
            # turn = pad_to(config.max_utt_len, turn, do_pad=False)
            context.append(turn)

        if len(state['history']) == 1:
            self.prev_state = default_state()

        prepared_data = {}
        prepared_data['context'] = []
        prepared_data['response'] = {}

        prev_action = deepcopy(self.prev_state['user_action'])
        prev_bstate = deepcopy(self.prev_state['belief_state'])
        state_history = state['history']
        action = deepcopy(state['user_action'])
        bstate = deepcopy(state['belief_state'])

        # mark_not_mentioned(prev_state)
        #active_domain = self.get_active_domain_convlab(self.prev_active_domain, prev_bstate, bstate)
        active_domain = self.get_active_domain(self.prev_active_domain,
                                               prev_bstate, bstate)
        #print(active_domain)
        domain_mark_not_mentioned(bstate, active_domain)

        top_results, num_results = None, None
        for usr in context:

            words = usr.split()

            usr = delexicalize.delexicalise(' '.join(words), self.dic)

            # parsing reference number GIVEN belief state
            usr = delexicaliseReferenceNumber(usr, bstate)

            # changes to numbers only here
            digitpat = re.compile('\d+')
            usr = re.sub(digitpat, '[value_count]', usr)
            # add database pointer
            pointer_vector, top_results, num_results = addDBPointer(
                bstate, self.db)
            #print(top_results)
            # add booking pointer
            pointer_vector = addBookingPointer(bstate, pointer_vector)
            belief_summary = get_summary_bstate(bstate)

            usr_utt = [BOS] + usr.split() + [EOS]
            packed_val = {}
            packed_val['bs'] = belief_summary
            packed_val['db'] = pointer_vector
            packed_val['utt'] = self.corpus._sent2id(usr_utt)

            prepared_data['context'].append(packed_val)

        prepared_data['response']['bs'] = prepared_data['context'][-1]['bs']
        prepared_data['response']['db'] = prepared_data['context'][-1]['db']
        results = [
            Pack(context=prepared_data['context'],
                 response=prepared_data['response'])
        ]

        data_feed = prepare_batch_gen(results, self.config)

        outputs = self.model_predict(data_feed)
        if active_domain is not None and active_domain in num_results:
            num_results = num_results[active_domain]
        else:
            num_results = 0

        if active_domain is not None and active_domain in top_results:
            top_results = {active_domain: top_results[active_domain]}
        else:
            top_results = {}

        state_with_history = deepcopy(bstate)
        state_with_history['history'] = deepcopy(state_history)

        response = self.populate_template(outputs, top_results, num_results,
                                          state_with_history)

        return response, active_domain
Esempio n. 24
0
def validate(model, reader, evaluator, config, local_rank):
    model.eval()
    val_loss = 0
    slot_acc = 0
    joint_acc = 0
    perplexity = 0
    batch_count = 0
    inform_rate = 0
    success_rate = 0
    inform_test = 0
    success_test = 0
    dialog_count = 0

    with torch.no_grad():
        iterator = reader.make_batch(reader.dev)

        if local_rank == 0:
            t = tqdm(enumerate(iterator), total=validate.max_iter, ncols=150)
        else:
            t = enumerate(iterator)

        for batch_idx, batch in t:
            inputs, contexts, context_lengths, dial_ids = reader.make_input(
                batch)
            batch_size = len(contexts[0])

            beliefs = []  # [batch, slots, len] * turn => list
            belief_gens = []  # [batch, slots, len] * turn => list
            response_gens = []  # [batch, len] * turn => list
            action_gens = []
            restores = []
            turn_domains = []

            turns = len(inputs)

            distributed_batch_size = math.ceil(batch_size / config.num_gpus)

            action_history = [[] for i in range(distributed_batch_size)]

            model.module.states = [
                default_state() for i in range(distributed_batch_size)
            ]
            for b_idx in range(distributed_batch_size):
                model.module.states[b_idx]["history"].append(["sys", "null"])

            for turn_idx in range(turns):
                for key, value in inputs[turn_idx].items():
                    if key != "prev_gate":
                        inputs[turn_idx][key] = distribute_data(
                            value, config.num_gpus)[local_rank]
                contexts[turn_idx] = distribute_data(
                    contexts[turn_idx], config.num_gpus)[local_rank]
                context_lengths[turn_idx] = distribute_data(
                    context_lengths[turn_idx], config.num_gpus)[local_rank]

                # if turn_idx > 0:
                #     inputs[turn_idx]["prev_gate"] = prev_gate
                contexts[turn_idx] = contexts[
                    turn_idx][:, :context_lengths[turn_idx].max()]

                first_turn = (turn_idx == 0)

                teacher_forcing = 0

                gate_loss_, value_loss_, acc_belief, belief_gen, action_history_ = model.forward(inputs[turn_idx], contexts[turn_idx], context_lengths[turn_idx], \
                    action_history, teacher_forcing)
                # gate_loss_, value_loss_, action_loss_, response_loss_, acc_belief, ppl, prev_gate, belief_gen, response_gen, action_gen, restored, turn_domain = \
                #     model.forward(inputs[turn_idx], contexts[turn_idx], context_lengths[turn_idx], teacher_forcing)

                loss = gate_loss_ + value_loss_  #+ action_loss_ + response_loss_
                val_loss += loss * distributed_batch_size
                slot_acc += acc_belief.sum(dim=1).sum(dim=0)
                joint_acc += (acc_belief.mean(dim=1) == 1).sum(dim=0).float()
                # perplexity += ppl * distributed_batch_size
                batch_count += distributed_batch_size

                action_history = action_history_

                torch.cuda.empty_cache()

                # belief = []  # [batch, slots, len] => list
                # for belief_batch in inputs[turn_idx]["belief"]:
                #     belief_ = []
                #     for belief_slot in belief_batch:
                #         belief_.append(reader.vocab.decode(belief_slot[belief_slot != 0].tolist()[:-1]))  # remove <eos>
                #     belief.append(belief_)
                # beliefs.append(belief)

                # belief_gens.append(belief_gen)
                # response_gens.append([reader.vocab.decode(response) for response in response_gen])
                # action_gens.append([reader.vocab.decode(action) for action in action_gen])
                # restores.append(restored)
                # turn_domains.append([ontology.all_domains[domain] for domain in turn_domain])

            if local_rank == 0:
                t.set_description("iter: {}".format(batch_idx + 1))

            # dial_ids = distribute_data(dial_ids, config.num_gpus)[local_rank]
            # inform_rate_, success_rate_, inform_test_, success_test_ = evaluator.evaluate(beliefs, belief_gens, response_gens, dial_ids)

            # inform_rate += inform_rate_
            # success_rate += success_rate_
            # inform_test += inform_test_
            # success_test += success_test_
            # dialog_count += distributed_batch_size
    # if turns > 1:
    #     for i in range(min(5,len(response_gens[1]))):
    #         print(local_rank, reader.vocab.decode(inputs[1]["user"][i][inputs[1]["user"][i]!=0].tolist()), "@@", action_gens[1][i], "@@", response_gens[1][i], "@@", \
    #             restores[1][i], "@@", turn_domains[1][i])

    dist.all_reduce(val_loss)
    val_loss = val_loss / config.num_gpus
    dist.all_reduce(slot_acc)
    slot_acc = slot_acc / config.num_gpus
    dist.all_reduce(joint_acc)
    joint_acc = joint_acc / config.num_gpus
    # dist.all_reduce(perplexity)
    # perplexity = perplexity / config.num_gpus
    # dist.all_reduce(inform_rate)
    # inform_rate = inform_rate / config.num_gpus
    # dist.all_reduce(success_rate)
    # success_rate = success_rate / config.num_gpus
    # dist.all_reduce(inform_test)
    # inform_test = inform_test / config.num_gpus
    # dist.all_reduce(success_test)
    # success_test = success_test / config.num_gpus

    model.train()
    val_loss = val_loss.item() / batch_count
    slot_acc = slot_acc.item() / batch_count / len(
        ontology.all_info_slots) * 100
    joint_acc = joint_acc.item() / batch_count * 100
    # perplexity = perplexity.item() / batch_count
    # inform_rate = inform_rate / dialog_count * 100
    # success_rate = success_rate / dialog_count * 100
    # inform_test = inform_test / dialog_count * 100
    # success_test = success_test / dialog_count * 100

    return val_loss, joint_acc, slot_acc  #, perplexity, inform_rate, success_rate, inform_test, success_test
    from convlab2.nlg.template.multiwoz.nlg import TemplateNLG
    user_nlg = TemplateNLG(is_user=True, mode='manual')
    sys_nlg = TemplateNLG(is_user=False, mode='manual')
    from convlab2.util.multiwoz.state import default_state

    user_policy.init_session()
    sys_policy.init_session()

    print(user_policy.goal)

    print(user_policy.agenda)
    user_act = user_policy.predict([])
    print(user_act)
    user_utt = user_nlg.generate(user_act)
    print(user_utt)
    state = default_state()
    state['user_action'] = user_act
    sys_act = sys_policy.predict(state)
    sys_act.append(["Request", "Restaurant", "Price", "?"])
    print(sys_act)

    user_act = user_policy.predict(sys_act)
    print(user_act)
    user_utt = user_nlg.generate(user_act)
    print(user_utt)
    sys_act = sys_policy.predict(state)
    print(sys_act)

    user_act = user_policy.predict(sys_act)
    print(user_act)
    user_utt = user_nlg.generate(user_act)
Esempio n. 26
0
 def init_session(self):
     self.state = default_state()