def main(args): # parse args args = parse_args(args) # prepare data if args['prep_data']: print('\n ====== preparing data ====== \n') for i in range(1, 7): print(' TASK #{}\n'.format(i)) prepare_data(args, task_id=i) sys.exit() ################################################################## # 데이터 준비가 아니면 read data and metadata from pickled files ################################################################## with open(P_DATA_DIR + str(args['task_id']) + '.metadata.pkl', 'rb') as f: metadata = pkl.load(f) with open(P_DATA_DIR + str(args['task_id']) + '.data.pkl', 'rb') as f: data_ = pkl.load(f) # read content of data and metadata candidates = data_['candidates'] candid2idx, idx2candid = metadata['candid2idx'], metadata['idx2candid'] # read train, test, val data train, test, val = data_['train'], data_['test'], data_['val'] # get more required information from metadata sentence_size = metadata['sentence_size'] w2idx = metadata['w2idx'] idx2w = metadata['idx2w'] memory_size = metadata['memory_size'] vocab_size = metadata['vocab_size'] n_cand = metadata['n_cand'] candidate_sentence_size = metadata['candidate_sentence_size'] # 후보 response들의 백터화 candidates_vec = data_utils.vectorize_candidates(candidates, w2idx, candidate_sentence_size) # create model - memn2n model = memn2n.MemN2NDialog(batch_size=BATCH_SIZE, vocab_size=vocab_size, candidates_size=n_cand, sentence_size=sentence_size, embedding_size=20, candidates_vec=candidates_vec, hops=3) train, val, test, batches = data_utils.get_batches(train, val, test, metadata, batch_size=BATCH_SIZE) # training은 여기서 실행 된다. if args['train']: epochs = args['epochs'] eval_interval = args['eval_interval'] print('\n>>> Training started...\n') # write log to file log_handle = open('log/' + args['log_file'], 'w') cost_total = 0. for i in range(epochs + 1): for start, end in batches: s = train['s'][start:end] q = train['q'][start:end] a = train['a'][start:end] cost_total += model.batch_fit(s, q, a) if i % eval_interval == 0 and i: train_preds = batch_predict(model, train['s'], train['q'], len(train['s']), batch_size=BATCH_SIZE) val_preds = batch_predict(model, val['s'], val['q'], len(val['s']), batch_size=BATCH_SIZE) train_acc = metrics.accuracy_score(np.array(train_preds), train['a']) val_acc = metrics.accuracy_score(val_preds, val['a']) print( 'Epoch[{}] : <Accuracy>\n\ttraining : {} \n\tvalidation : {}' .format(i, train_acc, val_acc)) log_handle.write('{} {} {} {}\n'.format( i, train_acc, val_acc, cost_total / (eval_interval * len(batches)))) cost_total = 0. model.saver.save( model._sess, CKPT_DIR + '{}/memn2n_model.ckpt'.format(args['task_id']), global_step=i) log_handle.close() # inference else: # restore checkpoint ckpt = tf.train.get_checkpoint_state(CKPT_DIR + str(args['task_id'])) if ckpt and ckpt.model_checkpoint_path: print('\n>> restoring checkpoint from', ckpt.model_checkpoint_path) model.saver.restore(model._sess, ckpt.model_checkpoint_path) isess = InteractiveSession(model, idx2candid, w2idx, n_cand, memory_size) if args['infer']: query = '' while query != 'exit': query = input('>> ') print('>> ' + isess.reply(query))
def main(args): # parse args args = parse_args(args) # prepare data if args['prep_data']: print('\n>> Preparing Data\n') for i in range(1,7): print(' TASK#{}\n'.format(i)) prepare_data(args, task_id=i) sys.exit() # ELSE # read data and metadata from pickled files with open(P_DATA_DIR + str(args['task_id']) + '.metadata.pkl', 'rb') as f: metadata = pkl.load(f) with open(P_DATA_DIR + str(args['task_id']) + '.data.pkl', 'rb') as f: data_ = pkl.load(f) # read content of data and metadata candidates = data_['candidates'] candid2idx, idx2candid = metadata['candid2idx'], metadata['idx2candid'] # get train/test/val data train, test, val = data_['train'], data_['test'], data_['val'] # gather more information from metadata sentence_size = metadata['sentence_size'] w2idx = metadata['w2idx'] idx2w = metadata['idx2w'] memory_size = metadata['memory_size'] vocab_size = metadata['vocab_size'] n_cand = metadata['n_cand'] candidate_sentence_size = metadata['candidate_sentence_size'] # vectorize candidates candidates_vec = data_utils.vectorize_candidates(candidates, w2idx, candidate_sentence_size) ### # create model #model = model['memn2n']( # why? model = memn2n.MemN2NDialog( batch_size= BATCH_SIZE, vocab_size= vocab_size, candidates_size= n_cand, sentence_size= sentence_size, embedding_size= 20, candidates_vec= candidates_vec, hops= 3 ) # gather data in batches train, val, test, batches = data_utils.get_batches(train, val, test, metadata, batch_size=BATCH_SIZE) if args['train']: # training starts here epochs = args['epochs'] eval_interval = args['eval_interval'] # # training and evaluation loop print('\n>> Training started!\n') # write log to file log_handle = open('log/' + args['log_file'], 'w') cost_total = 0. #best_validation_accuracy = 0. for i in range(epochs+1): for start, end in batches: s = train['s'][start:end] q = train['q'][start:end] a = train['a'][start:end] cost_total += model.batch_fit(s, q, a) if i%eval_interval == 0 and i: train_preds = batch_predict(model, train['s'], train['q'], len(train['s']), batch_size=BATCH_SIZE) val_preds = batch_predict(model, val['s'], val['q'], len(val['s']), batch_size=BATCH_SIZE) train_acc = metrics.accuracy_score(np.array(train_preds), train['a']) val_acc = metrics.accuracy_score(val_preds, val['a']) print('Epoch[{}] : <ACCURACY>\n\ttraining : {} \n\tvalidation : {}'. format(i, train_acc, val_acc)) log_handle.write('{} {} {} {}\n'.format(i, train_acc, val_acc, cost_total/(eval_interval*len(batches)))) cost_total = 0. # empty cost # # save the best model, to disk #if val_acc > best_validation_accuracy: #best_validation_accuracy = val_acc model.saver.save(model._sess, CKPT_DIR + '{}/memn2n_model.ckpt'.format(args['task_id']), global_step=i) model.saver.save(model._sess, CKPT_DIR + '{}/memn2n_model.ckpt_final'.format(args['task_id']), global_step=i) # close file log_handle.close() else: # inference ### print("launching interative dialog session") #print("model's session variables" + model._sess) ckpt = tf.train.get_checkpoint_state(CKPT_DIR + '{}'.format(args['task_id'])) if ckpt and ckpt.model_checkpoint_path: print('\n>> restoring checkpoint from', ckpt.model_checkpoint_path) model.saver.restore(model._sess, ckpt.model_checkpoint_path) ds = interactive_session(task_id=args['task_id']) ds.interactive_dialog(task_id=args['task_id']) '''