Exemple #1
0
    def __init__(self):
        self.reader = MultiWozReader()
        if len(cfg.cuda_device) == 1:
            self.m = DAMD(self.reader)
        else:
            m = DAMD(self.reader)
            self.m = torch.nn.DataParallel(m, device_ids=cfg.cuda_device)
            # print(self.m.module)
        self.evaluator = MultiWozEvaluator(self.reader)  # evaluator class
        if cfg.cuda: self.m = self.m.cuda()  #cfg.cuda_device[0]
        self.optim = Adam(lr=cfg.lr,
                          params=filter(lambda x: x.requires_grad,
                                        self.m.parameters()),
                          weight_decay=5e-5)
        self.base_epoch = -1

        if cfg.limit_bspn_vocab:
            self.reader.bspn_masks_tensor = {}
            for key, values in self.reader.bspn_masks.items():
                v_ = cuda_(torch.Tensor(values).long())
                self.reader.bspn_masks_tensor[key] = v_
        if cfg.limit_aspn_vocab:
            self.reader.aspn_masks_tensor = {}
            for key, values in self.reader.aspn_masks.items():
                v_ = cuda_(torch.Tensor(values).long())
                self.reader.aspn_masks_tensor[key] = v_
Exemple #2
0
    def __init__(self, device):
        self.device = device
        # initialize tokenizer
        self.tokenizer = GPT2Tokenizer.from_pretrained(cfg.gpt_path)
        # cfg.tokenizer = tokenizer

        # initialize multiwoz reader
        self.reader = MultiWozReader(self.tokenizer)

        # create model: gpt2
        self.model = GPT2LMHeadModel.from_pretrained(cfg.gpt_path)
        if cfg.mode == 'train':
            self.model.resize_token_embeddings(len(self.tokenizer))
        self.model.to(self.device)  # single gpu

        #
        self.evaluator = MultiWozEvaluator(self.reader)
        if cfg.save_log and cfg.mode == 'train':
            self.tb_writer = SummaryWriter(log_dir='./log')
        else:
            self.tb_writer = None
def dialog_turn_state_analysis(mode='train'):
    data_path = 'data/multi-woz-processed/data_for_damd.json'
    conv_data = 'data/multi-woz/annotated_user_da_with_span_full.json'
    archive = zipfile.ZipFile(conv_data + '.zip', 'r')
    convlab_data = json.loads(archive.open(conv_data.split('/')[-1], 'r').read().lower())
    reader = MultiWozReader()
    data = json.loads(open(data_path, 'r', encoding='utf-8').read().lower())

    turn_state_record, turn_state_count, golden_acts = {}, {}, {}
    act_state_collect = []
    act_state_detail = {}
    state_valid_acts = {}
    dial_count = 0
    turn_count = 0

    for fn, dial in data.items():
        dial_count += 1
        state_valid_acts[fn] = {}
        for turn_no, turn in enumerate(dial['log']):
            turn_state = {}
            turn_domain = turn['turn_domain'].split()
            cons_delex = turn['cons_delex'].split()
            sys_act = turn['sys_act']
            usr_act = convlab_data[fn]['log'][turn_no * 2]['dialog_act']
            db_ptr = [int(i) for i in turn['pointer'].split(',')]
            match = turn['match']
            if len(turn_domain) != 1 or turn_domain[0] == '[general]' or not sys_act:
                continue
            state_valid_acts[fn][turn_no] = {}
            turn_count += 1

            slot_mentioned = []
            for idx, tk in enumerate(cons_delex[:-1]):
                if tk in turn_domain:
                    i = idx+1
                    while i < len(cons_delex):
                        if '[' not in cons_delex[i]:
                            slot_mentioned.append(cons_delex[i])
                        else:
                            break
                        i = i+1
            slot_mentioned.sort()
            # turn_state['slot_mentioned'] = len(slot_mentioned)
            turn_state['domain'] = turn_domain
            turn_state['slot_mentioned'] = slot_mentioned
            if match == '':
                turn_state['match']=''
            elif match == '0':
                turn_state['match']='0'
            elif match == '1':
                turn_state['match'] = '1'
            elif match == '2' or match == '3':
                turn_state['match'] = '2-3'
            else:
                turn_state['match']='>3'
            if db_ptr[-2:] == [0,0]:
                turn_state['book'] = ''
            elif db_ptr[-2:] == [1,0]:
                turn_state['book'] = 'no'
            else:
                turn_state['book'] = 'yes'

            turn_state['usract'] = []
            for act in usr_act:
                d, a = act.split('-')
                if a not in turn_state['usract']:
                    slot_list = []
                    if a == 'request':
                        for slot_value in usr_act[act]:

                            slot = slot_value[0]

                            if slot == 'none':
                                continue
                            elif slot not in slot_list:
                                slot = ontology.da_abbr_to_slot_name.get(slot, slot)
                                slot_list.append(slot)
                    if not slot_list:
                        turn_state['usract'].append(a)
                    else:
                        slot_list.sort()
                        turn_state['usract'].append(a+'('+','.join(slot_list)+')')
            turn_state['usract'].sort()

            turn_state_str = ''
            for k,v in turn_state.items():
                if isinstance(v, list):
                    v_ = ','.join(v)
                elif isinstance(v, int):
                    v_ = str(v)
                else:
                    v_ = v
                turn_state_str += '%s(%s);'%(k, v_)
            turn_state_str = turn_state_str[:-1]
            state_valid_acts[fn][turn_no]['usdx'] = turn['user_delex']
            state_valid_acts[fn][turn_no]['state'] = turn_state_str


            if sys_act not in act_state_detail:
                act_state_detail[sys_act] = 1
            act_list = reader.aspan_to_act_list(sys_act)
            act_state = {'domain': {}, 'general': {}}
            for act in act_list:
                d, a, s = act.split('-')
                if d == 'general':
                    act_state['general'][a] = ''
                else:
                    if a not in act_state['domain']:
                        if s != 'none':
                            act_state['domain'][a] = ''
                        else:
                            act_state['domain'][a] = ''
                    else:
                        act_state['domain'][a] = ''

            no_order_act = {}
            for a in act_list:
                no_order_act[a] = 1

            act_state_str = ''
            for k,v in act_state.items():
                if isinstance(v, dict):
                    v_ = ''
                    for kk, vv in v.items():
                        v_ += kk+'(%s),'%str(vv)
                    if v_.endswith(','):
                        v_ = v_[:-1]
                elif isinstance(v, int):
                    v_ = str(v)
                else:
                    v_ = v
                if v_ != '':
                    act_state_str += '%s(%s);'%(k, v_)
            act_state_str = act_state_str[:-1]
            state_valid_acts[fn][turn_no]['gold'] = {}
            state_valid_acts[fn][turn_no]['gold'][act_state_str] = {}
            state_valid_acts[fn][turn_no]['gold'][act_state_str]['resp'] = turn['resp']
            state_valid_acts[fn][turn_no]['gold'][act_state_str]['act'] = sys_act

            if mode == 'test' and fn not in reader.test_files:
                continue
            if mode == 'train' and fn in reader.test_files:
                continue
            if act_state not in act_state_collect:
                act_state_collect.append(act_state)
            new_state = True if turn_state_str not in turn_state_record else False
            raw_sys_rec  = fn+'-'+str(turn_no)+':'+sys_act
            if new_state:
                turn_state_record[turn_state_str] = {act_state_str: {'num': 1, 'raw_acts': [raw_sys_rec], 'no_order_act': [no_order_act],
                                                                         'user': [turn['user']], 'resp': [turn['resp']]}}
                golden_acts[turn_state_str] = {'act_span': raw_sys_rec, 'no_order_act': no_order_act}
                turn_state_count[turn_state_str] = 1
            else:
                turn_state_count[turn_state_str] += 1
                if act_state_str in turn_state_record[turn_state_str]:
                    if no_order_act == golden_acts[turn_state_str]['no_order_act']:
                        continue
                    if no_order_act in turn_state_record[turn_state_str][act_state_str]['no_order_act']:
                        continue
                    turn_state_record[turn_state_str][act_state_str]['num'] +=1
                    turn_state_record[turn_state_str][act_state_str]['raw_acts'].append(raw_sys_rec)
                    turn_state_record[turn_state_str][act_state_str]['user'].append(turn['user'])
                    turn_state_record[turn_state_str][act_state_str]['resp'].append(turn['resp'])
                    turn_state_record[turn_state_str][act_state_str]['no_order_act'].append(no_order_act)
                else:
                    turn_state_record[turn_state_str][act_state_str] = {'num': 1, 'raw_acts': [raw_sys_rec], 'no_order_act': [no_order_act],
                                                                                                    'user': [turn['user']], 'resp': [turn['resp']]}
    for state, acts in turn_state_record.items():
        turn_state_record[state] = OrderedDict(sorted(acts.items(), key=lambda i:i[1]['num'], reverse=True))

    # print(mode)
    print('dialog count:', dial_count, 'turn count: ',turn_count)
    print('state count:', len(turn_state_record))
    print('raw act span count:', len(act_state_detail))
    print('act state count:', len(act_state_collect))


    for fn, dial in data.items():
        if fn in reader.dev_files or fn in reader.test_files:
            continue
        dial_count += 1
        for turn_no, turn in enumerate(dial['log']):
            if turn_no not in state_valid_acts[fn]:
                continue
            state = state_valid_acts[fn][turn_no]['state']
            gold_act_type = list(state_valid_acts[fn][turn_no]['gold'].keys())[0]
            state_valid_acts[fn][turn_no]['other'] = {}
            if state in turn_state_record:
                for act_type in turn_state_record[state]:
                    if act_type == gold_act_type:
                        continue
                    state_valid_acts[fn][turn_no]['other'][act_type] = []
                    for idx, a in enumerate(turn_state_record[state][act_type]['raw_acts']):
                        m = {'act': a}
                        m['resp'] = turn_state_record[state][act_type]['resp'][idx]
                        state_valid_acts[fn][turn_no]['other'][act_type].append(m)

    # sub_state_valid_acts = {}
    # count = 0
    # for fn, dial in state_valid_acts.items():
    #     if 'mul' in fn and fn not in reader.test_files and count<=100:
    #         sub_state_valid_acts[fn] = dial
    #         count += 1
    #     if count >100:
    #         break
    # with open('data/multi-woz-processed/example_multi_act_dialogs.json', 'w') as f:
    #     json.dump(sub_state_valid_acts, f, indent=2)

    idx_save = {}
    act_span_save = {}
    hist = []
    for fn, dial in state_valid_acts.items():
        if fn in reader.dev_files or fn in reader.test_files:
            continue
        act_span_save[fn] = {}
        idx_save[fn] = {}
        for turn_num, turn in dial.items():
            act_span_save[fn][turn_num] = {}
            idx_save[fn][turn_num] = []
            for act_type, acts in turn['other'].items():
                hist.append(len(acts)+1)
                act_span_save[fn][turn_num][act_type] = [a['act'].split(':')[1] for a in acts]
                idx_save[fn][turn_num].append([a['act'].split(':')[0] for a in acts])


    with open('data/multi-woz-processed/multi_act_mapping_%s.json'%mode, 'w') as f:
        json.dump(act_span_save, f, indent=2)