def prepare_data_seq(task, batch_size=100, shuffle=True): file_train = 'data/dialog-bAbI-tasks/dialog-babi-task{}trn.txt'.format( task) file_dev = 'data/dialog-bAbI-tasks/dialog-babi-task{}dev.txt'.format(task) file_test = 'data/dialog-bAbI-tasks/dialog-babi-task{}tst.txt'.format(task) if (int(task) != 6): file_test_OOV = 'data/dialog-bAbI-tasks/dialog-babi-task{}tst-OOV.txt'.format( task) if int(task) != 6: ent = entityList('data/dialog-bAbI-tasks/dialog-babi-kb-all.txt', int(task)) else: ent = entityList( 'data/dialog-bAbI-tasks/dialog-babi-task6-dstc2-kb.txt', int(task)) pair_train, max_len_train, max_r_train = read_langs(file_train, ent, max_line=None) pair_dev, max_len_dev, max_r_dev = read_langs(file_dev, ent, max_line=None) pair_test, max_len_test, max_r_test = read_langs(file_test, ent, max_line=None) max_r_test_OOV = 0 max_len_test_OOV = 0 if (int(task) != 6): pair_test_OOV, max_len_test_OOV, max_r_test_OOV = read_langs( file_test_OOV, ent, max_line=None) max_len = max(max_len_train, max_len_dev, max_len_test, max_len_test_OOV) + 1 max_r = max(max_r_train, max_r_dev, max_r_test, max_r_test_OOV) + 1 lang = Lang() train = get_seq(pair_train, lang, batch_size, True, max_len) dev = get_seq(pair_dev, lang, batch_size, False, max_len) test = get_seq(pair_test, lang, batch_size, False, max_len) if (int(task) != 6): testOOV = get_seq(pair_test_OOV, lang, batch_size, False, max_len) else: testOOV = [] logging.info("Read %s sentence pairs train" % len(pair_train)) logging.info("Read %s sentence pairs dev" % len(pair_dev)) logging.info("Read %s sentence pairs test" % len(pair_test)) if (int(task) != 6): logging.info("Read %s sentence pairs test" % len(pair_test_OOV)) logging.info("Max len Input %s " % max_len) logging.info("Vocab_size %s " % lang.n_words) logging.info("USE_CUDA={}".format(USE_CUDA)) return train, dev, test, testOOV, lang, max_len, max_r
return lang, max_len, max_r if __name__=='__main__': '''-lr=0.001 -layer=1 -hdd=12 -dr=0.0 -dec=Mem2Seq -bsz=2 -ds=babi -t=1 ''' # args['task'] = 5 # args['hidden'] = 12 # args['learn'] = 0.001 # args['layer'] = 1 # args['drop'] = 0.0 # args['decoder'] = 'Men2Seq' # args['batch'] = 2 # args['dataset'] = 'babi' task = 5 file_train = '../data/dialog-bAbI-tasks/dialog-babi-task{}trn.txt'.format(task) ent = entityList('../data/dialog-bAbI-tasks/dialog-babi-kb-all.txt', int(task)) pair_train,max_len_train, max_r_train = read_langs(file_train, ent, max_line=None) max_len = max_len_train + 1 batch_size = 100 lang = Lang() train = get_seq(pair_train, lang, batch_size, True, max_len) epoch = 1 logging.info("Epoch:{}".format(epoch)) # Run the train function import tqdm pbar = tqdm(enumerate(train), total=len(train)) for i, data in pbar:
def evaluate(self, dev, avg_best, BLEU=False): logging.info("STARTING EVALUATION") acc_avg = 0.0 wer_avg = 0.0 bleu_avg = 0.0 acc_P = 0.0 acc_V = 0.0 microF1_PRED, microF1_PRED_cal, microF1_PRED_nav, microF1_PRED_wet = 0, 0, 0, 0 microF1_TRUE, microF1_TRUE_cal, microF1_TRUE_nav, microF1_TRUE_wet = 0, 0, 0, 0 ref = [] hyp = [] ref_s = "" hyp_s = "" dialog_acc_dict = {} if args['dataset'] == 'kvr': with open('data/KVR/kvret_entities.json') as f: global_entity = json.load(f) global_entity_list = [] for key in global_entity.keys(): if key != 'poi': global_entity_list += [ item.lower().replace(' ', '_') for item in global_entity[key] ] else: for item in global_entity['poi']: global_entity_list += [ item[k].lower().replace(' ', '_') for k in item.keys() ] global_entity_list = list(set(global_entity_list)) else: if int(args["task"]) != 6: global_entity_list = entityList( 'data/dialog-bAbI-tasks/dialog-babi-kb-all.txt', int(args["task"])) else: global_entity_list = entityList( 'data/dialog-bAbI-tasks/dialog-babi-task6-dstc2-kb.txt', int(args["task"])) pbar = tqdm(enumerate(dev), total=len(dev)) for j, data_dev in pbar: if args['dataset'] == 'kvr': words = self.evaluate_batch(len(data_dev[1]), data_dev[0], data_dev[1], data_dev[2], data_dev[3], data_dev[4], data_dev[5], data_dev[6]) else: words = self.evaluate_batch(len(data_dev[1]), data_dev[0], data_dev[1], data_dev[2], data_dev[3], data_dev[4], data_dev[5], data_dev[6]) acc = 0 w = 0 temp_gen = [] for i, row in enumerate(np.transpose(words)): st = '' for e in row: if e == '<EOS>': break else: st += e + ' ' temp_gen.append(st) correct = data_dev[7][i] ### compute F1 SCORE st = st.lstrip().rstrip() correct = correct.lstrip().rstrip() if args['dataset'] == 'kvr': f1_true, count = self.compute_prf(data_dev[8][i], st.split(), global_entity_list, data_dev[14][i]) microF1_TRUE += f1_true microF1_PRED += count f1_true, count = self.compute_prf(data_dev[9][i], st.split(), global_entity_list, data_dev[14][i]) microF1_TRUE_cal += f1_true microF1_PRED_cal += count f1_true, count = self.compute_prf(data_dev[10][i], st.split(), global_entity_list, data_dev[14][i]) microF1_TRUE_nav += f1_true microF1_PRED_nav += count f1_true, count = self.compute_prf(data_dev[11][i], st.split(), global_entity_list, data_dev[14][i]) microF1_TRUE_wet += f1_true microF1_PRED_wet += count elif args['dataset'] == 'babi' and int(args["task"]) == 6: f1_true, count = self.compute_prf(data_dev[10][i], st.split(), global_entity_list, data_dev[12][i]) microF1_TRUE += f1_true microF1_PRED += count if args['dataset'] == 'babi': if data_dev[11][i] not in dialog_acc_dict.keys(): dialog_acc_dict[data_dev[11][i]] = [] if (correct == st): acc += 1 dialog_acc_dict[data_dev[11][i]].append(1) else: dialog_acc_dict[data_dev[11][i]].append(0) else: if (correct == st): acc += 1 # print("Correct:"+str(correct)) # print("\tPredict:"+str(st)) # print("\tFrom:"+str(self.from_whichs[:,i])) w += wer(correct, st) ref.append(str(correct)) hyp.append(str(st)) ref_s += str(correct) + "\n" hyp_s += str(st) + "\n" acc_avg += acc / float(len(data_dev[1])) wer_avg += w / float(len(data_dev[1])) pbar.set_description("R:{:.4f},W:{:.4f}".format( acc_avg / float(len(dev)), wer_avg / float(len(dev)))) # dialog accuracy if args['dataset'] == 'babi': dia_acc = 0 for k in dialog_acc_dict.keys(): if len(dialog_acc_dict[k]) == sum(dialog_acc_dict[k]): dia_acc += 1 logging.info("Dialog Accuracy:\t" + str(dia_acc * 1.0 / len(dialog_acc_dict.keys()))) if args['dataset'] == 'kvr': logging.info("F1 SCORE:\t{}".format(microF1_TRUE / float(microF1_PRED))) logging.info("\tCAL F1:\t{}".format(microF1_TRUE_cal / float(microF1_PRED_cal))) logging.info("\tWET F1:\t{}".format(microF1_TRUE_wet / float(microF1_PRED_wet))) logging.info("\tNAV F1:\t{}".format(microF1_TRUE_nav / float(microF1_PRED_nav))) elif args['dataset'] == 'babi' and int(args["task"]) == 6: logging.info("F1 SCORE:\t{}".format(microF1_TRUE / float(microF1_PRED))) bleu_score = moses_multi_bleu(np.array(hyp), np.array(ref), lowercase=True) logging.info("BLEU SCORE:" + str(bleu_score)) if (BLEU): if (bleu_score >= avg_best): self.save_model(str(self.name) + str(bleu_score)) logging.info("MODEL SAVED") return bleu_score else: acc_avg = acc_avg / float(len(dev)) if (acc_avg >= avg_best): self.save_model(str(self.name) + str(acc_avg)) logging.info("MODEL SAVED") return acc_avg