예제 #1
0
    def __init__(self, ):
        self.frames = read_zipped_json(
            os.path.join(DATA_ROOT, 'frames/Ori', 'train.json.zip'),
            'train.json')

        frames_frames_domain_slot_map = {
            # ('frame', 'category'): ('hotel', 'category'),
            ('frame', 'dst_city'): ('hotel', 'location'),
            # ('frame', 'gst_rating'): ('hotel', 'gst_rating'),
            ('frame', 'name'): ('hotel', 'name'),
            ('frame', 'or_city'): ('trip', 'or_city'),
            # ('frame', 'seat'): ('trip', 'seat'),
        }

        frames_sgd_domain_slot_map = {
            ('frame', 'dst_city'): ('hotels', 'dst_city'),
            ('frame', 'name'): ('hotels', 'hotel_name'),
            ('frame', 'or_city'): ('travel', 'location'),
        }
        frames_db_dir = os.path.join(
            get_root_path(), "convlab2/laug/Word_Perturbation/db/frames-db/")
        sgd_db_dir = os.path.join(
            get_root_path(), "convlab2/laug/Word_Perturbation/db/sgd-db/")

        loader_args = [
            MultiSourceDBLoaderArgs(frames_db_dir,
                                    frames_frames_domain_slot_map),
            MultiSourceDBLoaderArgs(sgd_db_dir, frames_sgd_domain_slot_map)
        ]
        self.db_loader = MultiSourceDBLoader(loader_args)
예제 #2
0
    def __init__(self):
        self.max_label_length = 32
        self.max_turn_length = 22
        self.hidden_dim = 100
        self.num_rnn_layers = 1
        self.zero_init_rnn = False
        self.attn_head = 4
        self.do_eval = True
        self.do_train = False
        self.distance_metric = 'cosine'
        self.train_batch_size = 1
        self.dev_batch_size = 1
        self.eval_batch_size = 16
        self.learning_rate = 5e-5
        self.num_train_epochs = 3
        self.patience = 10
        self.warmup_proportion = 0.1
        self.local_rank = -1
        self.seed = 42
        self.gradient_accumulation_steps = 4
        self.fp16 = False
        self.loss_scale = 0
        self.do_not_use_tensorboard = False
        self.fix_utterance_encoder = False
        self.do_eval = True
        self.num_train_epochs = 300

        self.bert_model = os.path.join(
            convlab2.get_root_path(),
            "pre-trained-models/bert-base-multilingual-uncased")
        self.bert_model_cache_dir = os.path.join(convlab2.get_root_path(),
                                                 "pre-trained-models/")
        self.bert_model_name = "bert-base-multilingual-uncased"
        self.do_lower_case = True
        self.task_name = 'bert-gru-sumbt'
        self.nbt = 'rnn'
        # self.output_dir = os.path.join(path, 'ckpt/')
        self.target_slot = 'all'
        self.learning_rate = 5e-5
        self.train_batch_size = 1
        self.eval_batch_size = 16
        self.distance_metric = 'euclidean'
        self.patience = 15

        self.hidden_dim = 300
        self.max_label_length = 32
        self.max_seq_length = 64
        self.max_turn_length = 22

        self.fp16_loss_scale = 0.0
        self.data_dir = 'data/multiwoz/'
        self.tf_dir = 'tensorboard'
        self.tmp_data_dir = 'processed_data/'
        self.output_dir = 'model_output_mbert_ft/'
예제 #3
0
 def __init__(self,
              goal_model_path=os.path.join(
                  get_root_path(), 'data/multiwoz/goal/new_goal_model.pkl'),
              corpus_path=None,
              boldify=False,
              sample_info_from_trainset=True,
              sample_reqt_from_trainset=False):
     """
     Args:
         goal_model_path: path to a goal model
         corpus_path: path to a dialog corpus to build a goal model
         boldify: highlight some information in the goal message
         sample_info_from_trainset: if True, sample info slots combination from train set, else sample each slot independently
         sample_reqt_from_trainset: if True, sample reqt slots combination from train set, else sample each slot independently
     """
     self.goal_model_path = goal_model_path
     self.corpus_path = corpus_path
     self.db = Database()
     self.boldify = do_boldify if boldify else null_boldify
     self.sample_info_from_trainset = sample_info_from_trainset
     self.sample_reqt_from_trainset = sample_reqt_from_trainset
     self.train_database = self.db.query('train', [])
     if os.path.exists(self.goal_model_path):
         self.ind_slot_dist, self.ind_slot_value_dist, self.domain_ordering_dist, self.book_dist, self.slots_num_dist, self.slots_combination_dist = pickle.load(
             open(self.goal_model_path, 'rb'))
         print('Loading goal model is done')
     else:
         self._build_goal_model()
         print('Building goal model is done')
예제 #4
0
    def __init__(self, name, args, sel_args, train=False, diverse=False, max_total_len=100,
                 model_url='https://tatk-data.s3-ap-northeast-1.amazonaws.com/rnnrollout_dealornot.zip'):
        self.config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'configs')

        self.file_url = model_url

        self.auto_download()

        if not os.path.exists(self.config_path):
            os.mkdir(self.config_path)
        _model_path = os.path.join(self.config_path, 'models')
        self.model_path = _model_path
        if not os.path.exists(_model_path):
            os.makedirs(_model_path)

        self.data_path = os.path.join(get_root_path(), args.data)
        domain = get_domain(args.domain)
        corpus = RnnModel.corpus_ty(domain, self.data_path, freq_cutoff=args.unk_threshold, verbose=True,
                                    sep_sel=args.sep_sel)

        model = RnnModel(corpus.word_dict, corpus.item_dict_old,
                         corpus.context_dict, corpus.count_dict, args)
        state_dict = utils.load_model(os.path.join(self.config_path, args.model_file))  # RnnModel
        model.load_state_dict(state_dict)

        sel_model = SelectionModel(corpus.word_dict, corpus.item_dict_old,
                                   corpus.context_dict, corpus.count_dict, sel_args)
        sel_state_dict = utils.load_model(os.path.join(self.config_path, sel_args.selection_model_file))
        sel_model.load_state_dict(sel_state_dict)

        super(DealornotAgent, self).__init__(model, sel_model, args, name, train, diverse, max_total_len)
        self.vis = args.visual
예제 #5
0
    def __init__(self, ):
        self.multiwoz = read_zipped_json(
            os.path.join(DATA_ROOT, 'multiwoz', 'train.json.zip'),
            'train.json')
        multiwoz_db_dir = os.path.join(DATA_ROOT, 'multiwoz', 'db')
        multiwoz_multiwoz_domain_slot_map = {
            ('attraction', 'area'): ('attraction', 'Area'),
            ('attraction', 'type'): ('attraction', 'Type'),
            ('attraction', 'name'): ('attraction', 'Name'),
            ('attraction', 'address'): ('attraction', 'Addr'),
            ('hospital', 'department'): ('hospital', 'Department'),
            ('hospital', 'address'): ('hospital', 'Addr'),
            ('hotel', 'type'): ('hotel', 'Type'),
            ('hotel', 'area'): ('hotel', 'Area'),
            ('hotel', 'name'): ('hotel', 'Name'),
            ('hotel', 'address'): ('hotel', 'Addr'),
            ('restaurant', 'food'): ('restaurant', 'Food'),
            ('restaurant', 'area'): ('restaurant', 'Area'),
            ('restaurant', 'name'): ('restaurant', 'Name'),
            ('restaurant', 'address'): ('restaurant', 'Addr'),
            ('train', 'destination'): ('train', 'Dest'),
            ('train', 'departure'): ('train', 'Depart')
        }

        multiwoz_sgd_domain_slot_map = {
            ('train', 'dest'): ('train', 'to'),
            ('train', 'depart'): ('train', 'from'),
            ('hotel', 'name'): ('hotels', 'hotel_name'),
            ('hotel', 'addr'): ('hotels', 'address'),
            ('attraction', 'name'): ('travel', 'attraction_name'),
            ('restaurant', 'name'): ('restaurants', 'restaurant_name'),
            ('restaurant', 'addr'): ('restaurants', 'street_address')
        }
        loader_args = [
            MultiSourceDBLoaderArgs(multiwoz_db_dir,
                                    multiwoz_multiwoz_domain_slot_map)
        ]
        sgd_db_dir = os.path.join(
            get_root_path(), "convlab2/laug/Word_Perturbation/db/sgd-db/")
        loader_args.append(
            MultiSourceDBLoaderArgs(sgd_db_dir, multiwoz_sgd_domain_slot_map))
        self.db_loader = MultiSourceDBLoader(loader_args)
예제 #6
0
                            user_goal[dom]['fail_book'].items())))[0][0][0]

                if adjusted_slot in ['internet', 'parking']:
                    message.append(
                        templates[dom]['fail_book ' + adjusted_slot + ' ' +
                                       user_goal[dom]['book'][adjusted_slot]])
                else:
                    message.append(
                        templates[dom]['fail_book ' + adjusted_slot].format(
                            self.boldify(
                                user_goal[dom]['book'][adjusted_slot])))

            dm = message[mess_ptr4domain:]
            mess_ptr4domain = len(message)
            message_by_domain.append(' '.join(dm))

        if boldify == do_boldify:
            for i, m in enumerate(message):
                message[i] = message[i].replace('wifi', "<b>wifi</b>")
                message[i] = message[i].replace('internet', "<b>internet</b>")
                message[i] = message[i].replace('parking', "<b>parking</b>")

        return message, message_by_domain


if __name__ == '__main__':
    goal_generator = GoalGenerator(corpus_path=os.path.join(
        get_root_path(), 'data/multiwoz/train.json'),
                                   sample_reqt_from_trainset=True)
    pprint(goal_generator.get_user_goal())
예제 #7
0
            # fail_book
            if 'fail_book' in user_goal[dom]:
                adjusted_slot = list(filter(lambda x: x[0][1] != x[1][1], zip(user_goal[dom]['book'].items(),
                                                                              user_goal[dom]['fail_book'].items())))[0][
                    0][0]

                if adjusted_slot in ['internet', 'parking']:
                    message.append(
                        templates[dom]['fail_book ' + adjusted_slot + ' ' + user_goal[dom]['book'][adjusted_slot]])
                else:
                    message.append(templates[dom]['fail_book ' + adjusted_slot].format(
                        self.boldify(user_goal[dom]['book'][adjusted_slot])))

            dm = message[mess_ptr4domain:]
            mess_ptr4domain = len(message)
            message_by_domain.append(' '.join(dm))

        if boldify == do_boldify:
            for i, m in enumerate(message):
                message[i] = message[i].replace('wifi', "<b>wifi</b>")
                message[i] = message[i].replace('internet', "<b>internet</b>")
                message[i] = message[i].replace('parking', "<b>parking</b>")

        return message, message_by_domain

if __name__ == '__main__':
    goal_generator = GoalGenerator(corpus_path=os.path.join(get_root_path(), 'data/multiwoz/train.json'), sample_reqt_from_trainset=True)
    # goal_generator._build_goal_model()
    pprint(goal_generator.get_user_goal())
예제 #8
0
def get_context_generator(context_file):
    return utils.ContextGenerator(os.path.join(get_root_path(), context_file))