def train(print_loss_total, print_act_total, print_grad_total, input_tensor, target_tensor, bs_tensor, db_tensor, name=None): # create an empty matrix with padding tokens input_tensor, input_lengths = util.padSequence(input_tensor) target_tensor, target_lengths = util.padSequence(target_tensor) bs_tensor = torch.tensor(bs_tensor, dtype=torch.float, device=device) db_tensor = torch.tensor(db_tensor, dtype=torch.float, device=device) loss, loss_acts, grad = model.train(input_tensor, input_lengths, target_tensor, target_lengths, db_tensor, bs_tensor, name) #print(loss, loss_acts) print_loss_total += loss print_act_total += loss_acts print_grad_total += grad model.global_step += 1 model.sup_loss = torch.zeros(1) return print_loss_total, print_act_total, print_grad_total
def decode(data, model, device): # model, val_dials, test_dials = loadModelAndData(num) # device = torch.device("cuda" if args.cuda else "cpu") for ii in range(1): if ii == 0: # print(50 * '-' + 'GREEDY') model.beam_search = False else: print(50 * '-' + 'BEAM') model.beam_search = True # VALIDATION val_dials_gen = {} valid_loss = 0 # for name, val_file in val_dials.items(): for i in range(1): val_file = data['cur'] input_tensor = []; target_tensor = [];bs_tensor = [];db_tensor = [] input_tensor, target_tensor, bs_tensor, db_tensor = util.loadDialogue(model, val_file, input_tensor, target_tensor, bs_tensor, db_tensor) # create an empty matrix with padding tokens input_tensor, input_lengths = util.padSequence(input_tensor) target_tensor, target_lengths = util.padSequence(target_tensor) bs_tensor = torch.tensor(bs_tensor, dtype=torch.float, device=device) db_tensor = torch.tensor(db_tensor, dtype=torch.float, device=device) output_words, loss_sentence = model.predict(input_tensor, input_lengths, target_tensor, target_lengths, db_tensor, bs_tensor) valid_loss += 0 return output_words[-1]
def trainIters(model, n_epochs=10, args=args): prev_min_loss, early_stop_count = 1 << 30, args.early_stop_count start = time.time() for epoch in range(1, n_epochs + 1): print_loss_total = 0 print_grad_total = 0 print_act_total = 0 # Reset every print_every start_time = time.time() # watch out where do you put it model.optimizer = Adam(lr=args.lr_rate, params=filter(lambda x: x.requires_grad, model.parameters()), weight_decay=args.l2_norm) model.optimizer_policy = Adam(lr=args.lr_rate, params=filter(lambda x: x.requires_grad, model.policy.parameters()), weight_decay=args.l2_norm) dials = list(train_dials.keys()) random.shuffle(dials) input_tensor = [] target_tensor = [] bs_tensor = [] db_tensor = [] for name in dials: val_file = train_dials[name] model.optimizer.zero_grad() model.optimizer_policy.zero_grad() input_tensor, target_tensor, bs_tensor, db_tensor = util.loadDialogue( model, val_file, input_tensor, target_tensor, bs_tensor, db_tensor) if len(db_tensor) > args.batch_size: print_loss_total, print_act_total, print_grad_total = train( print_loss_total, print_act_total, print_grad_total, input_tensor, target_tensor, bs_tensor, db_tensor) input_tensor = [] target_tensor = [] bs_tensor = [] db_tensor = [] print_loss_avg = print_loss_total / len(train_dials) print_act_total_avg = print_act_total / len(train_dials) print_grad_avg = print_grad_total / len(train_dials) print('TIME:', time.time() - start_time) print( 'Time since %s (Epoch:%d %d%%) Loss: %.4f, Loss act: %.4f, Grad: %.4f' % (util.timeSince(start, epoch / n_epochs), epoch, epoch / n_epochs * 100, print_loss_avg, print_act_total_avg, print_grad_avg)) # VALIDATION valid_loss = 0 for name, val_file in val_dials.items(): input_tensor = [] target_tensor = [] bs_tensor = [] db_tensor = [] input_tensor, target_tensor, bs_tensor, db_tensor = util.loadDialogue( model, val_file, input_tensor, target_tensor, bs_tensor, db_tensor) # create an empty matrix with padding tokens input_tensor, input_lengths = util.padSequence(input_tensor) target_tensor, target_lengths = util.padSequence(target_tensor) bs_tensor = torch.tensor(bs_tensor, dtype=torch.float, device=device) db_tensor = torch.tensor(db_tensor, dtype=torch.float, device=device) proba, _, _ = model.forward(input_tensor, input_lengths, target_tensor, target_lengths, db_tensor, bs_tensor) proba = proba.view(-1, model.vocab_size) # flatten all predictions loss = model.gen_criterion(proba, target_tensor.view(-1)) valid_loss += loss.item() valid_loss /= len(val_dials) print('Current Valid LOSS:', valid_loss) model.saveModel(epoch)
def decode(num=1): model, val_dials, test_dials = loadModelAndData(num) start_time = time.time() for ii in range(2): if ii == 0: print(50 * '-' + 'GREEDY') model.beam_search = False else: print(50 * '-' + 'BEAM') model.beam_search = True # VALIDATION val_dials_gen = {} valid_loss = 0 for name, val_file in val_dials.items(): input_tensor = []; target_tensor = [];bs_tensor = [];db_tensor = [] input_tensor, target_tensor, bs_tensor, db_tensor = util.loadDialogue(model, val_file, input_tensor, target_tensor, bs_tensor, db_tensor) # create an empty matrix with padding tokens input_tensor, input_lengths = util.padSequence(input_tensor) target_tensor, target_lengths = util.padSequence(target_tensor) bs_tensor = torch.tensor(bs_tensor, dtype=torch.float, device=device) db_tensor = torch.tensor(db_tensor, dtype=torch.float, device=device) output_words, loss_sentence = model.predict(input_tensor, input_lengths, target_tensor, target_lengths, db_tensor, bs_tensor) valid_loss += 0 val_dials_gen[name] = output_words print('Current VALID LOSS:', valid_loss) with open(args.valid_output + 'val_dials_gen.json', 'w') as outfile: json.dump(val_dials_gen, outfile) evaluateModel(val_dials_gen, val_dials, mode='valid') # TESTING test_dials_gen = {} test_loss = 0 idx = 0 for name, test_file in test_dials.items(): input_tensor = []; target_tensor = [];bs_tensor = [];db_tensor = [] input_tensor, target_tensor, bs_tensor, db_tensor = util.loadDialogue(model, test_file, input_tensor, target_tensor, bs_tensor, db_tensor) # create an empty matrix with padding tokens input_tensor, input_lengths = util.padSequence(input_tensor) target_tensor, target_lengths = util.padSequence(target_tensor) bs_tensor = torch.tensor(bs_tensor, dtype=torch.float, device=device) db_tensor = torch.tensor(db_tensor, dtype=torch.float, device=device) output_words, loss_sentence = model.predict(input_tensor, input_lengths, target_tensor, target_lengths, db_tensor, bs_tensor) test_loss += 0 test_dials_gen[name] = output_words test_loss /= len(test_dials) print('Current TEST LOSS:', test_loss) with open(args.decode_output + 'test_dials_gen.json', 'w') as outfile: json.dump(test_dials_gen, outfile) evaluateModel(test_dials_gen, test_dials, mode='test') print('TIME:', time.time() - start_time)