def __init__(self): # pretraining entity_lookup_table = None if load_pretrained_vectors: print('Loading pretrained word embeddings...') with open(pretrained_vector_path, 'rb') as f: entity_lookup_table = pickle.load(f) if verbose: print("Loaded pretrained vectors of size: ", entity_lookup_table.shape) print("Entity vocab size: ", entity_vocab_size) # data self.batcher = Batcher(train_file, kb_file, text_kb_file, batch_size, vocab_dir, min_num_mem_slots=min_facts, max_num_mem_slots=max_facts, use_kb_mem=use_kb, use_text_mem=use_text, max_num_text_mem_slots=max_text_facts, min_num_text_mem_slots=min_facts) self.dev_batcher = Batcher(dev_file, kb_file, text_kb_file, dev_batch_size, vocab_dir, min_num_mem_slots=min_facts, max_num_mem_slots=dev_max_facts, return_one_epoch=True, shuffle=False, use_kb_mem=use_kb, use_text_mem=use_text, max_num_text_mem_slots=dev_max_text_facts, min_num_text_mem_slots=min_facts) # define network if use_kb and use_text: self.model = TextKBQA(entity_vocab_size=entity_vocab_size, relation_vocab_size=relation_vocab_size, embedding_size=embedding_size, hops=hops, load_pretrained_model=load_model, load_pretrained_vectors=load_pretrained_vectors, join=combine_text_kb_answer, pretrained_entity_vectors=entity_lookup_table, verbose=verbose, separate_key_lstm=separate_key_lstm).cuda() elif use_kb: self.model = walking_memory(entity_vocab_size=entity_vocab_size, relation_vocab_size=relation_vocab_size, embedding_size=embedding_size, hops=hops, load_pretrained_model=load_model, load_pretrained_vectors=load_pretrained_vectors, pretrained_entity_vectors=entity_lookup_table, verbose=verbose).cuda() elif use_text: ''' self.model = TextQA(entity_vocab_size=entity_vocab_size, embedding_size=embedding_size, hops=hops, load_pretrained_model=load_model, load_pretrained_vectors=load_pretrained_vectors, pretrained_entity_vectors=entity_lookup_table, verbose=verbose, separate_key_lstm=separate_key_lstm).cuda() ''' # optimizer self.optimizer = torch.optim.Adam(self.model.parameters(), lr) self.max_dev_acc = -1.0
def test_batcher(): train_file = "/iesl/canvas/rajarshi/data/TextKBQA/small_train_with_facts.json" kb_file = "/iesl/canvas/rajarshi/data/TextKBQA/freebase.spades.txt" batch_size = 32 vocab_dir = "/home/rajarshi/research/joint-text-and-kb-inference-semantic-parsing/vocab/" min_num_mem_slots = 100 max_num_mem_slots = 500 batcher = Batcher(train_file, kb_file, batch_size, vocab_dir, min_num_mem_slots=min_num_mem_slots, max_num_mem_slots=max_num_mem_slots, return_one_epoch=True, shuffle=False) batch_counter = 0 for data in batcher.get_next_batch(): batch_counter += 1 batch_question, batch_q_lengths, batch_answer, batch_memory, batch_num_memories = data print("####### Test1: Checking number of batches returned#########") assert batch_counter == 1 print("Test passed!") batch_size = 19 batcher.batch_size = batch_size batcher.reset() batch_counter = 0 for data in batcher.get_next_batch(): batch_counter += 1 batch_question, batch_q_lengths, batch_answer, batch_memory, batch_num_memories = data print( "####### Test2: Checking number of batches returned with different batch size #########" ) print(batch_counter) assert batch_counter == 2 print("Test passed!") batch_size = 20 batcher.batch_size = batch_size batcher.reset() for data in batcher.get_next_batch(): batch_counter += 1 batch_question, batch_q_lengths, batch_answer, batch_memory, batch_num_memories = data print(batch_question[0]) print(batch_answer[0]) sys.exit(1)
class Trainer(object): def __init__(self): with tf.Session() as sess: print('Blake hack for acquiring gpu') # pretraining entity_lookup_table = None if load_pretrained_vectors: print('Loading pretrained word embeddings...') with open(pretrained_vector_path, 'rb') as f: entity_lookup_table = pickle.load(f) if verbose: print("Loaded pretrained vectors of size: ", entity_lookup_table.shape) print("Entity vocab size: ", entity_vocab_size) # data self.batcher = Batcher(train_file, kb_file, text_kb_file, batch_size, vocab_dir, min_num_mem_slots=min_facts, max_num_mem_slots=max_facts, use_kb_mem=use_kb, use_text_mem=use_text, max_num_text_mem_slots=max_text_facts, min_num_text_mem_slots=min_facts) self.dev_batcher = Batcher(dev_file, kb_file, text_kb_file, dev_batch_size, vocab_dir, min_num_mem_slots=min_facts, max_num_mem_slots=dev_max_facts, return_one_epoch=True, shuffle=False, use_kb_mem=use_kb, use_text_mem=use_text, max_num_text_mem_slots=dev_max_text_facts, min_num_text_mem_slots=min_facts) # define network if use_kb and use_text: self.model = TextKBQA( entity_vocab_size=entity_vocab_size, relation_vocab_size=relation_vocab_size, embedding_size=embedding_size, hops=hops, load_pretrained_model=load_model, load_pretrained_vectors=load_pretrained_vectors, join=combine_text_kb_answer, pretrained_entity_vectors=entity_lookup_table, verbose=verbose, separate_key_lstm=separate_key_lstm) elif use_kb: self.model = KBQA(entity_vocab_size=entity_vocab_size, relation_vocab_size=relation_vocab_size, embedding_size=embedding_size, hops=hops, load_pretrained_model=load_model, load_pretrained_vectors=load_pretrained_vectors, pretrained_entity_vectors=entity_lookup_table, verbose=verbose) elif use_text: self.model = TextQA( entity_vocab_size=entity_vocab_size, embedding_size=embedding_size, hops=hops, load_pretrained_model=load_model, load_pretrained_vectors=load_pretrained_vectors, pretrained_entity_vectors=entity_lookup_table, verbose=verbose, separate_key_lstm=separate_key_lstm) # optimizer self.optimizer = tf.train.AdamOptimizer(lr) self.max_dev_acc = -1.0 def bp(self, cost): tvars = tf.trainable_variables() grads = tf.gradients(cost, tvars) grads, _ = tf.clip_by_global_norm(grads, grad_clip_norm) train_op = self.optimizer.apply_gradients(zip(grads, tvars)) return train_op def initialize(self): #### inputs #### self.question = tf.placeholder(tf.int32, [None, None], name="question") self.question_lengths = tf.placeholder(tf.int32, [None], name="question_lengths") self.answer = tf.placeholder(tf.int32, [None], name="answer") if use_kb and use_text: self.memory = tf.placeholder(tf.int32, [None, None, 3], name="memory") self.text_key_mem = tf.placeholder(tf.int32, [None, None, None], name="key_mem") self.text_key_len = tf.placeholder(tf.int32, [None, None], name="key_len") self.text_val_mem = tf.placeholder(tf.int32, [None, None], name="val_mem") # network output self.output = self.model(self.memory, self.text_key_mem, self.text_key_len, self.text_val_mem, self.question, self.question_lengths) elif use_kb: self.memory = tf.placeholder(tf.int32, [None, None, 3], name="memory") # network output self.output = self.model(self.memory, self.question, self.question_lengths) elif use_text: self.text_key_mem = tf.placeholder(tf.int32, [None, None, None], name="key_mem") self.text_key_len = tf.placeholder(tf.int32, [None, None], name="key_len") self.text_val_mem = tf.placeholder(tf.int32, [None, None], name="val_mem") # network output self.output = self.model(self.text_key_mem, self.text_key_len, self.text_val_mem, self.question, self.question_lengths) # predict self.probs = tf.nn.softmax(self.output) self.predict_op = tf.argmax(self.output, 1, name="predict_op") self.rank_op = tf.nn.top_k(self.output, 50) # loss cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( self.output, self.answer) self.loss = tf.reduce_mean(cross_entropy, name="loss_op") if use_kb and use_text: # Graph created now load/save op for it # load the parameters for the kb only model #var_list = [v for v in tf.trainable_variables() if v.name.startswith('BiRNN/')] #var_list += [self.model.entity_lookup_table, self.model.relation_lookup_table, self.model.W, self.model.b, # self.model.W1, self.model.b1, self.model.R[0]] #self.saver = tf.train.Saver(var_list=var_list) self.saver = tf.train.Saver() else: self.saver = tf.train.Saver() # Add to the Graph the Ops that calculate and apply gradients. self.train_op = self.bp(self.loss) # return the variable initializer Op. init_op = tf.initialize_all_variables() return init_op def dev_eval(self, sess): print('Evaluating on dev set...') dev_start_time = time.time() num_dev_data = 0 dev_loss = 0.0 dev_acc = 0.0 attn_weight = None preds = [] SRR = 0.0 for data in tqdm(self.dev_batcher.get_next_batch()): if use_kb and use_text: dev_batch_question, dev_batch_q_lengths, dev_batch_answer, dev_batch_memory, dev_batch_num_memories, \ dev_batch_text_key_mem, dev_batch_text_key_len, dev_batch_text_val_mem, dev_batch_num_text_mems = data feed_dict_dev = { self.question: dev_batch_question, self.question_lengths: dev_batch_q_lengths, self.answer: dev_batch_answer, self.memory: dev_batch_memory, self.text_key_mem: dev_batch_text_key_mem, self.text_key_len: dev_batch_text_key_len, self.text_val_mem: dev_batch_text_val_mem } elif use_kb: dev_batch_question, dev_batch_q_lengths, dev_batch_answer, dev_batch_memory, dev_batch_num_memories = data feed_dict_dev = { self.question: dev_batch_question, self.question_lengths: dev_batch_q_lengths, self.answer: dev_batch_answer, self.memory: dev_batch_memory } elif use_text: dev_batch_question, dev_batch_q_lengths, dev_batch_answer, dev_batch_text_key_mem, dev_batch_text_key_len, \ dev_batch_text_val_mem, dev_batch_num_text_mems = data feed_dict_dev = { self.question: dev_batch_question, self.question_lengths: dev_batch_q_lengths, self.answer: dev_batch_answer, self.text_key_mem: dev_batch_text_key_mem, self.text_key_len: dev_batch_text_key_len, self.text_val_mem: dev_batch_text_val_mem } dev_batch_loss_value, dev_prediction, batch_attn_weight, topk = sess.run( [ self.loss, self.predict_op, self.model.attn_weights_all_hops, self.rank_op ], feed_dict=feed_dict_dev) for j, v in enumerate(topk.indices): for i, w in enumerate(v): if w == dev_batch_answer[j]: SRR += 1.0 / (i + 1) dev_loss += dev_batch_loss_value num_dev_data += dev_batch_question.shape[0] dev_acc += np.sum(dev_prediction == dev_batch_answer) attn_weight = batch_attn_weight[0] if attn_weight is None \ else np.vstack((attn_weight, batch_attn_weight[0])) # store predictions dev_prediction = np.expand_dims(dev_prediction, axis=1) dev_batch_answer = np.expand_dims(dev_batch_answer, axis=1) if dev_prediction is not None: concat = np.concatenate((dev_prediction, dev_batch_answer), axis=1) preds.append(concat) print('MRR: ', SRR / num_dev_data) dev_acc = (1.0 * dev_acc / num_dev_data) dev_loss = (1.0 * dev_loss / num_dev_data) if print_attention_weights: f_out = open(output_dir + "/attn_wts.npy", 'w') np.save(f_out, attn_weight) print('Wrote attention weights...') self.dev_batcher.reset() if dev_acc >= 0.3 or mode == 'test': f_out = open(output_dir + "/out_txt." + str(dev_acc), 'w') print('Writing to {}'.format("out_txt." + str(dev_acc))) preds = np.vstack(preds) preds.tofile(f_out) if mode == 'test': f_out1 = open(output_dir + "/out.txt", 'w') preds.tofile(f_out1) f_out1.close() f_out.close() print( 'It took {0:10.4f}s to evaluate on dev set of size: {3:10d} with dev loss: {1:10.4f} and dev acc: {2:10.4f}' .format(time.time() - dev_start_time, dev_loss, dev_acc, num_dev_data)) return dev_acc, dev_loss def fit(self): train_loss = 0.0 batch_counter = 0 train_acc = 0.0 with tf.Session(config=tf.ConfigProto( log_device_placement=False)) as sess: sess.run(self.initialize()) if load_model: print('Loading retrained model from {}'.format(model_path)) self.saver.restore(sess, model_path) if mode == 'test': self.dev_eval(sess) # print(sess.run(self.model.b)) # self.dev_eval(sess) if mode == 'train': self.start_time = time.time() print('Starting to train') for data in self.batcher.get_next_batch(): batch_counter += 1 if use_kb and use_text: batch_question, batch_q_lengths, batch_answer, batch_memory, batch_num_memories, \ batch_text_key_mem, batch_text_key_len, batch_text_val_mem, batch_num_text_mems = data feed_dict = { self.question: batch_question, self.question_lengths: batch_q_lengths, self.answer: batch_answer, self.memory: batch_memory, self.text_key_mem: batch_text_key_mem, self.text_key_len: batch_text_key_len, self.text_val_mem: batch_text_val_mem } elif use_kb: batch_question, batch_q_lengths, batch_answer, batch_memory, batch_num_memories = data feed_dict = { self.question: batch_question, self.question_lengths: batch_q_lengths, self.answer: batch_answer, self.memory: batch_memory } elif use_text: batch_question, batch_q_lengths, batch_answer, batch_text_key_mem, batch_text_key_len, \ batch_text_val_mem, batch_num_text_mems = data feed_dict = { self.question: batch_question, self.question_lengths: batch_q_lengths, self.answer: batch_answer, self.text_key_mem: batch_text_key_mem, self.text_key_len: batch_text_key_len, self.text_val_mem: batch_text_val_mem } # train batch_loss_value, _, prediction = sess.run( [self.loss, self.train_op, self.predict_op], feed_dict=feed_dict) batch_train_acc = (1.0 * np.sum(prediction == batch_answer) / (batch_question.shape[0])) train_loss = 0.98 * train_loss + 0.02 * batch_loss_value train_acc = 0.98 * train_acc + 0.02 * batch_train_acc print( '\t at iter {0:10d} at time {1:10.4f}s train loss: {2:10.4f}, train_acc: {3:10.4f} ' .format(batch_counter, time.time() - self.start_time, train_loss, train_acc)) if batch_counter != 0 and batch_counter % dev_eval_counter == 0: # predict on dev dev_acc, dev_loss = self.dev_eval(sess) print( '\t at iter {0:10d} at time {1:10.4f}s dev loss: {2:10.4f} dev_acc: {3:10.4f} ' .format(batch_counter, time.time() - self.start_time, dev_loss, dev_acc)) if dev_acc > self.max_dev_acc: self.max_dev_acc = dev_acc # save this model save_path = self.saver.save( sess, output_dir + "/max_dev_out.ckpt") if use_kb and use_text: save_path = self.saver.save( sess, output_dir + "/full_max_dev_out.ckpt") with open(output_dir + "/dev_accuracies.txt", mode='a') as out: out.write( 'Dev accuracy while writing max_dev_out.ckpt {0:10.4f}\n' .format(self.max_dev_acc)) print("Saved model") if batch_counter % save_counter == 0: save_path = self.saver.save( sess, output_dir + "/out.ckpt") print("Saved model")
class Trainer(object): def __init__(self): # pretraining entity_lookup_table = None if load_pretrained_vectors: print('Loading pretrained word embeddings...') with open(pretrained_vector_path, 'rb') as f: entity_lookup_table = pickle.load(f) if verbose: print("Loaded pretrained vectors of size: ", entity_lookup_table.shape) print("Entity vocab size: ", entity_vocab_size) # data self.batcher = Batcher(train_file, kb_file, text_kb_file, batch_size, vocab_dir, min_num_mem_slots=min_facts, max_num_mem_slots=max_facts, use_kb_mem=use_kb, use_text_mem=use_text, max_num_text_mem_slots=max_text_facts, min_num_text_mem_slots=min_facts) self.dev_batcher = Batcher(dev_file, kb_file, text_kb_file, dev_batch_size, vocab_dir, min_num_mem_slots=min_facts, max_num_mem_slots=dev_max_facts, return_one_epoch=True, shuffle=False, use_kb_mem=use_kb, use_text_mem=use_text, max_num_text_mem_slots=dev_max_text_facts, min_num_text_mem_slots=min_facts) # define network if use_kb and use_text: self.model = TextKBQA(entity_vocab_size=entity_vocab_size, relation_vocab_size=relation_vocab_size, embedding_size=embedding_size, hops=hops, load_pretrained_model=load_model, load_pretrained_vectors=load_pretrained_vectors, join=combine_text_kb_answer, pretrained_entity_vectors=entity_lookup_table, verbose=verbose, separate_key_lstm=separate_key_lstm).cuda() elif use_kb: self.model = walking_memory(entity_vocab_size=entity_vocab_size, relation_vocab_size=relation_vocab_size, embedding_size=embedding_size, hops=hops, load_pretrained_model=load_model, load_pretrained_vectors=load_pretrained_vectors, pretrained_entity_vectors=entity_lookup_table, verbose=verbose).cuda() elif use_text: ''' self.model = TextQA(entity_vocab_size=entity_vocab_size, embedding_size=embedding_size, hops=hops, load_pretrained_model=load_model, load_pretrained_vectors=load_pretrained_vectors, pretrained_entity_vectors=entity_lookup_table, verbose=verbose, separate_key_lstm=separate_key_lstm).cuda() ''' # optimizer self.optimizer = torch.optim.Adam(self.model.parameters(), lr) self.max_dev_acc = -1.0 # def bp(self, cost): # tvars = tf.trainable_variables() # grads = tf.gradients(cost, tvars) # grads, _ = tf.clip_by_global_norm(grads, grad_clip_norm) # train_op = self.optimizer.apply_gradients(zip(grads, tvars)) # return train_op # def initialize(self): # #### inputs #### # self.question = tf.placeholder(tf.int32, [None, None], name="question") # self.question_lengths = tf.placeholder(tf.int32, [None], name="question_lengths") # self.answer = tf.placeholder(tf.int32, [None], name="answer") # if use_kb and use_text: # self.memory = tf.placeholder(tf.int32, [None, None, 3], name="memory") # self.text_key_mem = tf.placeholder(tf.int32, [None, None, None], name="key_mem") # self.text_key_len = tf.placeholder(tf.int32, [None, None], name="key_len") # self.text_val_mem = tf.placeholder(tf.int32, [None, None], name="val_mem") # # network output # self.output = self.model(self.memory, self.text_key_mem, self.text_key_len, self.text_val_mem, # self.question, self.question_lengths) # elif use_kb: # self.memory = tf.placeholder(tf.int32, [None, None, 3], name="memory") # # network output # self.output = self.model(self.memory, self.question, self.question_lengths) # elif use_text: # self.text_key_mem = tf.placeholder(tf.int32, [None, None, None], name="key_mem") # self.text_key_len = tf.placeholder(tf.int32, [None, None], name="key_len") # self.text_val_mem = tf.placeholder(tf.int32, [None, None], name="val_mem") # # network output # self.output = self.model(self.text_key_mem, self.text_key_len, self.text_val_mem, self.question, # self.question_lengths) # # predict # self.probs = tf.nn.softmax(self.output) # self.predict_op = tf.argmax(self.output, 1, name="predict_op") # self.rank_op = tf.nn.top_k(self.output, 50) # # loss # cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(self.output, self.answer) # self.loss = tf.reduce_mean(cross_entropy, name="loss_op") # if use_kb and use_text: # # Graph created now load/save op for it # # load the parameters for the kb only model # #var_list = [v for v in tf.trainable_variables() if v.name.startswith('BiRNN/')] # #var_list += [self.model.entity_lookup_table, self.model.relation_lookup_table, self.model.W, self.model.b, # # self.model.W1, self.model.b1, self.model.R[0]] # #self.saver = tf.train.Saver(var_list=var_list) # self.saver = tf.train.Saver() # else: # self.saver = tf.train.Saver() # # Add to the Graph the Ops that calculate and apply gradients. # self.train_op = self.bp(self.loss) # # return the variable initializer Op. # init_op = tf.initialize_all_variables() # return init_op def dev_eval(self): print('Evaluating on dev set...') dev_start_time = time.time() num_dev_data = 0 dev_loss = 0.0 dev_acc = 0.0 attn_weight = None preds = [] SRR = 0.0 for data in tqdm(self.dev_batcher.get_next_batch()): self.model.eval() if use_kb and use_text: dev_batch_question, dev_batch_q_lengths, dev_batch_answer, dev_batch_memory, dev_batch_num_memories, \ dev_batch_text_key_mem, dev_batch_text_key_len, dev_batch_text_val_mem, dev_batch_num_text_mems = data logits = self.model(Variable(torch.LongTensor(dev_batch_memory.astype(int))).cuda(), Variable(torch.LongTensor(dev_batch_text_key_mem.astype(int))).cuda(), Variable(torch.LongTensor(dev_batch_text_key_len.astype(int))).cuda(), Variable(torch.LongTensor(dev_batch_text_val_mem.astype(int))).cuda(), Variable(torch.LongTensor( dev_batch_question.astype(int))).cuda(), Variable(torch.LongTensor(dev_batch_q_lengths.astype(int))).cuda()) elif use_kb: dev_batch_question, dev_batch_q_lengths, dev_batch_answer, dev_batch_memory, dev_batch_num_memories = data logits = self.model(Variable(torch.LongTensor(dev_batch_memory.astype(int))).cuda(), Variable(torch.LongTensor( dev_batch_question.astype(int))).cuda(), Variable(torch.LongTensor(dev_batch_q_lengths.astype(int))).cuda()) elif use_text: ''' dev_batch_question, dev_batch_q_lengths, dev_batch_answer, dev_batch_text_key_mem, dev_batch_text_key_len, \ dev_batch_text_val_mem, dev_batch_num_text_mems = data feed_dict_dev = {self.question: dev_batch_question, self.question_lengths: dev_batch_q_lengths, self.answer: dev_batch_answer, self.text_key_mem: dev_batch_text_key_mem, self.text_key_len: dev_batch_text_key_len, self.text_val_mem: dev_batch_text_val_mem} ''' # eval dev_batch_loss_value = F.cross_entropy(logits, Variable( torch.LongTensor(dev_batch_answer.astype(int))).cuda()) dev_prediction = torch.max(logits, dim=1)[1] topk = torch.topk(logits, 50)[1].data.cpu().numpy() for j, v in enumerate(topk): for i, w in enumerate(v): if w == dev_batch_answer[j]: SRR += 1.0 / (i + 1) dev_loss += dev_batch_loss_value.data[0] num_dev_data += dev_batch_question.shape[0] dev_acc += (1.0 * np.sum(dev_prediction.data.cpu().numpy() == dev_batch_answer)) # print attention weight is a future feature # attn_weight = batch_attn_weight[0] if attn_weight is None \ # else np.vstack((attn_weight, batch_attn_weight[0])) # store predictions dev_prediction = np.expand_dims( dev_prediction.data.cpu().numpy(), axis=1) dev_batch_answer = np.expand_dims(dev_batch_answer, axis=1) if dev_prediction is not None: concat = np.concatenate( (dev_prediction, dev_batch_answer), axis=1) preds.append(concat) print('MRR: ', SRR / num_dev_data) dev_acc = (1.0 * dev_acc / num_dev_data) dev_loss = (1.0 * dev_loss / num_dev_data) # if print_attention_weights: # f_out = open(output_dir + "/attn_wts.npy", 'w') # np.save(f_out, attn_weight) # print('Wrote attention weights...') self.dev_batcher.reset() if dev_acc >= 0.3 or mode == 'test': f_out = open(output_dir + "/out_txt." + str(dev_acc), 'w') print('Writing to {}'.format("out_txt." + str(dev_acc))) preds = np.vstack(preds) preds.tofile(f_out) if mode == 'test': f_out1 = open(output_dir + "/out.txt", 'w') preds.tofile(f_out1) f_out1.close() f_out.close() print( 'It took {0:10.4f}s to evaluate on dev set of size: {3:10d} with dev loss: {1:10.4f} and dev acc: {2:10.4f}'.format( time.time() - dev_start_time, dev_loss, dev_acc, num_dev_data)) return dev_acc, dev_loss def fit(self): train_loss = 0.0 batch_counter = 0 train_acc = 0.0 if load_model: print('Loading retrained model from {}'.format(model_path)) self.model.load_state_dict(torch.load(model_path)) if mode == 'test': self.model.eval() self.dev_eval() # print(sess.run(self.model.b)) # self.dev_eval(sess) if mode == 'train': self.start_time = time.time() print('Starting to train') for data in self.batcher.get_next_batch(): batch_counter += 1 # train self.model.train() if use_kb and use_text: batch_question, batch_q_lengths, batch_answer, batch_memory, batch_num_memories, \ batch_text_key_mem, batch_text_key_len, batch_text_val_mem, batch_num_text_mems = data logits = self.model(Variable(torch.LongTensor(batch_memory.astype(int))).cuda(), Variable(torch.LongTensor(batch_text_key_mem.astype(int))).cuda(), Variable(torch.LongTensor(batch_text_key_len.astype(int))).cuda(), Variable(torch.LongTensor(batch_text_val_mem.astype(int))).cuda(), Variable(torch.LongTensor( batch_question.astype(int))).cuda(), Variable(torch.LongTensor(batch_q_lengths.astype(int))).cuda()) elif use_kb: batch_question, batch_q_lengths, batch_answer, batch_memory, batch_num_memories = data logits = self.model(Variable(torch.LongTensor(batch_memory.astype(int))).cuda(), Variable(torch.LongTensor( batch_question.astype(int))).cuda(), Variable(torch.LongTensor(batch_q_lengths.astype(int))).cuda()) elif use_text: raise NotImplementedError ''' batch_question, batch_q_lengths, batch_answer, batch_text_key_mem, batch_text_key_len, \ batch_text_val_mem, batch_num_text_mems = data feed_dict = {self.question: batch_question, self.question_lengths: batch_q_lengths, self.answer: batch_answer, self.text_key_mem: batch_text_key_mem, self.text_key_len: batch_text_key_len, self.text_val_mem: batch_text_val_mem} ''' batch_loss_value = F.cross_entropy(logits, Variable( torch.LongTensor(batch_answer.astype(int))).cuda()) prediction = torch.max(logits, dim=1)[1] self.optimizer.zero_grad() batch_loss_value.backward() torch.nn.utils.clip_grad_norm(self.model.parameters(), grad_clip_norm) self.optimizer.step() batch_train_acc = ( 1.0 * np.sum(prediction.data.cpu().numpy() == batch_answer) / (batch_question.shape[0])) # moving average train_loss = 0.98 * train_loss + \ 0.02 * batch_loss_value.data[0] train_acc = 0.98 * train_acc + 0.02 * batch_train_acc print('\t at iter {0:10d} at time {1:10.4f}s train loss: {2:10.4f}, train_acc: {3:10.4f} '.format( batch_counter, time.time() - self.start_time, train_loss, train_acc)) if batch_counter != 0 and batch_counter % dev_eval_counter == 0: # predict on dev dev_acc, dev_loss = self.dev_eval() print('\t at iter {0:10d} at time {1:10.4f}s dev loss: {2:10.4f} dev_acc: {3:10.4f} '.format( batch_counter, time.time() - self.start_time, dev_loss, dev_acc)) if dev_acc > self.max_dev_acc: self.max_dev_acc = dev_acc # save this model torch.save(self.model.state_dict(), output_dir + "/max_dev_out.ckpt") if use_kb and use_text: torch.save(self.model.state_dict(), output_dir + "/full_max_dev_out.ckpt") with open(output_dir + "/dev_accuracies.txt", mode='a') as out: out.write( 'Dev accuracy while writing max_dev_out.ckpt {0:10.4f}\n'.format(self.max_dev_acc)) print("Saved model") if batch_counter % save_counter == 0: torch.save(self.model.state_dict(), output_dir + "/out.ckpt") print("Saved model")