Exemplo n.º 1
0
    def __init__(self, num=1):
        parser = argparse.ArgumentParser(description='S2S')
        parser.add_argument('--no_cuda',
                            type=util.str2bool,
                            nargs='?',
                            const=True,
                            default=True,
                            help='enables CUDA training')
        parser.add_argument('--seed',
                            type=int,
                            default=1,
                            metavar='S',
                            help='random seed (default: 1)')

        parser.add_argument('--no_models',
                            type=int,
                            default=20,
                            help='how many models to evaluate')
        parser.add_argument('--original',
                            type=str,
                            default='model/model/',
                            help='Original path.')

        parser.add_argument('--dropout', type=float, default=0.0)
        parser.add_argument('--use_emb', type=str, default='False')

        parser.add_argument('--beam_width',
                            type=int,
                            default=10,
                            help='Beam width used in beamsearch')
        parser.add_argument('--write_n_best',
                            type=util.str2bool,
                            nargs='?',
                            const=True,
                            default=False,
                            help='Write n-best list (n=beam_width)')

        parser.add_argument('--model_path',
                            type=str,
                            default='model/model/translate.ckpt',
                            help='Path to a specific model checkpoint.')
        parser.add_argument('--model_dir',
                            type=str,
                            default='data/multi-woz/model/model/')
        parser.add_argument('--model_name', type=str, default='translate.ckpt')

        parser.add_argument('--valid_output',
                            type=str,
                            default='model/data/val_dials/',
                            help='Validation Decoding output dir path')
        parser.add_argument('--decode_output',
                            type=str,
                            default='model/data/test_dials/',
                            help='Decoding output dir path')

        args = parser.parse_args([])

        args.cuda = not args.no_cuda and torch.cuda.is_available()

        torch.manual_seed(args.seed)

        self.device = torch.device("cuda" if args.cuda else "cpu")
        with open(
                os.path.join(os.path.dirname(__file__),
                             args.model_path + '.config'), 'r') as f:
            add_args = json.load(f)
            # print(add_args)
            for k, v in add_args.items():
                setattr(args, k, v)
            # print(args)
            args.mode = 'test'
            args.load_param = True
            args.dropout = 0.0
            assert args.dropout == 0.0

        # Start going through models
        args.original = args.model_path
        args.model_path = args.original
        self.model = loadModel(num, args)
        self.dial = {"cur": {"log": []}}
        self.prev_state = default_state()
        self.prev_active_domain = None
        self.dic = delexicalize.prepareSlotValuesIndependent()
        self.db = Database()
Exemplo n.º 2
0
def createDelexData(dialogue):
    """Main function of the script - loads delexical dictionary,
    goes through each dialogue and does:
    1) data normalization
    2) delexicalization
    3) addition of database pointer
    4) saves the delexicalized data
    """

    # create dictionary of delexicalied values that then we will search against, order matters here!
    dic = delexicalize.prepareSlotValuesIndependent()
    delex_data = {}

    # fin1 = open('data/multi-woz/data.json', 'r')
    # data = json.load(fin1)

    # dialogue = data[dialogue_name]
    dial = dialogue['cur']
    idx_acts = 1

    for idx, turn in enumerate(dial['log']):
        # print(idx)
        # print(turn)
        # normalization, split and delexicalization of the sentence
        sent = normalize(turn['text'])

        words = sent.split()
        sent = delexicalize.delexicalise(' '.join(words), dic)

        # parsing reference number GIVEN belief state
        sent = delexicaliseReferenceNumber(sent, turn)

        # changes to numbers only here
        digitpat = re.compile('\d+')
        sent = re.sub(digitpat, '[value_count]', sent)
        # print(sent)

        # delexicalized sentence added to the dialogue
        dial['log'][idx]['text'] = sent

        if idx % 2 == 1:  # if it's a system turn
            # add database pointer
            pointer_vector, db_results, num_entities = addDBPointer(turn)
            # add booking pointer
            pointer_vector = addBookingPointer(dial, turn, pointer_vector)

            # print pointer_vector
            dial['log'][idx - 1]['db_pointer'] = pointer_vector.tolist()

        idx_acts += 1
    dial = get_dial(dial)

    if dial:
        dialogue = {}
        dialogue['usr'] = []
        dialogue['sys'] = []
        dialogue['db'] = []
        dialogue['bs'] = []
        for turn in dial:
            # print(turn)
            dialogue['usr'].append(turn[0])
            dialogue['sys'].append(turn[1])
            dialogue['db'].append(turn[2])
            dialogue['bs'].append(turn[3])

    delex_data['cur'] = dialogue

    return delex_data
def createDelexData():
    """Main function of the script - loads delexical dictionary,
    goes through each dialogue and does:
    1) data normalization
    2) delexicalization
    3) addition of database pointer
    4) saves the delexicalized data
    """
    # download the data
    loadData()

    # create dictionary of delexicalied values that then we will search against, order matters here!
    dic = delexicalize.prepareSlotValuesIndependent()
    delex_data = {}

    fin1 = open('data/multi-woz/data.json', 'r')
    data = json.load(fin1)

    fin2 = open('data/multi-woz/dialogue_acts.json', 'r')
    data2 = json.load(fin2)

    for dialogue_name in tqdm(data):
        dialogue = data[dialogue_name]
        # print dialogue_name

        idx_acts = 1

        for idx, turn in enumerate(dialogue['log']):
            # normalization, split and delexicalization of the sentence
            sent = normalize(turn['text'])

            words = sent.split()
            sent = delexicalize.delexicalise(' '.join(words), dic)

            # parsing reference number GIVEN belief state
            sent = delexicaliseReferenceNumber(sent, turn)

            # changes to numbers only here
            digitpat = re.compile('\d+')
            sent = re.sub(digitpat, '[value_count]', sent)

            # delexicalized sentence added to the dialogue
            dialogue['log'][idx]['text'] = sent

            if idx % 2 == 1:  # if it's a system turn
                # add database pointer
                pointer_vector = addDBPointer(turn)
                # add booking pointer
                pointer_vector = addBookingPointer(dialogue, turn,
                                                   pointer_vector)

                # print pointer_vector
                dialogue['log'][idx -
                                1]['db_pointer'] = pointer_vector.tolist()

            # FIXING delexicalization:
            dialogue = fixDelex(dialogue_name, dialogue, data2, idx, idx_acts)
            idx_acts += 1

        delex_data[dialogue_name] = dialogue

    with open('data/multi-woz/delex.json', 'w') as outfile:
        json.dump(delex_data, outfile)

    return delex_data