def train(self): """ Train the model. Model saved to """ num_hid, bidir, net_type, n2p, batch_size, model_url, graph_url, dev = \ None, True, None, None, None, None, None, None global train_batch_size, MODEL_URL, GRAPH_URL, device, TRAIN_MODEL_URL, TRAIN_GRAPH_URL if batch_size: train_batch_size = batch_size print("Setting up the batch size to {}.........................". format(batch_size)) if model_url: TRAIN_MODEL_URL = model_url print("Setting up the model url to {}.........................". format(TRAIN_MODEL_URL)) if graph_url: TRAIN_GRAPH_URL = graph_url print("Setting up the graph url to {}.........................". format(TRAIN_GRAPH_URL)) if dev: device = dev print( "Setting up the device to {}.........................".format( device)) # 1 Load and process the input data including the ontology # Load the word embeddings word_vectors = load_word_vectors(self.word_vectors_url) # Load the ontology and extract the feature vectors ontology, ontology_vectors, slots = load_ontology( self.ontology_url, word_vectors) # Load and process the training data dialogues, _ = load_woz_data(self.training_url, word_vectors, ontology) no_dialogues = len(dialogues) # Load and process the validation data val_dialogues, _ = load_woz_data(self.validation_url, word_vectors, ontology) # Generate the validation batch data val_data = generate_batch(val_dialogues, 0, len(val_dialogues), len(ontology)) val_iterations = int(len(val_dialogues) / train_batch_size) # 2 Initialise and set up the model graph # Initialise the model graph = tf.Graph() with graph.as_default(): model_variables = model_definition(ontology_vectors, len(ontology), slots, num_hidden=num_hid, bidir=bidir, net_type=net_type, dev=device) (user, sys_res, no_turns, user_uttr_len, sys_uttr_len, labels, domain_labels, domain_accuracy, slot_accuracy, value_accuracy, value_f1, train_step, keep_prob, _, _, _) = model_variables [precision, recall, value_f1] = value_f1 saver = tf.train.Saver() if device == 'gpu': config = tf.ConfigProto(allow_soft_placement=True) config.gpu_options.allow_growth = True else: config = tf.ConfigProto(device_count={'GPU': 0}) sess = tf.Session(config=config) if os.path.exists(TRAIN_MODEL_URL + ".index"): saver.restore(sess, TRAIN_MODEL_URL) print("Loading from an existing model {} ....................". format(TRAIN_MODEL_URL)) else: if not os.path.exists(TRAIN_MODEL_URL): os.makedirs('/'.join(TRAIN_MODEL_URL.split('/')[:-1])) os.makedirs('/'.join(TRAIN_GRAPH_URL.split('/')[:-1])) init = tf.global_variables_initializer() sess.run(init) print( "Create new model parameters....................................." ) merged = tf.summary.merge_all() val_accuracy = tf.summary.scalar('validation_accuracy', value_accuracy) val_f1 = tf.summary.scalar('validation_f1_score', value_f1) train_writer = tf.summary.FileWriter(TRAIN_GRAPH_URL, graph) train_writer.flush() # 3 Perform an epoch of training last_update = -1 best_f_score = -1 for epoch in range(no_epochs): batch_size = train_batch_size sys.stdout.flush() iterations = math.ceil(no_dialogues / train_batch_size) start_time = time.time() val_i = 0 shuffle(dialogues) for batch_id in range(iterations): if batch_id == iterations - 1 and no_dialogues % iterations != 0: batch_size = no_dialogues % train_batch_size batch_user, batch_sys, batch_labels, batch_domain_labels, batch_user_uttr_len, batch_sys_uttr_len, \ batch_no_turns = generate_batch(dialogues, batch_id, batch_size, len(ontology)) [_, summary, da, sa, va, vf, pr, re] = sess.run( [ train_step, merged, domain_accuracy, slot_accuracy, value_accuracy, value_f1, precision, recall ], feed_dict={ user: batch_user, sys_res: batch_sys, labels: batch_labels, domain_labels: batch_domain_labels, user_uttr_len: batch_user_uttr_len, sys_uttr_len: batch_sys_uttr_len, no_turns: batch_no_turns, keep_prob: 0.5 }) print( "The accuracies for domain is {:.2f}, slot {:.2f}, value {:.2f}, f1_score {:.2f} precision {:.2f}" " recall {:.2f} for batch {}".format( da, sa, va, vf, pr, re, batch_id + iterations * epoch)) train_writer.add_summary( summary, start_batch + batch_id + iterations * epoch) # ================================ VALIDATION ============================================== if batch_id % batches_per_eval == 0 or batch_id == 0: if batch_id == 0: print("Batch", "0", "to", batch_id, "took", round(time.time() - start_time, 2), "seconds.") else: print("Batch", batch_id + iterations * epoch - batches_per_eval, "to", batch_id + iterations * epoch, "took", round(time.time() - start_time, 3), "seconds.") start_time = time.time() _, _, v_acc, f1_score, sm1, sm2 = evaluate_model( sess, model_variables, val_data, [val_accuracy, val_f1], batch_id, val_i) val_i += 1 val_i %= val_iterations train_writer.add_summary( sm1, start_batch + batch_id + iterations * epoch) train_writer.add_summary( sm2, start_batch + batch_id + iterations * epoch) stime = time.time() current_metric = f1_score print(" Validation metric:", round(current_metric, 5), " eval took", round(time.time() - stime, 2), "last update at:", last_update, "/", iterations) # and if we got a new high score for validation f-score, we need to save the parameters: if current_metric > best_f_score: last_update = batch_id + iterations * epoch + 1 print( "\n ====================== New best validation metric:", round(current_metric, 4), " - saving these parameters. Batch is:", last_update, "/", iterations, "---------------- =========== \n") best_f_score = current_metric saver.save(sess, TRAIN_MODEL_URL) print("The best parameters achieved a validation metric of", round(best_f_score, 4))
def update(self, user_act=None): """Update the dialog state.""" if type(user_act) is not str: raise Exception( 'Expected user_act to be <class \'str\'> type, but get {}.'. format(type(user_act))) prev_state = self.state if not os.path.exists(os.path.join(self.data_dir, "results")): os.makedirs(os.path.join(self.data_dir, "results")) global train_batch_size model_variables = self.model_variables (user, sys_res, no_turns, user_uttr_len, sys_uttr_len, labels, domain_labels, domain_accuracy, slot_accuracy, value_accuracy, value_f1, train_step, keep_prob, predictions, true_predictions, [y, _]) = model_variables # generate fake dialogue based on history (this os to reuse the original MDBT code) # actual_history = prev_state['history'] # [[sys, user], [sys, user], ...] actual_history = copy.deepcopy( prev_state['history']) # [[sys, user], [sys, user], ...] actual_history = [['null']] actual_history[-1].append(user_act) actual_history = self.normalize_history(actual_history) if len(actual_history) == 0: actual_history = [[ '', user_act if len(user_act) > 0 else 'fake user act' ]] fake_dialogue = {} turn_no = 0 for _sys, _user in actual_history: turn = {} turn['system'] = _sys fake_user = {} fake_user['text'] = _user fake_user['belief_state'] = default_state()['belief_state'] turn['user'] = fake_user key = str(turn_no) fake_dialogue[key] = turn turn_no += 1 context, actual_context = process_history([fake_dialogue], self.word_vectors, self.ontology) batch_user, batch_sys, batch_labels, batch_domain_labels, batch_user_uttr_len, batch_sys_uttr_len, \ batch_no_turns = generate_batch(context, 0, 1, len(self.ontology)) # old feature # run model [pred, y_pred] = self.sess.run( [predictions, y], feed_dict={ user: batch_user, sys_res: batch_sys, labels: batch_labels, domain_labels: batch_domain_labels, user_uttr_len: batch_user_uttr_len, sys_uttr_len: batch_sys_uttr_len, no_turns: batch_no_turns, keep_prob: 1.0 }) # convert to str output dialgs, _, _ = track_dialogue(actual_context, self.ontology, pred, y_pred) assert len(dialgs) >= 1 last_turn = dialgs[0][-1] predictions = last_turn['prediction'] new_belief_state = copy.deepcopy(prev_state['belief_state']) # updaet belief state for item in predictions: item = item.lower() domain, slot, value = item.strip().split('-') value = value[::-1].split(':', 1)[1][::-1] if slot == 'price range': slot = 'pricerange' if slot not in ['name', 'book']: if domain not in new_belief_state: raise Exception( 'Error: domain <{}> not in belief state'.format( domain)) slot = REF_SYS_DA[domain.capitalize()].get(slot, slot) assert 'semi' in new_belief_state[domain] assert 'book' in new_belief_state[domain] if 'book' in slot: assert slot.startswith('book ') slot = slot.strip().split()[1] domain_dic = new_belief_state[domain] if slot in domain_dic['semi']: new_belief_state[domain]['semi'][slot] = normalize_value( self.value_dict, domain, slot, value) elif slot in domain_dic['book']: new_belief_state[domain]['book'][slot] = value elif slot.lower() in domain_dic['book']: new_belief_state[domain]['book'][slot.lower()] = value else: with open('mdbt_unknown_slot.log', 'a+') as f: f.write( 'unknown slot name <{}> with value <{}> of domain <{}>\nitem: {}\n\n' .format(slot, value, domain, item)) new_request_state = copy.deepcopy(prev_state['request_state']) # update request_state user_request_slot = self.detect_requestable_slots(user_act) for domain in user_request_slot: for key in user_request_slot[domain]: if domain not in new_request_state: new_request_state[domain] = {} if key not in new_request_state[domain]: new_request_state[domain][key] = user_request_slot[domain][ key] # update state new_state = copy.deepcopy(dict(prev_state)) new_state['belief_state'] = new_belief_state new_state['request_state'] = new_request_state self.state = new_state return self.state
def test(self, sess): """Test the MDBT model on mdbt dataset. Almost the same as original code.""" if not os.path.exists("../../data/mdbt/results"): os.makedirs("../../data/mdbt/results") global train_batch_size, MODEL_URL, GRAPH_URL model_variables = self.model_variables (user, sys_res, no_turns, user_uttr_len, sys_uttr_len, labels, domain_labels, domain_accuracy, slot_accuracy, value_accuracy, value_f1, train_step, keep_prob, predictions, true_predictions, [y, _]) = model_variables [precision, recall, value_f1] = value_f1 # print("\tMDBT: Loading from an existing model {} ....................".format(MODEL_URL)) iterations = math.ceil(self.no_dialogues / train_batch_size) batch_size = train_batch_size [slot_acc, tot_accuracy] = [np.zeros(len(self.ontology), dtype="float32"), 0] slot_accurac = 0 # value_accurac = np.zeros((len(slots),), dtype="float32") value_accurac = 0 joint_accuracy = 0 f1_score = 0 preci = 0 recal = 0 processed_dialogues = [] # np.set_printoptions(threshold=np.nan) for batch_id in range(int(iterations)): if batch_id == iterations - 1: batch_size = self.no_dialogues - batch_id * train_batch_size batch_user, batch_sys, batch_labels, batch_domain_labels, batch_user_uttr_len, batch_sys_uttr_len, \ batch_no_turns = generate_batch(self.dialogues, batch_id, batch_size, len(self.ontology)) [da, sa, va, vf, pr, re, pred, true_pred, y_pred] = sess.run( [ domain_accuracy, slot_accuracy, value_accuracy, value_f1, precision, recall, predictions, true_predictions, y ], feed_dict={ user: batch_user, sys_res: batch_sys, labels: batch_labels, domain_labels: batch_domain_labels, user_uttr_len: batch_user_uttr_len, sys_uttr_len: batch_sys_uttr_len, no_turns: batch_no_turns, keep_prob: 1.0 }) true = sum([ 1 if np.array_equal(pred[k, :], true_pred[k, :]) and sum(true_pred[k, :]) > 0 else 0 for k in range(true_pred.shape[0]) ]) actual = sum([ 1 if sum(true_pred[k, :]) > 0 else 0 for k in range(true_pred.shape[0]) ]) ja = true / actual tot_accuracy += da # joint_accuracy += ja slot_accurac += sa if math.isnan(pr): pr = 0 preci += pr recal += re if math.isnan(vf): vf = 0 f1_score += vf # value_accurac += va slot_acc += np.mean(np.asarray(np.equal(pred, true_pred), dtype="float32"), axis=0) dialgs, va1, ja = track_dialogue( self.actual_dialogues[batch_id * train_batch_size:batch_id * train_batch_size + batch_size], self.ontology, pred, y_pred) processed_dialogues += dialgs joint_accuracy += ja value_accurac += va1 print( "The accuracies for domain is {:.2f}, slot {:.2f}, value {:.2f}, other value {:.2f}, f1_score {:.2f} precision {:.2f}" " recall {:.2f} for batch {}".format(da, sa, np.mean(va), va1, vf, pr, re, batch_id)) print( "End of evaluating the test set..........................................................................." ) slot_acc /= iterations # print("The accuracies for each slot:") # print(value_accurac/iterations) print("The overall accuracies for domain is" " {}, slot {}, value {}, f1_score {}, precision {}," " recall {}, joint accuracy {}".format( tot_accuracy / iterations, slot_accurac / iterations, value_accurac / iterations, f1_score / iterations, preci / iterations, recal / iterations, joint_accuracy / iterations)) with open(self.results_url, 'w') as f: json.dump(processed_dialogues, f, indent=4)