Пример #1
0
    def _build_data(self, root_dir, processed_dir):
        self.data = {}
        data_loader = ActPolicyDataloader(
            dataset_dataloader=MultiWOZDataloader())
        for part in ['train', 'val', 'test']:
            self.data[part] = []
            raw_data = data_loader.load_data(data_key=part, role='sys')[part]

            for belief_state, context_dialog_act, terminated, dialog_act in \
                zip(raw_data['belief_state'], raw_data['context_dialog_act'], raw_data['terminated'], raw_data['dialog_act']):
                state = default_state()
                state['belief_state'] = belief_state
                state['user_action'] = context_dialog_act[-1]
                state['system_action'] = context_dialog_act[-2] if len(
                    context_dialog_act) > 1 else {}
                state['terminated'] = terminated
                action = dialog_act
                self.data[part].append([
                    self.vector.state_vectorize(state),
                    self.vector.action_vectorize(action)
                ])

        os.makedirs(processed_dir)
        for part in ['train', 'val', 'test']:
            with open(os.path.join(processed_dir, '{}.pkl'.format(part)),
                      'wb') as f:
                pickle.dump(self.data[part], f)
Пример #2
0
    def __init__(self, data_dir='configs', data=None):
        """Constructor of MultiWOzMDBT class.
        Args:
            data_dir (str): The path of data dir, where the root path is convlab2/dst/mdbt/multiwoz.
        """
        if data is None:
            loader = AgentDSTDataloader(MultiWOZDataloader())
            data = loader.load_data()
        self.file_url = 'https://convlab.blob.core.windows.net/convlab-2/mdbt_multiwoz_sys.zip'
        local_path = os.path.dirname(os.path.abspath(__file__))
        self.data_dir = os.path.join(local_path,
                                     data_dir)  # abstract data path

        self.validation_url = os.path.join(self.data_dir, 'data/validate.json')
        self.training_url = os.path.join(self.data_dir, 'data/train.json')
        self.testing_url = os.path.join(self.data_dir, 'data/test.json')

        self.word_vectors_url = os.path.join(
            self.data_dir, 'word-vectors/paragram_300_sl999.txt')
        self.ontology_url = os.path.join(self.data_dir, 'data/ontology.json')
        self.model_url = os.path.join(self.data_dir, 'models/model-1')
        self.graph_url = os.path.join(self.data_dir, 'graphs/graph-1')
        self.results_url = os.path.join(self.data_dir, 'results/log-1.txt')
        self.kb_url = os.path.join(self.data_dir, 'data/')  # not used
        self.train_model_url = os.path.join(self.data_dir,
                                            'train_models/model-1')
        self.train_graph_url = os.path.join(self.data_dir,
                                            'train_graph/graph-1')

        self.auto_download()

        print('Configuring MDBT model...')
        self.word_vectors = load_word_vectors(self.word_vectors_url)

        # Load the ontology and extract the feature vectors
        self.ontology, self.ontology_vectors, self.slots = load_ontology(
            self.ontology_url, self.word_vectors)

        # Load and process the training data
        self.test_dialogues, self.actual_dialogues = load_woz_data_new(
            data['test'],
            self.word_vectors,
            self.ontology,
            url=self.testing_url)
        self.no_dialogues = len(self.test_dialogues)

        super(MultiWozMDBT,
              self).__init__(self.ontology_vectors, self.ontology, self.slots,
                             self.data_dir)
Пример #3
0
                from convlab2.dst.sumbt.multiwoz.sumbt import SUMBTTracker
                model = SUMBTTracker()
            elif model_name == 'TRADE':
                from convlab2.dst.trade.multiwoz.trade import MultiWOZTRADE
                model = MultiWOZTRADE()
            elif model_name == 'mdbt':
                from convlab2.dst.mdbt.multiwoz.dst import MultiWozMDBT
                model = MultiWozMDBT()
            else:
                raise Exception("Available models: TRADE/mdbt/sumbt")

        ## load data
        from convlab2.util.dataloader.module_dataloader import AgentDSTDataloader
        from convlab2.util.dataloader.dataset_dataloader import MultiWOZDataloader
        dataloader = AgentDSTDataloader(
            dataset_dataloader=MultiWOZDataloader(dataset_name.endswith('zh')))
        data = dataloader.load_data(data_key=data_key)[data_key]
        context, golden_truth = data['context'], data['belief_state']
        all_predictions = {}
        test_set = []
        curr_sess = {}
        session_count = 0
        turn_count = 0
        is_start = True
        for i in tqdm(range(len(context))):
            if len(context[i]) == 0:
                turn_count = 0
                if is_start:
                    is_start = False
                else:  # save session
                    all_predictions[session_count] = copy.deepcopy(curr_sess)
Пример #4
0
    def load_data(self, *args, **kwargs):
        kwargs.setdefault('belief_state', True)
        kwargs.setdefault('utterance', True)
        kwargs.setdefault('context', True)
        kwargs.setdefault('context_window_size', 3)
        return self.dataset_dataloader.load_data(*args, **kwargs)


class SingleTurnNLGDataloader(ModuleDataloader):
    def load_data(self, *args, **kwargs):
        kwargs.setdefault('utterance', True)
        kwargs.setdefault('dialog_act', True)
        return self.dataset_dataloader.load_data(*args, **kwargs)


class MultiTurnNLGDataloader(ModuleDataloader):
    def load_data(self, *args, **kwargs):
        kwargs.setdefault('utterance', True)
        kwargs.setdefault('dialog_act', True)
        kwargs.setdefault('context', True)
        kwargs.setdefault('context_window_size', 3)
        return self.dataset_dataloader.load_data(*args, **kwargs)


if __name__ == '__main__':
    d = MultiTurnNLUDataloader(dataset_dataloader=MultiWOZDataloader())
    data = d.load_data(data_key='val', role='usr')
    pprint(data['val']['utterance'][:5])
    pprint(data['val']['context'][:5])
    pprint(data['val']['dialog_act'][:5])
Пример #5
0
        if model_name == 'TRADE':
            from convlab2.dst.trade.multiwoz.trade import MultiWOZTRADE
            model = MultiWOZTRADE()
        elif model_name == 'mdbt':
            from convlab2.dst.mdbt.multiwoz.dst import MultiWozMDBT
            model = MultiWozMDBT()
        elif model_name == 'sumbt':
            from convlab2.dst.sumbt.multiwoz.sumbt import SUMBTTracker
            model = SUMBTTracker()
        else:
            raise Exception("Available models: TRADE/mdbt/sumbt")

        ## load data
        from convlab2.util.dataloader.module_dataloader import AgentDSTDataloader
        from convlab2.util.dataloader.dataset_dataloader import MultiWOZDataloader
        dataloader = AgentDSTDataloader(dataset_dataloader=MultiWOZDataloader())
        data = dataloader.load_data(data_key='test')['test']
        context, golden_truth = data['context'], data['belief_state']
        all_predictions = {}
        test_set = []
        curr_sess = {}
        session_count = 0
        turn_count = 0
        is_start = True
        for i in tqdm(range(len(context))):
        # for i in tqdm(range(200)):  # for test
            if len(context[i]) == 0:
                turn_count = 0
                if is_start:
                    is_start = False
                else:  # save session
Пример #6
0
    joint_acc_score = joint_acc / float(total) if total != 0 else 0
    turn_acc_score = turn_acc / float(total) if total != 0 else 0
    F1_score = F1_pred / float(F1_count) if F1_count != 0 else 0
    return joint_acc_score, F1_score, turn_acc_score

if __name__ == '__main__':

    ## init phase
    config_data = sys.argv[1]
    model = SUMBTTracker(arg_path = config_data)
    if dataset_name.startswith('MultiWOZ'):
        dataset_name = dataset_name + "zh"
        ## load data
        from convlab2.util.dataloader.module_dataloader import AgentDSTDataloader
        from convlab2.util.dataloader.dataset_dataloader import MultiWOZDataloader
        dataloader = AgentDSTDataloader(dataset_dataloader=MultiWOZDataloader(config['lang']=='zh')))
        data = dataloader.load_data(data_key=data_key)[data_key]
        context, golden_truth = data['context'], data['belief_state']
        all_predictions = {}
        test_set = []
        curr_sess = {}
        session_count = 0
        turn_count = 0
        is_start = True
        for i in tqdm(range(len(context))):
            if len(context[i]) == 0:
                turn_count = 0
                if is_start:
                    is_start = False
                else:  # save session
                    all_predictions[session_count] = copy.deepcopy(curr_sess)
Пример #7
0
     if model_name == 'sumbt':
         from convlab2.dst.sumbt.multiwoz.sum1bt import SUMBTTracker
         model = SUMBTTracker()
     elif model_name == 'TRADE':
         from convlab2.dst.trade.multiwoz.trade import MultiWOZTRADE
         model = MultiWOZTRADE()
     elif model_name == 'mdbt':
         from convlab2.dst.mdbt.multiwoz.dst import MultiWozMDBT
         model = MultiWozMDBT()
     else:
         raise Exception("Available models: TRADE/mdbt/sumbt")
 dataset_name = dataset_name + "zh"
 ## load data
 from convlab2.util.dataloader.module_dataloader import AgentDSTDataloader
 from convlab2.util.dataloader.dataset_dataloader import MultiWOZDataloader
 dataloader = AgentDSTDataloader(dataset_dataloader=MultiWOZDataloader(
     not (dataset_name.endswith('zh'))))
 data = dataloader.load_data(
     data_dir=
     '/fs/startiger0/nmoghe/code/ConvLab-2/data/zh_data_en_states',
     data_key=data_key)[data_key]
 context, golden_truth = data['context'], data['belief_state']
 all_predictions = {}
 test_set = []
 curr_sess = {}
 session_count = 0
 turn_count = 0
 is_start = True
 for i in tqdm(range(len(context))):
     if len(context[i]) == 0:
         turn_count = 0
         if is_start:
Пример #8
0
class WordPolicyDataloader(ModuleDataloader):
    def load_data(self, *args, **kwargs):
        kwargs.setdefault('belief_state', True)
        kwargs.setdefault('utterance', True)
        kwargs.setdefault('context', True)
        kwargs.setdefault('context_window_size', 3)
        return self.dataset_dataloader.load_data(*args, **kwargs)


class SingleTurnNLGDataloader(ModuleDataloader):
    def load_data(self, *args, **kwargs):
        kwargs.setdefault('utterance', True)
        kwargs.setdefault('dialog_act', True)
        return self.dataset_dataloader.load_data(*args, **kwargs)


class MultiTurnNLGDataloader(ModuleDataloader):
    def load_data(self, *args, **kwargs):
        kwargs.setdefault('utterance', True)
        kwargs.setdefault('dialog_act', True)
        kwargs.setdefault('context', True)
        kwargs.setdefault('context_window_size', 3)
        return self.dataset_dataloader.load_data(*args, **kwargs)


if __name__ == '__main__':
    d = SingleTurnNLUDataloader(dataset_dataloader=MultiWOZDataloader())
    data = d.load_data(data_key='val', role='user')
    pprint(data['val']['utterance'][:5])
    pprint(data['val']['dialog_act'][:5])
Пример #9
0
                model = SCLSTM(is_user=True, use_cuda=True)
            elif role == 'sys':
                model = SCLSTM(is_user=False, use_cuda=True)
        elif model_name == 'TemplateNLG':
            from convlab2.nlg.template.multiwoz import TemplateNLG
            if role == 'usr':
                model = TemplateNLG(is_user=True)
            elif role == 'sys':
                model = TemplateNLG(is_user=False)
        else:
            raise Exception("Available models: SCLSTM, TEMPLATE")

        from convlab2.util.dataloader.module_dataloader import SingleTurnNLGDataloader
        from convlab2.util.dataloader.dataset_dataloader import MultiWOZDataloader
        dataloader = SingleTurnNLGDataloader(
            dataset_dataloader=MultiWOZDataloader())
        data = dataloader.load_data(data_key='test', role=role)['test']

        dialog_acts = []
        golden_utts = []
        gen_utts = []
        gen_slots = []

        sen_num = 0

        for i in tqdm(range(len(data['utterance']))):
            dialog_acts.append(data['dialog_act'][i])
            golden_utts.append(data['utterance'][i])
            gen_utts.append(model.generate(data['dialog_act'][i]))

        bleu4 = get_bleu4(dialog_acts, golden_utts, gen_utts)
Пример #10
0
def test_update():
    os.environ["CUDA_VISIBLE_DEVICES"] = '0'
    _config = tf.ConfigProto()
    _config.gpu_options.allow_growth = True
    _config.allow_soft_placement = True
    start_time = time.time()
    mdbt = MultiWozMDBT()
    print('\tMDBT: model build time: {:.2f} seconds'.format(time.time() -
                                                            start_time))
    mdbt.restore()
    # demo state history
    mdbt.state['history'] = [
        [
            'null',
            'I\'m trying to find an expensive restaurant in the centre part of town.'
        ],
        [
            'The Cambridge Chop House is an good expensive restaurant in the centre of town. Would you like me to book it for you?',
            'Yes, a table for 1 at 16:15 on sunday.  I need the reference number.'
        ]
    ]
    new_state = mdbt.update('hi, this is not good')
    print(json.dumps(new_state, indent=4))
    print('all time: {:.2f} seconds'.format(time.time() - start_time))


if __name__ == '__main__':
    loader = AgentDSTDataloader(MultiWOZDataloader())
    data = loader.load_data()
    model = MultiWozMDBT(data=data)