def __init__(self, num=1): parser = argparse.ArgumentParser(description='S2S') parser.add_argument('--no_cuda', type=util.str2bool, nargs='?', const=True, default=True, help='enables CUDA training') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument('--no_models', type=int, default=20, help='how many models to evaluate') parser.add_argument('--original', type=str, default='model/model/', help='Original path.') parser.add_argument('--dropout', type=float, default=0.0) parser.add_argument('--use_emb', type=str, default='False') parser.add_argument('--beam_width', type=int, default=10, help='Beam width used in beamsearch') parser.add_argument('--write_n_best', type=util.str2bool, nargs='?', const=True, default=False, help='Write n-best list (n=beam_width)') parser.add_argument('--model_path', type=str, default='model/model/translate.ckpt', help='Path to a specific model checkpoint.') parser.add_argument('--model_dir', type=str, default='data/multi-woz/model/model/') parser.add_argument('--model_name', type=str, default='translate.ckpt') parser.add_argument('--valid_output', type=str, default='model/data/val_dials/', help='Validation Decoding output dir path') parser.add_argument('--decode_output', type=str, default='model/data/test_dials/', help='Decoding output dir path') args = parser.parse_args([]) args.cuda = not args.no_cuda and torch.cuda.is_available() torch.manual_seed(args.seed) self.device = torch.device("cuda" if args.cuda else "cpu") with open( os.path.join(os.path.dirname(__file__), args.model_path + '.config'), 'r') as f: add_args = json.load(f) # print(add_args) for k, v in add_args.items(): setattr(args, k, v) # print(args) args.mode = 'test' args.load_param = True args.dropout = 0.0 assert args.dropout == 0.0 # Start going through models args.original = args.model_path args.model_path = args.original self.model = loadModel(num, args) self.dial = {"cur": {"log": []}} self.prev_state = default_state() self.prev_active_domain = None self.dic = delexicalize.prepareSlotValuesIndependent() self.db = Database()
def createDelexData(dialogue): """Main function of the script - loads delexical dictionary, goes through each dialogue and does: 1) data normalization 2) delexicalization 3) addition of database pointer 4) saves the delexicalized data """ # create dictionary of delexicalied values that then we will search against, order matters here! dic = delexicalize.prepareSlotValuesIndependent() delex_data = {} # fin1 = open('data/multi-woz/data.json', 'r') # data = json.load(fin1) # dialogue = data[dialogue_name] dial = dialogue['cur'] idx_acts = 1 for idx, turn in enumerate(dial['log']): # print(idx) # print(turn) # normalization, split and delexicalization of the sentence sent = normalize(turn['text']) words = sent.split() sent = delexicalize.delexicalise(' '.join(words), dic) # parsing reference number GIVEN belief state sent = delexicaliseReferenceNumber(sent, turn) # changes to numbers only here digitpat = re.compile('\d+') sent = re.sub(digitpat, '[value_count]', sent) # print(sent) # delexicalized sentence added to the dialogue dial['log'][idx]['text'] = sent if idx % 2 == 1: # if it's a system turn # add database pointer pointer_vector, db_results, num_entities = addDBPointer(turn) # add booking pointer pointer_vector = addBookingPointer(dial, turn, pointer_vector) # print pointer_vector dial['log'][idx - 1]['db_pointer'] = pointer_vector.tolist() idx_acts += 1 dial = get_dial(dial) if dial: dialogue = {} dialogue['usr'] = [] dialogue['sys'] = [] dialogue['db'] = [] dialogue['bs'] = [] for turn in dial: # print(turn) dialogue['usr'].append(turn[0]) dialogue['sys'].append(turn[1]) dialogue['db'].append(turn[2]) dialogue['bs'].append(turn[3]) delex_data['cur'] = dialogue return delex_data
def createDelexData(): """Main function of the script - loads delexical dictionary, goes through each dialogue and does: 1) data normalization 2) delexicalization 3) addition of database pointer 4) saves the delexicalized data """ # download the data loadData() # create dictionary of delexicalied values that then we will search against, order matters here! dic = delexicalize.prepareSlotValuesIndependent() delex_data = {} fin1 = open('data/multi-woz/data.json', 'r') data = json.load(fin1) fin2 = open('data/multi-woz/dialogue_acts.json', 'r') data2 = json.load(fin2) for dialogue_name in tqdm(data): dialogue = data[dialogue_name] # print dialogue_name idx_acts = 1 for idx, turn in enumerate(dialogue['log']): # normalization, split and delexicalization of the sentence sent = normalize(turn['text']) words = sent.split() sent = delexicalize.delexicalise(' '.join(words), dic) # parsing reference number GIVEN belief state sent = delexicaliseReferenceNumber(sent, turn) # changes to numbers only here digitpat = re.compile('\d+') sent = re.sub(digitpat, '[value_count]', sent) # delexicalized sentence added to the dialogue dialogue['log'][idx]['text'] = sent if idx % 2 == 1: # if it's a system turn # add database pointer pointer_vector = addDBPointer(turn) # add booking pointer pointer_vector = addBookingPointer(dialogue, turn, pointer_vector) # print pointer_vector dialogue['log'][idx - 1]['db_pointer'] = pointer_vector.tolist() # FIXING delexicalization: dialogue = fixDelex(dialogue_name, dialogue, data2, idx, idx_acts) idx_acts += 1 delex_data[dialogue_name] = dialogue with open('data/multi-woz/delex.json', 'w') as outfile: json.dump(delex_data, outfile) return delex_data