def generate(use_cuda=False, device=-1): vocab = du.load_vocab(args.vocab) eos_id = vocab.stoi[EOS_TOK] pad_id = vocab.stoi[PAD_TOK] sos_id = vocab.stoi[SOS_TOK] tup_id = vocab.stoi[TUP_TOK] assert False == (args.perplexity and args.seed and args.ranking), "Only 1 can be True at a time." # Batch size during decoding is set to 1 assert args.batch_size == 1, "Set batch size to 1 during decoding." # Load the model. with open(args.model, 'rb') as f: model = torch.load(f, map_location=lambda s, loc: s) # set the eval mode model.eval() # to decode without cuda model.use_cuda = False # TEMP FIX to work with old models without this parameter. #model.type_emb = None # TASK SPECIFIC FUNCTION CALLS if args.ranking: # HARD only. Easy one has been deactivated. do_ranking(model, vocab) elif args.perplexity: get_perplexity(model, vocab) elif args.seed: gen_from_seed(model, vocab, eos_id, pad_id, sos_id, tup_id) else: print("NOT IMPLEMENTED. RETURNING.") return
def do_ranking(model, vocab): dataset = du.NarrativeClozeDataset(args.data, vocab, src_seq_length=MAX_EVAL_SEQ_LEN, min_seq_length=MIN_EVAL_SEQ_LEN) batches = BatchIter(dataset, args.batch_size, sort_key=lambda x: len(x.actual), train=False, device=device) ranked_acc = 0.0 if args.emb_type: print("RANKING WITH ROLE EMB") vocab2 = du.load_vocab(args.vocab2) role_dataset = du.NarrativeClozeDataset( args.role_data, vocab2, src_seq_length=MAX_EVAL_SEQ_LEN, min_seq_length=MIN_EVAL_SEQ_LEN) role_batches = BatchIter(role_dataset, args.batch_size, sort_key=lambda x: len(x.actual), train=False, device=device) assert len(dataset) == len( role_dataset), "Dataset and Role dataset must be of same length." for iteration, (bl, rbl) in enumerate(zip(batches, role_batches)): if (iteration + 1) % 25 == 0: print("iteration {}".format(iteration + 1)) ## DATA STEPS all_texts = [ bl.actual, bl.actual_tgt, bl.dist1, bl.dist1_tgt, bl.dist2, bl.dist2_tgt, bl.dist3, bl.dist3_tgt, bl.dist4, bl.dist4_tgt, bl.dist5, bl.dist5_tgt ] # each is a tup all_roles = [ rbl.actual, rbl.dist1, rbl.dist2, rbl.dist3, rbl.dist4, rbl.dist5 ] # tgts are not needed for role assert len(all_roles) == 6, "6 = 6 * 1." assert len(all_texts) == 12, "12 = 6 * 2." all_texts_vars = [] all_roles_vars = [] if use_cuda: for tup in all_texts: all_texts_vars.append((Variable(tup[0].cuda(), volatile=True), tup[1])) for tup in all_roles: all_roles_vars.append((Variable(tup[0].cuda(), volatile=True), tup[1])) else: for tup in all_texts: all_texts_vars.append((Variable(tup[0], volatile=True), tup[1])) for tup in all_roles: all_roles_vars.append((Variable(tup[0], volatile=True), tup[1])) # will itetrate 2 at a time using iterator and next vars_iter = iter(all_texts_vars) roles_iter = iter(all_roles_vars) # run the model and collect ppls for all 6 sentences pps = [] for tup in vars_iter: ## INIT AND DECODE before every sentence hidden = model.init_hidden(args.batch_size) next_tup = next(vars_iter) role_tup = next(roles_iter) nll = calc_perplexity(args, model, tup[0], vocab, next_tup[0], next_tup[1], hidden, role_tup[0]) pp = torch.exp(nll) #print("NEG-LOSS {} PPL {}".format(nll.data[0], pp.data[0])) pps.append(pp.data.numpy()[0]) # low perplexity == top ranked sentence- correct answer is the first one of course assert len(pps) == 6, "6 targets." #print("\n") all_texts_str = [ transform(text[0].data.numpy()[0], vocab.itos) for text in all_texts_vars ] #print("ALL: {}".format(all_texts_str)) min_index = np.argmin(pps) if min_index == 0: ranked_acc += 1 #print("TARGET: {}".format(transform(all_texts_vars[1][0].data.numpy()[0], vocab.itos))) #print("CORRECT: {}".format(transform(all_texts_vars[1][0].data.numpy()[0], vocab.itos))) #else: # print the ones that are wrong #print("TARGET: {}".format(transform(all_texts_vars[1][0].data.numpy()[0], vocab.itos))) #print("WRONG: {}".format(transform(all_texts_vars[min_index+2][0].data.numpy()[0], vocab.itos))) if (iteration + 1) == args.max_decode: print("Max decode reached. Exiting.") break ranked_acc /= (iteration + 1) * 1 / 100 # multiplying to get percent print("Average acc(%): {}".format(ranked_acc)) return ranked_acc else: # THIS IS FOR MODEL WITHOUT ROLE EMB print("RANKING WITHOUT ROLE EMB.") for iteration, bl in enumerate(batches): if (iteration + 1) % 25 == 0: print("iteration {}".format(iteration + 1)) ## DATA STEPS all_texts = [ bl.actual, bl.actual_tgt, bl.dist1, bl.dist1_tgt, bl.dist2, bl.dist2_tgt, bl.dist3, bl.dist3_tgt, bl.dist4, bl.dist4_tgt, bl.dist5, bl.dist5_tgt ] # each is a tup assert len(all_texts) == 12, "12 = 6 * 2." all_texts_vars = [] if use_cuda: for tup in all_texts: all_texts_vars.append((Variable(tup[0].cuda(), volatile=True), tup[1])) else: for tup in all_texts: all_texts_vars.append((Variable(tup[0], volatile=True), tup[1])) # will itetrate 2 at a time using iterator and next vars_iter = iter(all_texts_vars) # run the model for all 6 sentences pps = [] for tup in vars_iter: ## INIT AND DECODE before every sentence hidden = model.init_hidden(args.batch_size) next_tup = next(vars_iter) nll = calc_perplexity(args, model, tup[0], vocab, next_tup[0], next_tup[1], hidden) pp = torch.exp(nll) #print("NEG-LOSS {} PPL {}".format(nll.data[0], pp.data[0])) pps.append(pp.data.numpy()[0]) # low perplexity == top ranked sentence- correct answer is the first one of course assert len(pps) == 6, "6 targets." #print("\n") all_texts_str = [ transform(text[0].data.numpy()[0], vocab.itos) for text in all_texts_vars ] #print("ALL: {}".format(all_texts_str)) min_index = np.argmin(pps) if min_index == 0: ranked_acc += 1 #print("TARGET: {}".format(transform(all_texts_vars[1][0].data.numpy()[0], vocab.itos))) #print("CORRECT: {}".format(transform(all_texts_vars[1][0].data.numpy()[0], vocab.itos))) #else: # print the ones that are wrong #print("TARGET: {}".format(transform(all_texts_vars[1][0].data.numpy()[0], vocab.itos))) #print("WRONG: {}".format(transform(all_texts_vars[min_index+2][0].data.numpy()[0], vocab.itos))) if (iteration + 1) == args.max_decode: print("Max decode reached. Exiting.") break ranked_acc /= (iteration + 1) * 1 / 100 # multiplying to get percent print("Average acc(%): {}".format(ranked_acc)) return ranked_acc
def get_perplexity(model, vocab): total_loss = 0.0 if args.emb_type: # GET PERPLEXITY WITH ROLE EMB print("PERPLEXITY WITH ROLE EMB") vocab2 = du.load_vocab(args.vocab2) dataset = du.LMRoleSentenceDataset( args.data, vocab, args.role_data, vocab2, src_seq_length=MAX_EVAL_SEQ_LEN, min_seq_length=MIN_EVAL_SEQ_LEN) #put in filter pred later batches = BatchIter(dataset, args.batch_size, sort_key=lambda x: len(x.text), train=False, device=device) print("DATASET {}".format(len(dataset))) for iteration, bl in enumerate(batches): if (iteration + 1) % 25 == 0: print("iteration {}".format(iteration + 1)) ## DATA STEPS batch, batch_lens = bl.text target, target_lens = bl.target role, role_lens = bl.role if use_cuda: batch = Variable(batch.cuda(), volatile=True) target = Variable(target.cuda(), volatile=True) role = Variable(role.cuda(), volatile=True) else: batch = Variable(batch, volatile=True) target = Variable(target, volatile=True) role = Variable(role, volatile=True) ## INIT AND DECODE hidden = model.init_hidden(args.batch_size) ce_loss = calc_perplexity(args, model, batch, vocab, target, target_lens, hidden, role) #print("Loss {}".format(ce_loss)) total_loss = total_loss + ce_loss.data[0] if (iteration + 1) == args.max_decode: print("Max decode reached. Exiting.") break # after iterating over all examples loss = total_loss / (iteration + 1) print("Average Loss: {}".format(loss)) return loss else: print("PERPLEXITY WITHOUT ROLE EMB") dataset = du.LMSentenceDataset( args.data, vocab, src_seq_length=MAX_EVAL_SEQ_LEN, min_seq_length=MIN_EVAL_SEQ_LEN) #put in filter pred later batches = BatchIter(dataset, args.batch_size, sort_key=lambda x: len(x.text), train=False, device=device) for iteration, bl in enumerate(batches): if (iteration + 1) % 25 == 0: print("iteration {}".format(iteration + 1)) ## DATA STEPS batch, batch_lens = bl.text target, target_lens = bl.target if use_cuda: batch = Variable(batch.cuda(), volatile=True) target = Variable(target, volatile=True) else: batch = Variable(batch, volatile=True) target = Variable(target, volatile=True) ## INIT AND DECODE hidden = model.init_hidden(args.batch_size) ce_loss = calc_perplexity(args, model, batch, vocab, target, target_lens, hidden) #print("Loss {}".format(ce_loss)) total_loss = total_loss + ce_loss.data[0] if (iteration + 1) == args.max_decode: print("Max decode reached. Exiting.") break # after iterating over all examples loss = total_loss / (iteration + 1) print("Average Loss: {}".format(loss)) return loss
def gen_from_seed(model, vocab, eos_id, pad_id, sos_id, tup_id): if args.emb_type: # GEN FROM SEED WITH ROLE EMB print("GEN SEED WITH ROLE EMB") vocab2 = du.load_vocab(args.vocab2) # will use this to feed in role ids in beam decode ROLES = [ vocab2.stoi[TUP_TOK], vocab2.stoi[VERB], vocab2.stoi[SUB], vocab2.stoi[OBJ], vocab2.stoi[PREP] ] dataset = du.LMRoleSentenceDataset( args.data, vocab, args.role_data, vocab2, src_seq_length=MAX_EVAL_SEQ_LEN, min_seq_length=MIN_EVAL_SEQ_LEN) #put in filter pred later dataset = du.LMRoleSentenceDataset(args.data, vocab, args.role_data, vocab2) #put in filter pred later batches = BatchIter(dataset, args.batch_size, sort_key=lambda x: len(x.text), train=False, device=device) for iteration, bl in enumerate(batches): if (iteration + 1) % 25 == 0: print("iteration {}".format(iteration + 1)) ## DATA STEPS batch, batch_lens = bl.text target, target_lens = bl.target role, role_lens = bl.role if use_cuda: batch = Variable(batch.cuda(), volatile=True) role = Variable(role.cuda(), volatile=True) else: batch = Variable(batch, volatile=True) role = Variable(role, volatile=True) ## INIT AND DECODE hidden = model.init_hidden(args.batch_size) #run the model first on t-1 events, except last word. we know corresponding role ids as well. seq_len = batch.size(1) for i in range(seq_len - 1): inp = batch[:, i] inp = inp.unsqueeze(args.batch_size) typ = role[:, i] typ = typ.unsqueeze(1) _, hidden = model(inp, hidden, typ) #print("seq len {}, decode after {} steps".format(seq_len, i+1)) # beam set current state to last word in the sequence beam_inp = batch[:, i + 1] # do not need this anymore as assuming last sequence role obj is prep. #role_inp = role[:, i+1] # print("ROLES LIST: {}".format(ROLES)) # print("FIRST ID: {}".format(role[:, i+1])) # init beam initializes the beam with the last sequence element. ROLE is a list of roe type ids. outputs = beam_decode(model, beam_inp, hidden, args.max_len_decode, args.beam_size, pad_id, sos_id, eos_id, tup_idx=tup_id, init_beam=True, roles=ROLES) predicted_events = get_pred_events(outputs, vocab) print("CONTEXT: {}".format( transform(batch.data.squeeze(), vocab.itos))) print("PRED_t: {}".format( predicted_events)) # n_best stitched together. if (iteration + 1) == args.max_decode: print("Max decode reached. Exiting.") break else: print("GEN SEED WITHOUT ROLE EMB") dataset = du.LMSentenceDataset( args.data, vocab, src_seq_length=MAX_EVAL_SEQ_LEN, min_seq_length=MIN_EVAL_SEQ_LEN) #put in filter pred later batches = BatchIter(dataset, args.batch_size, sort_key=lambda x: len(x.text), train=False, device=device) for iteration, bl in enumerate(batches): if (iteration + 1) % 25 == 0: print("iteration {}".format(iteration + 1)) ## DATA STEPS batch, batch_lens = bl.text target, target_lens = bl.target if use_cuda: batch = Variable(batch.cuda(), volatile=True) else: batch = Variable(batch, volatile=True) ## INIT AND DECODE hidden = model.init_hidden(args.batch_size) #run the model first on t-1 events, except last word seq_len = batch.size(1) for i in range(seq_len - 1): inp = batch[:, i] inp = inp.unsqueeze(args.batch_size) _, hidden = model(inp, hidden) #print("seq len {}, decode after {} steps".format(seq_len, i+1)) # beam set current state to last word in the sequence beam_inp = batch[:, i + 1] # init beam initializesthe beam with the last sequence element outputs = beam_decode(model, beam_inp, hidden, args.max_len_decode, args.beam_size, pad_id, sos_id, eos_id, tup_idx=tup_id, init_beam=True) predicted_events = get_pred_events(outputs, vocab) print("CONTEXT: {}".format( transform(batch.data.squeeze(), vocab.itos))) print("PRED_t: {}".format( predicted_events)) # n_best stitched together. if (iteration + 1) == args.max_decode: print("Max decode reached. Exiting.") break
if torch.cuda.is_available(): if not args.cuda: logging.warning("WARNING: You have a CUDA device, so you should probably run with --cuda") args.device = torch.device('cpu') else: args.device = torch.device('cuda') logging.info("Using GPU {}".format(torch.cuda.get_device_name(args.device))) else: args.device = torch.device('cpu') evocab = du.load_vocab(args.evocab) with open(args.pmi_dict, 'r') as fi: pmi_dict = json.load(fi) with open(args.causal_dict, 'rb') as fi: causal_dict = pickle.load(fi) evocab_lm = du.convert_to_lm_vocab(copy.deepcopy(evocab)) lm_model = torch.load(args.lm_model, map_location=args.device) so_events = [x for x in evocab.itos if len(x.split('->')) == 2 and x.split('->')[1] in ['nsubj', 'dobj', 'iobj']] #only count these in the rankings print(len(so_events))
def train(args): """ Train the model in the ol' fashioned way, just like grandma used to Args args (argparse.ArgumentParser) """ #Load the data logging.info("Loading Vocab") evocab = du.load_vocab(args.evocab) tvocab = du.load_vocab(args.tvocab) logging.info("Event Vocab Loaded, Size {}".format(len(evocab.stoi.keys()))) logging.info("Text Vocab Loaded, Size {}".format(len(tvocab.stoi.keys()))) if args.use_pretrained: pretrained = GloVe(name='6B', dim=args.text_embed_size, unk_init=torch.Tensor.normal_) tvocab = du.load_vectors(pretrained) logging.info("Loaded Pretrained Word Embeddings") if args.load_model: logging.info("Loading the Model") model = torch.load(args.load_model, map_location=args.device) else: logging.info("Creating the Model") if args.onehot_events: logging.info( "Model Type: SemiNaiveAdjustmentEstimatorOneHotEvents") model = estimators.SemiNaiveAdjustmentEstimatorOneHotEvents( args, evocab, tvocab) else: logging.info("Model Type: SemiNaiveAdjustmentEstimator") model = estimators.SemiNaiveAdjustmentEstimator( args, evocab, tvocab) if args.finetune: assert args.load_model logging.info("Finetuning...") if args.freeze: logging.info("Freezing...") for param in model.parameters(): param.requires_grad = False model = estimators.AdjustmentEstimator(args, evocab, tvocab, model) #Still finetune the last layer even if freeze is on (if freeze is on , then everything else is frozen) model.expected_outcome.event_text_logits_mlp.weight.requires_grad = True model.expected_outcome.event_text_logits_mlp.bias.requires_grad = True logging.info("Trainable Params: {}".format( [x[0] for x in model.named_parameters() if x[1].requires_grad])) model = model.to(device=args.device) #create the optimizer if args.load_opt: logging.info("Loading the optimizer state") optimizer = torch.load(args.load_opt) else: if args.optimizer == 'adagrad': logging.info("Creating Adagrad optimizer anew") optimizer = torch.optim.Adagrad(filter(lambda x: x.requires_grad, model.parameters()), lr=args.lr) elif args.optimizer == 'sgd': logging.info("Creating SGD optimizer anew") optimizer = torch.optim.SGD(filter(lambda x: x.requires_grad, model.parameters()), lr=args.lr) else: logging.info("Creating Adam optimizer anew") optimizer = torch.optim.Adam(filter(lambda x: x.requires_grad, model.parameters()), lr=args.lr) logging.info("Loading Datasets") min_size = model.text_encoder.largest_ngram_size #Add extra pads if text size smaller than largest CNN kernel size if args.load_pickle: logging.info("Loading Train from Pickled Data") with open(args.train_data, 'rb') as pfi: pickled_examples = pickle.load(pfi) train_dataset = du.InstanceDataset("", evocab, tvocab, min_size=min_size, pickled_examples=pickled_examples) else: train_dataset = du.InstanceDataset(args.train_data, evocab, tvocab, min_size=min_size) valid_dataset = du.InstanceDataset(args.valid_data, evocab, tvocab, min_size=min_size) #Remove UNK events from the e1prev_intext attribute so they don't mess up avg encoders # train_dataset.filter_examples(['e1prev_intext']) #These take really long time! Will have to figure something out... # valid_dataset.filter_examples(['e1prev_intext']) logging.info("Finished Loading Training Dataset {} examples".format( len(train_dataset))) logging.info("Finished Loading Valid Dataset {} examples".format( len(valid_dataset))) train_batches = BatchIter(train_dataset, args.batch_size, sort_key=lambda x: len(x.allprev), train=True, repeat=False, shuffle=True, sort_within_batch=True, device=None) valid_batches = BatchIter(valid_dataset, args.batch_size, sort_key=lambda x: len(x.allprev), train=False, repeat=False, shuffle=False, sort_within_batch=True, device=None) train_data_len = len(train_dataset) valid_data_len = len(valid_dataset) loss_func = nn.CrossEntropyLoss() start_time = time.time() #start of epoch 1 best_valid_loss = float('inf') best_epoch = args.epochs if args.finetune: vloss = validation(args, valid_batches, model, loss_func) logging.info("Pre Finetune Validation Loss: {}".format(vloss)) #MAIN TRAINING LOOP for curr_epoch in range(args.epochs): prev_losses = [] for iteration, inst in enumerate(train_batches): instance = du.send_instance_to(inst, args.device) model.train() model.zero_grad() model_outputs = model(instance) exp_outcome_out = model_outputs[ EXP_OUTCOME_COMPONENT] #[batch X num events], output predication for e2 exp_outcome_loss = loss_func(exp_outcome_out, instance.e2) loss = exp_outcome_loss loss.backward() torch.nn.utils.clip_grad_norm(model.parameters(), args.clip) optimizer.step() prev_losses.append(loss.cpu().data) prev_losses = prev_losses[-50:] if (iteration % args.log_every == 0) and iteration != 0: past_50_avg = sum(prev_losses) / len(prev_losses) logging.info( "Epoch/iteration {}/{}, Past 50 Average Loss {}, Best Val {} at Epoch {}" .format( curr_epoch, iteration, past_50_avg, 'NA' if best_valid_loss == float('inf') else best_valid_loss, 'NA' if best_epoch == args.epochs else best_epoch)) if (iteration % args.validate_after == 0) and iteration != 0: logging.info( "Running Validation at Epoch/iteration {}/{}".format( curr_epoch, iteration)) new_valid_loss = validation(args, valid_batches, model, loss_func) logging.info( "Validation loss at Epoch/iteration {}/{}: {:.3f} - Best Validation Loss: {:.3f}" .format(curr_epoch, iteration, new_valid_loss, best_valid_loss)) if new_valid_loss < best_valid_loss: logging.info( "New Validation Best...Saving Model Checkpoint") best_valid_loss = new_valid_loss best_epoch = curr_epoch #torch.save(model, "{}.epoch_{}.loss_{:.2f}.pt".format(args.save_model, curr_epoch, best_valid_loss)) #torch.save(optimizer, "{}.{}.epoch_{}.loss_{:.2f}.pt".format(args.save_model, "optimizer", curr_epoch, best_valid_loss)) torch.save(model, "{}".format(args.save_model)) torch.save(optimizer, "{}_optimizer".format(args.save_model)) #END OF EPOCH logging.info("End of Epoch {}, Running Validation".format(curr_epoch)) new_valid_loss = validation(args, valid_batches, model, loss_func) logging.info( "Validation loss at end of Epoch {}: {:.3f} - Best Validation Loss: {:.3f}" .format(curr_epoch, new_valid_loss, best_valid_loss)) if new_valid_loss < best_valid_loss: logging.info("New Validation Best...Saving Model Checkpoint") best_valid_loss = new_valid_loss best_epoch = curr_epoch #torch.save(model, "{}.epoch_{}.loss_{:.2f}.pt".format(args.save_model, curr_epoch, best_valid_loss)) #torch.save(optimizer, "{}.{}.epoch_{}.loss_{:.2f}.pt".format(args.save_model, "optimizer", curr_epoch, best_valid_loss)) torch.save(model, "{}".format(args.save_model)) torch.save(optimizer, "{}_optimizer".format(args.save_model)) if curr_epoch - best_epoch >= args.stop_after: logging.info( "No improvement in {} epochs, terminating at epoch {}...". format(args.stop_after, curr_epoch)) logging.info("Best Validation Loss: {:.2f} at Epoch {}".format( best_valid_loss, best_epoch)) break
def train(args): """ Train the model in the ol' fashioned way, just like grandma used to Args args (argparse.ArgumentParser) """ #Load the data logging.info("Loading Vocab") evocab = du.load_vocab(args.evocab) logging.info("Event Vocab Loaded, Size {}".format(len(evocab.stoi.keys()))) evocab.stoi[SOS_TOK] = len(evocab.itos) evocab.itos.append(SOS_TOK) evocab.stoi[EOS_TOK] = len(evocab.itos) evocab.itos.append(EOS_TOK) assert evocab.stoi[EOS_TOK] == evocab.itos.index(EOS_TOK) assert evocab.stoi[SOS_TOK] == evocab.itos.index(SOS_TOK) if args.load_model: logging.info("Loading the Model") model = torch.load(args.load_model, map_location=args.device) else: logging.info("Creating the Model") model = LM.EventLM(args.event_embed_size, args.rnn_hidden_dim, args.rnn_layers, len(evocab.itos), dropout=args.dropout) model = model.to(device=args.device) #create the optimizer if args.load_opt: logging.info("Loading the optimizer state") optimizer = torch.load(args.load_opt) else: logging.info("Creating the optimizer anew") optimizer = torch.optim.Adam(filter(lambda x: x.requires_grad, model.parameters()), lr=args.lr) # optimizer = torch.optim.Adagrad(model.parameters(), lr=args.lr) logging.info("Loading Datasets") train_dataset = du.LmInstanceDataset(args.train_data, evocab) valid_dataset = du.LmInstanceDataset(args.valid_data, evocab) #Remove UNK events from the e1prev_intext attribute so they don't mess up avg encoders # train_dataset.filter_examples(['e1prev_intext']) #These take really long time! Will have to figure something out... # valid_dataset.filter_examples(['e1prev_intext']) logging.info("Finished Loading Training Dataset {} examples".format( len(train_dataset))) logging.info("Finished Loading Valid Dataset {} examples".format( len(valid_dataset))) train_batches = BatchIter(train_dataset, args.batch_size, sort_key=lambda x: len(x.text), train=True, repeat=False, shuffle=True, sort_within_batch=True, device=None) valid_batches = BatchIter(valid_dataset, args.batch_size, sort_key=lambda x: len(x.text), train=False, repeat=False, shuffle=False, sort_within_batch=True, device=None) train_data_len = len(train_dataset) valid_data_len = len(valid_dataset) start_time = time.time() #start of epoch 1 best_valid_loss = float('inf') best_epoch = args.epochs #MAIN TRAINING LOOP for curr_epoch in range(args.epochs): prev_losses = [] for iteration, inst in enumerate(train_batches): instance = du.lm_send_instance_to(inst, args.device) text_inst, text_lens = inst.text target_inst, target_lens = inst.target model.train() model.zero_grad() logits = [] hidden = None for step in range(text_inst.size(1)): step_inp = text_inst[:, step] #get all instances for this step step_inp = step_inp.unsqueeze(1) #[batch X 1] logit_i, hidden = model(step_inp, hidden) logits += [logit_i] logits = torch.stack(logits, dim=1) #[batch, seq_len, vocab] loss = masked_cross_entropy(logits, target_inst, target_lens) loss.backward() torch.nn.utils.clip_grad_norm(model.parameters(), args.clip) optimizer.step() prev_losses.append(loss.cpu().data) prev_losses = prev_losses[-50:] if (iteration % args.log_every == 0) and iteration != 0: past_50_avg = sum(prev_losses) / len(prev_losses) logging.info( "Epoch/iteration {}/{}, Past 50 Average Loss {}, Best Val {} at Epoch {}" .format( curr_epoch, iteration, past_50_avg, 'NA' if best_valid_loss == float('inf') else best_valid_loss, 'NA' if best_epoch == args.epochs else best_epoch)) if (iteration % args.validate_after == 0) and iteration != 0: logging.info( "Running Validation at Epoch/iteration {}/{}".format( curr_epoch, iteration)) new_valid_loss = validation(args, valid_batches, model) logging.info( "Validation loss at Epoch/iteration {}/{}: {:.3f} - Best Validation Loss: {:.3f}" .format(curr_epoch, iteration, new_valid_loss, best_valid_loss)) if new_valid_loss < best_valid_loss: logging.info( "New Validation Best...Saving Model Checkpoint") best_valid_loss = new_valid_loss best_epoch = curr_epoch #torch.save(model, "{}.epoch_{}.loss_{:.2f}.pt".format(args.save_model, curr_epoch, best_valid_loss)) #torch.save(optimizer, "{}.{}.epoch_{}.loss_{:.2f}.pt".format(args.save_model, "optimizer", curr_epoch, best_valid_loss)) torch.save(model, "{}".format(args.save_model)) torch.save(optimizer, "{}_optimizer".format(args.save_model)) #END OF EPOCH logging.info("End of Epoch {}, Running Validation".format(curr_epoch)) new_valid_loss = validation(args, valid_batches, model) logging.info( "Validation loss at end of Epoch {}: {:.3f} - Best Validation Loss: {:.3f}" .format(curr_epoch, new_valid_loss, best_valid_loss)) if new_valid_loss < best_valid_loss: logging.info("New Validation Best...Saving Model Checkpoint") best_valid_loss = new_valid_loss best_epoch = curr_epoch #torch.save(model, "{}.epoch_{}.loss_{:.2f}.pt".format(args.save_model, curr_epoch, best_valid_loss)) #torch.save(optimizer, "{}.{}.epoch_{}.loss_{:.2f}.pt".format(args.save_model, "optimizer", curr_epoch, best_valid_loss)) torch.save(model, "{}".format(args.save_model)) torch.save(optimizer, "{}_optimizer".format(args.save_model)) if curr_epoch - best_epoch >= args.stop_after: logging.info( "No improvement in {} epochs, terminating at epoch {}...". format(args.stop_after, curr_epoch)) logging.info("Best Validation Loss: {:.2f} at Epoch {}".format( best_valid_loss, best_epoch)) break