Exemple #1
0
def delexicaliseReferenceNumber(sent, turn):
    """Based on the belief state, we can find reference number that
    during data gathering was created randomly."""
    domains = [
        'restaurant', 'hotel', 'attraction', 'train', 'taxi', 'hospital'
    ]  # , 'police']
    if turn['metadata']:
        for domain in domains:
            if turn['metadata'][domain]['book']['booked']:
                for slot in turn['metadata'][domain]['book']['booked'][0]:
                    if slot == 'reference':
                        val = '[' + domain + '_' + slot + ']'
                    else:
                        val = '[' + domain + '_' + slot + ']'
                    key = normalize(
                        turn['metadata'][domain]['book']['booked'][0][slot])
                    sent = (' ' + sent + ' ').replace(' ' + key + ' ',
                                                      ' ' + val + ' ')

                    # try reference with hashtag
                    key = normalize(
                        "#" +
                        turn['metadata'][domain]['book']['booked'][0][slot])
                    sent = (' ' + sent + ' ').replace(' ' + key + ' ',
                                                      ' ' + val + ' ')

                    # try reference with ref#
                    key = normalize(
                        "ref#" +
                        turn['metadata'][domain]['book']['booked'][0][slot])
                    sent = (' ' + sent + ' ').replace(' ' + key + ' ',
                                                      ' ' + val + ' ')
    return sent
Exemple #2
0
def queryResultVenues(domain, turn, real_belief=False):
    # query the db
    sql_query = "select * from {}".format(domain)

    if real_belief == True:
        items = turn.items()
    elif real_belief == 'tracking':
        for slot in turn[domain]:
            key = slot[0].split("-")[1]
            val = slot[0].split("-")[2]
            if key == "price range":
                key = "pricerange"
            elif key == "leave at":
                key = "leaveAt"
            elif key == "arrive by":
                key = "arriveBy"
            if val == "do n't care":
                pass
            else:
                if flag:
                    sql_query += " where "
                    val2 = val.replace("'", "''")
                    val2 = normalize(val2)
                    if key == 'leaveAt':
                        sql_query += key + " > " + r"'" + val2 + r"'"
                    elif key == 'arriveBy':
                        sql_query += key + " < " + r"'" + val2 + r"'"
                    else:
                        sql_query += r" " + key + "=" + r"'" + val2 + r"'"
                    flag = False
                else:
                    val2 = val.replace("'", "''")
                    val2 = normalize(val2)
                    if key == 'leaveAt':
                        sql_query += r" and " + key + " > " + r"'" + val2 + r"'"
                    elif key == 'arriveBy':
                        sql_query += r" and " + key + " < " + r"'" + val2 + r"'"
                    else:
                        sql_query += r" and " + key + "=" + r"'" + val2 + r"'"

            try:  # "select * from attraction  where name = 'queens college'"
                return dbs[domain].execute(sql_query).fetchall()
            except:
                return []  # TODO test it
        pass
    else:
        items = turn['metadata'][domain]['semi'].items()

    flag = True
    for key, val in items:
        if val == "" or val == "dontcare" or val == 'not mentioned' or val == "don't care" or val == "dont care" or val == "do n't care":
            pass
        else:
            if flag:
                sql_query += " where "
                val2 = val.replace("'", "''")
                val2 = normalize(val2)
                if key == 'leaveAt':
                    sql_query += r" " + key + " > " + r"'" + val2 + r"'"
                elif key == 'arriveBy':
                    sql_query += r" " + key + " < " + r"'" + val2 + r"'"
                else:
                    sql_query += r" " + key + "=" + r"'" + val2 + r"'"
                flag = False
            else:
                val2 = val.replace("'", "''")
                val2 = normalize(val2)
                if key == 'leaveAt':
                    sql_query += r" and " + key + " > " + r"'" + val2 + r"'"
                elif key == 'arriveBy':
                    sql_query += r" and " + key + " < " + r"'" + val2 + r"'"
                else:
                    sql_query += r" and " + key + "=" + r"'" + val2 + r"'"

    try:  # "select * from attraction  where name = 'queens college'"
        return dbs[domain].execute(sql_query).fetchall()
    except:
        raise
        return []  # TODO test it
Exemple #3
0
def predict(model, prev_state, prev_active_domain, state, dic):
    start_time = time.time()
    model.beam_search = False
    input_tensor = []
    bs_tensor = []
    db_tensor = []

    usr = state['history'][-1][-1]

    prev_state = deepcopy(prev_state['belief_state'])
    state = deepcopy(state['belief_state'])

    mark_not_mentioned(prev_state)
    mark_not_mentioned(state)

    words = usr.split()
    usr = delexicalize.delexicalise(' '.join(words), dic)

    # parsing reference number GIVEN belief state
    usr = delexicaliseReferenceNumber(usr, state)

    # changes to numbers only here
    digitpat = re.compile('\d+')
    usr = re.sub(digitpat, '[value_count]', usr)
    # dialogue = fixDelex(dialogue_name, dialogue, data2, idx, idx_acts)

    # add database pointer
    pointer_vector, top_results, num_results = addDBPointer(state)
    # add booking pointer
    pointer_vector = addBookingPointer(state, pointer_vector)
    belief_summary = get_summary_bstate(state)

    tensor = [
        model.input_word2index(word)
        for word in normalize(usr).strip(' ').split(' ')
    ] + [util.EOS_token]
    input_tensor.append(torch.LongTensor(tensor))
    bs_tensor.append(belief_summary)  #
    db_tensor.append(pointer_vector)  # db results and booking
    # bs_tensor.append([0.] * 94) #
    # db_tensor.append([0.] * 30) # db results and booking
    # create an empty matrix with padding tokens
    input_tensor, input_lengths = util.padSequence(input_tensor)
    bs_tensor = torch.tensor(bs_tensor, dtype=torch.float, device=device)
    db_tensor = torch.tensor(db_tensor, dtype=torch.float, device=device)

    output_words, loss_sentence = model.predict(input_tensor, input_lengths,
                                                input_tensor, input_lengths,
                                                db_tensor, bs_tensor)
    active_domain = get_active_domain(prev_active_domain, prev_state, state)
    if active_domain is not None and active_domain in num_results:
        num_results = num_results[active_domain]
    else:
        num_results = 0
    if active_domain is not None and active_domain in top_results:
        top_results = {active_domain: top_results[active_domain]}
    else:
        top_results = {}
    response = populate_template(output_words[0], top_results, num_results,
                                 state)
    return response, active_domain
Exemple #4
0
def createDelexData():
    """Main function of the script - loads delexical dictionary,
    goes through each dialogue and does:
    1) data normalization
    2) delexicalization
    3) addition of database pointer
    4) saves the delexicalized data
    """
    # download the data
    loadData()

    # create dictionary of delexicalied values that then we will search against, order matters here!
    dic = delexicalize.prepareSlotValuesIndependent()
    delex_data = {}

    # fin1 = file('data/multi-woz/data.json')
    fin1 = open('data/multi-woz/data.json')
    data = json.load(fin1)

    # fin2 = file('data/multi-woz/dialogue_acts.json')
    fin2 = open('data/multi-woz/dialogue_acts.json')
    data2 = json.load(fin2)

    num = 0
    for dialogue_name in tqdm(data):
        dialogue = data[dialogue_name]
        #print dialogue_name

        idx_acts = 1

        for idx, turn in enumerate(dialogue['log']):
            # normalization, split and delexicalization of the sentence
            sent = normalize(turn['text'])

            words = sent.split()
            sent = delexicalize.delexicalise(' '.join(words), dic)

            # parsing reference number GIVEN belief state
            sent = delexicaliseReferenceNumber(sent, turn)

            # changes to numbers only here
            digitpat = re.compile('\d+')
            sent = re.sub(digitpat, '[value_count]', sent)

            # delexicalized sentence added to the dialogue
            dialogue['log'][idx]['text'] = sent

            if idx % 2 == 1:  # if it's a system turn
                # add database pointer
                pointer_vector = addDBPointer(turn)
                # add booking pointer
                pointer_vector = addBookingPointer(dialogue, turn,
                                                   pointer_vector)

                #print pointer_vector
                dialogue['log'][idx -
                                1]['db_pointer'] = pointer_vector.tolist()

            # FIXING delexicalization:
            dialogue = fixDelex(dialogue_name, dialogue, data2, idx, idx_acts)
            idx_acts += 1

        delex_data[dialogue_name] = dialogue
        # num += 1
        # if num > 100:
        #     break

    with open('data/multi-woz/delex.json', 'w') as outfile:
        json.dump(delex_data, outfile)

    return delex_data
def prepareSlotValuesIndependent():
    domains = [
        'restaurant', 'hotel', 'attraction', 'train', 'taxi', 'hospital',
        'police'
    ]
    requestables = ['phone', 'address', 'postcode', 'reference', 'id']
    dic = []
    dic_area = []
    dic_food = []
    dic_price = []

    # read databases
    for domain in domains:
        try:
            fin = open('db/' + domain + '_db.json')
            db_json = json.load(fin)
            fin.close()

            for ent in db_json:
                for key, val in ent.items():
                    if val == '?' or val == 'free':
                        pass
                    elif key == 'address':
                        dic.append((normalize(val),
                                    '[' + domain + '_' + 'address' + ']'))
                        if "road" in val:
                            val = val.replace("road", "rd")
                            dic.append((normalize(val),
                                        '[' + domain + '_' + 'address' + ']'))
                        elif "rd" in val:
                            val = val.replace("rd", "road")
                            dic.append((normalize(val),
                                        '[' + domain + '_' + 'address' + ']'))
                        elif "st" in val:
                            val = val.replace("st", "street")
                            dic.append((normalize(val),
                                        '[' + domain + '_' + 'address' + ']'))
                        elif "street" in val:
                            val = val.replace("street", "st")
                            dic.append((normalize(val),
                                        '[' + domain + '_' + 'address' + ']'))
                    elif key == 'name':
                        dic.append((normalize(val),
                                    '[' + domain + '_' + 'name' + ']'))
                        if "b & b" in val:
                            val = val.replace("b & b", "bed and breakfast")
                            dic.append((normalize(val),
                                        '[' + domain + '_' + 'name' + ']'))
                        elif "bed and breakfast" in val:
                            val = val.replace("bed and breakfast", "b & b")
                            dic.append((normalize(val),
                                        '[' + domain + '_' + 'name' + ']'))
                        elif "hotel" in val and 'gonville' not in val:
                            val = val.replace("hotel", "")
                            dic.append((normalize(val),
                                        '[' + domain + '_' + 'name' + ']'))
                        elif "restaurant" in val:
                            val = val.replace("restaurant", "")
                            dic.append((normalize(val),
                                        '[' + domain + '_' + 'name' + ']'))
                    elif key == 'postcode':
                        dic.append((normalize(val),
                                    '[' + domain + '_' + 'postcode' + ']'))
                    elif key == 'phone':
                        dic.append((val, '[' + domain + '_' + 'phone' + ']'))
                    elif key == 'trainID':
                        dic.append(
                            (normalize(val), '[' + domain + '_' + 'id' + ']'))
                    elif key == 'department':
                        dic.append((normalize(val),
                                    '[' + domain + '_' + 'department' + ']'))

                    # NORMAL DELEX
                    elif key == 'area':
                        dic_area.append((normalize(val),
                                         '[' + 'value' + '_' + 'area' + ']'))
                    elif key == 'food':
                        dic_food.append((normalize(val),
                                         '[' + 'value' + '_' + 'food' + ']'))
                    elif key == 'pricerange':
                        dic_price.append(
                            (normalize(val),
                             '[' + 'value' + '_' + 'pricerange' + ']'))
                    else:
                        pass
                    # TODO car type?
        except:
            pass

        if domain == 'hospital':
            dic.append(
                (normalize('Hills Rd'), '[' + domain + '_' + 'address' + ']'))
            dic.append((normalize('Hills Road'),
                        '[' + domain + '_' + 'address' + ']'))
            dic.append(
                (normalize('CB20QQ'), '[' + domain + '_' + 'postcode' + ']'))
            dic.append(('01223245151', '[' + domain + '_' + 'phone' + ']'))
            dic.append(('1223245151', '[' + domain + '_' + 'phone' + ']'))
            dic.append(('0122324515', '[' + domain + '_' + 'phone' + ']'))
            dic.append((normalize('Addenbrookes Hospital'),
                        '[' + domain + '_' + 'name' + ']'))

        elif domain == 'police':
            dic.append(
                (normalize('Parkside'), '[' + domain + '_' + 'address' + ']'))
            dic.append(
                (normalize('CB11JG'), '[' + domain + '_' + 'postcode' + ']'))
            dic.append(('01223358966', '[' + domain + '_' + 'phone' + ']'))
            dic.append(('1223358966', '[' + domain + '_' + 'phone' + ']'))
            dic.append((normalize('Parkside Police Station'),
                        '[' + domain + '_' + 'name' + ']'))

    # add at the end places from trains
    # fin = file('db/' + 'train' + '_db.json')
    fin = open('db/' + 'train' + '_db.json')
    db_json = json.load(fin)
    fin.close()

    for ent in db_json:
        for key, val in ent.items():
            if key == 'departure' or key == 'destination':
                dic.append(
                    (normalize(val), '[' + 'value' + '_' + 'place' + ']'))

    # add specific values:
    for key in [
            'monday', 'tuesday', 'wednesday', 'thursday', 'friday', 'saturday',
            'sunday'
    ]:
        dic.append((normalize(key), '[' + 'value' + '_' + 'day' + ']'))

    # more general values add at the end
    dic.extend(dic_area)
    dic.extend(dic_food)
    dic.extend(dic_price)

    return dic