def decode_sampling(self, input, u_len, w_len, decode_dict): """ this method is meant to be used at inference time input = input to the encoder u_len = utterance lengths w_len = word lengths decode_dict: - batch_size = batch_size - time_step = max_summary_length - vocab_size = 30522 for BERT - device = cpu or cuda - start_token_id = ID of the start token - stop_token_id = ID of the stop token - alpha = length normalisation - length_offset = length offset - keypadmask_dtype = torch.bool """ batch_size = decode_dict['batch_size'] time_step = decode_dict['time_step'] vocab_size = decode_dict['vocab_size'] device = decode_dict['device'] start_token_id = decode_dict['start_token_id'] stop_token_id = decode_dict['stop_token_id'] alpha = decode_dict['alpha'] length_offset = decode_dict['length_offset'] keypadmask_dtype = decode_dict['keypadmask_dtype'] # we should only feed through the encoder just once!! s_output, s_len = self.encoder(input, u_len, w_len) # memory # we run the decoder time_step times (auto-regressive) tgt_ids = torch.zeros((batch_size, time_step), dtype=torch.int64).to(device) tgt_ids[:, 0] = start_token_id for t in range(time_step - 1): decoder_output = self.decoder(tgt_ids[:, :t + 1], s_output, s_len, logsoftmax=False)[:, -1, :] pmf = nn.functional.softmax(decoder_output, dim=-1).cpu().numpy() for bn in range(batch_size): id = np.random.choice(vocab_size, p=pmf[bn]) tgt_ids[bn, t + 1] = id if (t % 100) == 0: print("{}=".format(t), end="") sys.stdout.flush() print("{}=#".format(t)) print(bert_tokenizer.decode(tgt_ids[0].cpu().numpy())) summaries_id = [None for _ in range(batch_size)] for j in range(batch_size): summaries_id[j] = tgt_ids[j].cpu().numpy() return summaries_id
def tgtids2summary(tgt_ids): # tgt_ids = a row of numpy array containing token ids bert_decoded = bert_tokenizer.decode(tgt_ids) # truncate START_TOKEN & part after STOP_TOKEN stop_idx = bert_decoded.find(STOP_TOKEN) processed_bert_decoded = bert_decoded[5:stop_idx] summary = [s.strip() for s in processed_bert_decoded.split(SEP_TOKEN)] return summary
def evaluate_greedy(model, eval_data, eval_batch_size, args, device): num_eval_epochs = int(len(eval_data)/eval_batch_size) print("num_eval_epochs = {}".format(num_eval_epochs)) eval_idx = 0 from rouge import Rouge rouge = Rouge() bert_decoded_outputs = [] bert_decoded_targets = [] for bn in range(num_eval_epochs): input, u_len, w_len, target, tgt_len, _, _, _ = get_a_batch( eval_data, eval_idx, eval_batch_size, args['num_utterances'], args['num_words'], args['summary_length'], args['summary_type'], device) # decoder target decoder_target, decoder_mask = shift_decoder_target(target, tgt_len, device) decoder_target = decoder_target.view(-1) enc_output_dict = model.encoder(input, u_len, w_len) # memory u_output = enc_output_dict['u_output'] # forward-pass DECODER xt = torch.zeros((eval_batch_size, 1), dtype=torch.int64).to(device) xt.fill_(101) # 101 # initial hidden state ht = torch.zeros((model.decoder.num_layers, eval_batch_size, model.decoder.dec_hidden_size), dtype=torch.float).to(device) for bi, l in enumerate(u_len): ht[:,bi,:] = u_output[bi,l-1,:].unsqueeze(0) decoded_words = [103 for _ in range(args['summary_length'])] for t in range(args['summary_length']-1): decoder_output, ht, _ = model.decoder.forward_step(xt, ht, enc_output_dict, logsoftmax=True) next_word = decoder_output.argmax().item() xt.fill_(next_word) decoded_words[t] = next_word bert_decoded_output = bert_tokenizer.decode(decoded_words) stop_idx = bert_decoded_output.find('[MASK]') bert_decoded_output = bert_decoded_output[:stop_idx] bert_decoded_output = bert_decoded_output.replace('[SEP] ', '') bert_decoded_outputs.append(bert_decoded_output) bert_decoded_target = bert_tokenizer.decode(decoder_target.cpu().numpy()) stop_idx2 = bert_decoded_target.find('[MASK]') bert_decoded_target = bert_decoded_target[:stop_idx2] bert_decoded_target = bert_decoded_target.replace('[SEP] ', '') bert_decoded_targets.append(bert_decoded_target) eval_idx += eval_batch_size print("#", end="") sys.stdout.flush() print() try: scores = rouge.get_scores(bert_decoded_outputs, bert_decoded_targets, avg=True) print("--------------------------------------------------") print("ROUGE-1 = {:.2f}".format(scores['rouge-1']['f']*100)) print("ROUGE-2 = {:.2f}".format(scores['rouge-2']['f']*100)) print("ROUGE-L = {:.2f}".format(scores['rouge-l']['f']*100)) print("--------------------------------------------------") return (scores['rouge-1']['f'] + scores['rouge-2']['f'] + scores['rouge-l']['f'])*(-100)/3 except ValueError: print("cannot compute ROUGE score") return 0
def evaluate(model, eval_data, eval_batch_size, args, device, use_rouge=False): # num_eval_epochs = int(eval_data['num_data']/eval_batch_size) + 1 num_eval_epochs = int(len(eval_data)/eval_batch_size) print("num_eval_epochs = {}".format(num_eval_epochs)) eval_idx = 0 eval_total_loss = 0.0 eval_total_tokens = 0 if not use_rouge: criterion = nn.NLLLoss(reduction='none') else: from rouge import Rouge rouge = Rouge() bert_decoded_outputs = [] bert_decoded_targets = [] for bn in range(num_eval_epochs): input, u_len, w_len, target, tgt_len, _, _, _ = get_a_batch( eval_data, eval_idx, eval_batch_size, args['num_utterances'], args['num_words'], args['summary_length'], args['summary_type'], device) # decoder target decoder_target, decoder_mask = shift_decoder_target(target, tgt_len, device) decoder_target = decoder_target.view(-1) decoder_mask = decoder_mask.view(-1) decoder_output, _, _, _, _ = model(input, u_len, w_len, target) if not use_rouge: loss = criterion(decoder_output.view(-1, args['vocab_size']), decoder_target) eval_total_loss += (loss * decoder_mask).sum().item() eval_total_tokens += decoder_mask.sum().item() else: # use rouge if eval_batch_size != 1: raise ValueError("VAL_BATCH_SIZE must be 1 to use ROUGE") decoder_output = decoder_output.view(-1, args['vocab_size']) bert_decoded_output = bert_tokenizer.decode(torch.argmax(decoder_output, dim=-1).cpu().numpy()) stop_idx = bert_decoded_output.find('[MASK]') bert_decoded_output = bert_decoded_output[:stop_idx] bert_decoded_output = bert_decoded_output.replace('[SEP] ', '') bert_decoded_outputs.append(bert_decoded_output) bert_decoded_target = bert_tokenizer.decode(decoder_target.cpu().numpy()) stop_idx2 = bert_decoded_target.find('[MASK]') bert_decoded_target = bert_decoded_target[:stop_idx2] bert_decoded_target = bert_decoded_target.replace('[SEP] ', '') bert_decoded_targets.append(bert_decoded_target) eval_idx += eval_batch_size print("#", end="") sys.stdout.flush() print() if not use_rouge: avg_eval_loss = eval_total_loss / eval_total_tokens return avg_eval_loss else: try: scores = rouge.get_scores(bert_decoded_outputs, bert_decoded_targets, avg=True) return (scores['rouge-1']['f'] + scores['rouge-2']['f'] + scores['rouge-l']['f'])*(-100)/3 except ValueError: return 0
def train_v5(): print("Start training hierarchical RNN model") # ---------------------------------------------------------------------------------- # args = {} args['use_gpu'] = True args['num_utterances'] = 1500 # max no. utterance in a meeting args['num_words'] = 64 # max no. words in an utterance args['summary_length'] = 300 # max no. words in a summary args['summary_type'] = 'short' # long or short summary args['vocab_size'] = 30522 # BERT tokenizer args['embedding_dim'] = 256 # word embeeding dimension args['rnn_hidden_size'] = 512 # RNN hidden size args['dropout'] = 0.1 args['num_layers_enc'] = 2 # in total it's num_layers_enc*2 (word/utt) args['num_layers_dec'] = 1 args['batch_size'] = 1 args['update_nbatches'] = 2 args['num_epochs'] = 20 args['random_seed'] = 777 args['best_val_loss'] = 1e+10 args['val_batch_size'] = 1 # 1 for now --- evaluate ROUGE args['val_stop_training'] = 5 args['lr'] = 1.0 args['adjust_lr'] = True # if True overwrite the learning rate above args['initial_lr'] = 0.01 # lr = lr_0*step^(-decay_rate) args['decay_rate'] = 0.5 args['label_smoothing'] = 0.1 args['a_da'] = 0.2 args['a_ext'] = 0.2 args['a_cov'] = 0.0 args['a_div'] = 1.0 args['memory_utt'] = False args['model_save_dir'] = "/home/alta/summary/pm574/summariser1/lib/trained_models2/" args['load_model'] = "/home/alta/summary/pm574/summariser1/lib/trained_models2/model-HGRUV5_CNNDM_FEB26A-ep12-bn0" # add .pt later # args['load_model'] = "/home/alta/summary/pm574/summariser1/lib/trained_models2/model-HGRUV5_FEB28A-ep6" # args['load_model'] = "/home/alta/summary/pm574/summariser1/lib/trained_models2/model-HGRUV5MEM_APR8A-ep1" # args['load_model'] = None args['model_name'] = 'HGRUV5_APR16H5' # ---------------------------------------------------------------------------------- # print_config(args) if args['use_gpu']: if 'X_SGE_CUDA_DEVICE' in os.environ: # to run on CUED stack machine print('running on the stack... 1 GPU') cuda_device = os.environ['X_SGE_CUDA_DEVICE'] print('X_SGE_CUDA_DEVICE is set to {}'.format(cuda_device)) os.environ['CUDA_VISIBLE_DEVICES'] = cuda_device else: print('running locally...') os.environ["CUDA_VISIBLE_DEVICES"] = '0' # choose the device (GPU) here device = 'cuda' else: device = 'cpu' print("device = {}".format(device)) # random seed random.seed(args['random_seed']) torch.manual_seed(args['random_seed']) np.random.seed(args['random_seed']) train_data = load_ami_data('train') valid_data = load_ami_data('valid') # make the training data 100 random.shuffle(valid_data) train_data.extend(valid_data[:6]) valid_data = valid_data[6:] model = EncoderDecoder(args, device=device) print(model) NUM_DA_TYPES = len(DA_MAPPING) da_labeller = DALabeller(args['rnn_hidden_size'], NUM_DA_TYPES, device) print(da_labeller) ext_labeller = EXTLabeller(args['rnn_hidden_size'], device) print(ext_labeller) # Load model if specified (path to pytorch .pt) if args['load_model'] != None: model_path = args['load_model'] + '.pt' try: model.load_state_dict(torch.load(model_path)) except RuntimeError: # need to remove module # Main model model_state_dict = torch.load(model_path) new_model_state_dict = OrderedDict() for key in model_state_dict.keys(): new_model_state_dict[key.replace("module.","")] = model_state_dict[key] if args['memory_utt']: model.load_state_dict(new_model_state_dict, strict=False) else: model.load_state_dict(new_model_state_dict) model.train() print("Loaded model from {}".format(args['load_model'])) else: print("Train a new model") # Hyperparameters BATCH_SIZE = args['batch_size'] NUM_EPOCHS = args['num_epochs'] VAL_BATCH_SIZE = args['val_batch_size'] VAL_STOP_TRAINING = args['val_stop_training'] if args['label_smoothing'] > 0.0: criterion = LabelSmoothingLoss(num_classes=args['vocab_size'], smoothing=args['label_smoothing'], reduction='none') else: criterion = nn.NLLLoss(reduction='none') da_criterion = nn.NLLLoss(reduction='none') ext_criterion = nn.BCELoss(reduction='none') # ONLY train the momory part # # for name, param in model.named_parameters(): # if "utt" in name: # pass # else: # param.requires_grad = False # optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),lr=args['lr'],betas=(0.9,0.999),eps=1e-08,weight_decay=0) # -------------------------- # optimizer = optim.Adam(model.parameters(),lr=args['lr'],betas=(0.9,0.999),eps=1e-08,weight_decay=0) optimizer.zero_grad() # DA labeller optimiser da_optimizer = optim.Adam(da_labeller.parameters(),lr=args['lr'],betas=(0.9,0.999),eps=1e-08,weight_decay=0) da_optimizer.zero_grad() # extractive labeller optimiser ext_optimizer = optim.Adam(ext_labeller.parameters(),lr=args['lr'],betas=(0.9,0.999),eps=1e-08,weight_decay=0) ext_optimizer.zero_grad() # validation losses best_val_loss = args['best_val_loss'] best_epoch = 0 stop_counter = 0 training_step = 0 for epoch in range(NUM_EPOCHS): print("======================= Training epoch {} =======================".format(epoch)) num_train_data = len(train_data) # num_batches = int(num_train_data/BATCH_SIZE) + 1 num_batches = int(num_train_data/BATCH_SIZE) print("num_batches = {}".format(num_batches)) print("shuffle train data") random.shuffle(train_data) idx = 0 for bn in range(num_batches): input, u_len, w_len, target, tgt_len, _, dialogue_acts, extractive_label = get_a_batch( train_data, idx, BATCH_SIZE, args['num_utterances'], args['num_words'], args['summary_length'], args['summary_type'], device) # decoder target decoder_target, decoder_mask = shift_decoder_target(target, tgt_len, device, mask_offset=True) decoder_target = decoder_target.view(-1) decoder_mask = decoder_mask.view(-1) decoder_output, u_output, attn_scores, cov_scores, u_attn_scores = model(input, u_len, w_len, target) loss = criterion(decoder_output.view(-1, args['vocab_size']), decoder_target) loss = (loss * decoder_mask).sum() / decoder_mask.sum() # COVLOSS: # loss_cov = compute_covloss(attn_scores, cov_scores) # loss_cov = (loss_cov.view(-1) * decoder_mask).sum() / decoder_mask.sum() # Diversity Loss (4): intra_div, inter_div = diverisity_loss(u_attn_scores, decoder_target, u_len, tgt_len) if inter_div == 0: loss_div = 0 else: loss_div = intra_div/inter_div # multitask(2): dialogue act prediction da_output = da_labeller(u_output) loss_utt_mask = length2mask(u_len, BATCH_SIZE, args['num_utterances'], device) loss_da = da_criterion(da_output.view(-1, NUM_DA_TYPES), dialogue_acts.view(-1)).view(BATCH_SIZE, -1) loss_da = (loss_da * loss_utt_mask).sum() / loss_utt_mask.sum() # multitask(3): extractive label prediction ext_output = ext_labeller(u_output).squeeze(-1) loss_ext = ext_criterion(ext_output, extractive_label) loss_ext = (loss_ext * loss_utt_mask).sum() / loss_utt_mask.sum() # total_loss = loss + args['a_da']*loss_da + args['a_ext']*loss_ext + args['a_cov']*loss_cov total_loss = loss + args['a_da']*loss_da + args['a_ext']*loss_ext + args['a_div']*loss_div # total_loss = loss + args['a_da']*loss_da + args['a_ext']*loss_ext # total_loss = loss + args['a_div']*loss_div total_loss.backward() # loss.backward() idx += BATCH_SIZE if bn % args['update_nbatches'] == 0: # gradient_clipping max_norm = 0.5 nn.utils.clip_grad_norm_(model.parameters(), max_norm) nn.utils.clip_grad_norm_(da_labeller.parameters(), max_norm) nn.utils.clip_grad_norm_(ext_labeller.parameters(), max_norm) # update the gradients if args['adjust_lr']: adjust_lr(optimizer, args['initial_lr'], args['decay_rate'], training_step) adjust_lr(da_optimizer, args['initial_lr'], args['decay_rate'], training_step) adjust_lr(ext_optimizer, args['initial_lr'], args['decay_rate'], training_step) optimizer.step() optimizer.zero_grad() da_optimizer.step() da_optimizer.zero_grad() ext_optimizer.step() ext_optimizer.zero_grad() training_step += args['batch_size']*args['update_nbatches'] if bn % 1 == 0: print("[{}] batch {}/{}: loss = {:.5f} | loss_div = {:.5f} | loss_da = {:.5f} | loss_ext = {:.5f}". format(str(datetime.now()), bn, num_batches, loss, loss_div, loss_da, loss_ext)) # print("[{}] batch {}/{}: loss = {:.5f} | loss_da = {:.5f} | loss_ext = {:.5f}". # format(str(datetime.now()), bn, num_batches, loss, loss_da, loss_ext)) # print("[{}] batch {}/{}: loss = {:.5f} | loss_div = {:.5f}". # format(str(datetime.now()), bn, num_batches, loss, loss_div)) # print("[{}] batch {}/{}: loss = {:.5f}".format(str(datetime.now()), bn, num_batches, loss)) sys.stdout.flush() if bn % 10 == 0: print("======================== GENERATED SUMMARY ========================") print(bert_tokenizer.decode(torch.argmax(decoder_output[0], dim=-1).cpu().numpy()[:tgt_len[0]])) print("======================== REFERENCE SUMMARY ========================") print(bert_tokenizer.decode(decoder_target.view(BATCH_SIZE,args['summary_length'])[0,:tgt_len[0]].cpu().numpy())) if bn == 0: # e.g. eval every epoch # ---------------- Evaluate the model on validation data ---------------- # print("Evaluating the model at epoch {} step {}".format(epoch, bn)) print("learning_rate = {}".format(optimizer.param_groups[0]['lr'])) # switch to evaluation mode model.eval() da_labeller.eval() ext_labeller.eval() with torch.no_grad(): avg_val_loss = evaluate(model, valid_data, VAL_BATCH_SIZE, args, device, use_rouge=True) # avg_val_loss = evaluate_greedy(model, valid_data, VAL_BATCH_SIZE, args, device) print("avg_val_loss_per_token = {}".format(avg_val_loss)) # switch to training mode model.train() da_labeller.train() ext_labeller.train() # ------------------- Save the model OR Stop training ------------------- # state = { 'epoch': epoch, 'bn': bn, 'training_step': training_step, 'model': model.state_dict(), 'da_labeller': da_labeller.state_dict(), 'ext_labeller': ext_labeller.state_dict(), 'optimizer': optimizer.state_dict(), 'best_val_loss': best_val_loss } if avg_val_loss < best_val_loss: stop_counter = 0 best_val_loss = avg_val_loss best_epoch = epoch savepath = args['model_save_dir']+"model-{}-ep{}.pt".format(args['model_name'], 999) # 999 = best torch.save(state, savepath) print("Model improved & saved at {}".format(savepath)) else: print("Model not improved #{}".format(stop_counter)) savepath = args['model_save_dir']+"model-{}-ep{}.pt".format(args['model_name'], 000) # 000 = current torch.save(state, savepath) print("Model NOT improved & saved at {}".format(savepath)) if stop_counter < VAL_STOP_TRAINING: print("Just continue training ---- no loading old weights") stop_counter += 1 else: print("Model has not improved for {} times! Stop training.".format(VAL_STOP_TRAINING)) return print("End of training hierarchical RNN model")
def train_adaptive_bias(): print("Start training adaptive bias") # ---------------------------------------------------------------------------------- # args = {} args[ 'load_model'] = "/home/alta/summary/pm574/summariser1/lib/trained_models/model-HGRUV2_CNNDM_AMI_JAN24A-ep17.pt" args['num_utterances'] = 2000 # max no. utterance in a meeting args['num_words'] = 64 # max no. words in an utterance args['summary_length'] = 800 # max no. words in a summary args['summary_type'] = 'long' # long or short summary args['vocab_size'] = 30522 # BERT tokenizer args['dropout'] = 0.0 args['embedding_dim'] = 256 # word embeeding dimension args['rnn_hidden_size'] = 512 # RNN hidden size args['num_layers_enc'] = 1 # in total it's num_layers_enc*3 (word/utt/seg) args['num_layers_dec'] = 1 args['init_bias'] = 20 args['random_seed'] = 28 args['lr'] = 1.0 args[ 'abias_save_dir'] = "/home/alta/summary/pm574/summariser1/lib/adaptive_bias_weights/" args['abias_name'] = "ABIAS_FEB21A2" # ---------------------------------------------------------------------------------- # print_config(args) os.environ["CUDA_VISIBLE_DEVICES"] = '1' # choose the device (GPU) here device = 'cuda' train_data = load_ami_data('train') valid_data = load_ami_data('valid') random.seed(args['random_seed']) torch.manual_seed(args['random_seed']) adaptivebias = AdaptiveBias(args['vocab_size'], args['init_bias'], device) print(adaptivebias) model = EncoderDecoder(args, device) model_path = args['load_model'] try: model.load_state_dict(torch.load(model_path)) except RuntimeError: # need to remove module # Main model model_state_dict = torch.load(model_path) new_model_state_dict = OrderedDict() for key in model_state_dict.keys(): new_model_state_dict[key.replace("module.", "")] = model_state_dict[key] model.load_state_dict(new_model_state_dict) model.eval() for p in model.parameters(): p.requires_grad = False optimizer = optim.Adam(adaptivebias.parameters(), lr=args['lr'], betas=(0.9, 0.999), eps=1e-08, weight_decay=0) optimizer.zero_grad() criterion = LabelSmoothingLoss(num_classes=args['vocab_size'], smoothing=0.1, reduction='none') batch_size = 1 num_epochs = 5 time_step = args['summary_length'] vocab_size = args['vocab_size'] start_token_id = 101 # [CLS] stop_token_id = 103 # [MASK] for epoch in range(num_epochs): print( "======================= Training epoch {} =======================" .format(epoch)) num_train_data = len(train_data) num_batches = int(num_train_data / batch_size) print("num_batches = {}".format(num_batches)) random.shuffle(train_data) idx = 0 for bn in range(num_batches): input, u_len, w_len, target, tgt_len, _, _, _ = get_a_batch( train_data, idx, batch_size, args['num_utterances'], args['num_words'], args['summary_length'], args['summary_type'], device) # decoder target decoder_target, decoder_mask = shift_decoder_target( target, tgt_len, device, mask_offset=False) decoder_target = decoder_target.view(-1) decoder_mask = decoder_mask.view(-1) enc_output_dict = model.encoder(input, u_len, w_len) y_out = torch.ones((batch_size, time_step, vocab_size), dtype=torch.float).to(device) y_pred = torch.zeros((batch_size, time_step), dtype=torch.long).to(device) y_pred.fill_(stop_token_id) y_init = torch.zeros((batch_size, 1), dtype=torch.long).to(device) y_init.fill_(start_token_id) ht = model.decoder.init_h0(batch_size) for t in range(time_step - 1): # we have already obtained prediction up to time step 't' and want to predict 't+1' if t == 0: decoder_output, ht = model.decoder.forward_step( y_init, ht, enc_output_dict, logsoftmax=False) output = decoder_output else: decoder_output, ht = model.decoder.forward_step( y_pred[:, t - 1].unsqueeze(-1), ht, enc_output_dict, logsoftmax=False) # sum y_out from 0 upto t-1 cov_y = y_out[:, :t, :].sum(dim=1) # normalise cov_y cov_y = cov_y / cov_y.sum(dim=-1).unsqueeze(-1) bias = adaptivebias(cov_y) # maybe think about in what domain we should add this bias?? LogSoftmax?? output = decoder_output - bias y_out[:, t, :] = nn.functional.softmax(output, dim=-1) y_pred[:, t] = output.argmax(dim=-1) # if t % 100 == 0: print("t = {}".format(t)) #### ONLY WORKS WITH batch_size = 1 if y_pred[0, t] == stop_token_id: break log_y_out = torch.log(y_out) loss = criterion(log_y_out.view(-1, args['vocab_size']), decoder_target) loss = (loss * decoder_mask).sum() / decoder_mask.sum() loss.backward() print("[{}] batch {}/{}: loss = {:5f}".format( str(datetime.now()), bn, num_batches, loss)) sys.stdout.flush() idx += batch_size optimizer.step() optimizer.zero_grad() if bn % 20 == 0: print( "======================== GENERATED SUMMARY ========================" ) print( bert_tokenizer.decode( y_pred[0].cpu().numpy()[:tgt_len[0]])) print( "======================== REFERENCE SUMMARY ========================" ) print( bert_tokenizer.decode( decoder_target.view(batch_size, args['summary_length']) [0, :tgt_len[0]].cpu().numpy())) print("END OF EPOCH {}".format(epoch)) # Evaluation with torch.no_grad(): avg_val_loss = eval_model_with_bias(model, adaptivebias, valid_data, args, device) print("avg_val_loss_per_token = {}".format(avg_val_loss)) # Save the model savepath = args['abias_save_dir'] + "abias-{}-ep{}.pt".format( args['abias_name'], epoch + 1) torch.save(adaptivebias.state_dict(), savepath)
def eval_model_with_bias(model, adaptivebias, eval_data, args, device): num_eval_epochs = len(eval_data) from rouge import Rouge rouge = Rouge() bert_decoded_outputs = [] bert_decoded_targets = [] time_step = args['summary_length'] vocab_size = args['vocab_size'] start_token_id = 101 # [CLS] stop_token_id = 103 # [MASK] eval_idx = 0 for bn in range(num_eval_epochs): input, u_len, w_len, target, tgt_len, _, _, _ = get_a_batch( eval_data, eval_idx, 1, args['num_utterances'], args['num_words'], args['summary_length'], args['summary_type'], device) # decoder target decoder_target, decoder_mask = shift_decoder_target( target, tgt_len, device) decoder_target = decoder_target.view(-1) decoder_mask = decoder_mask.view(-1) enc_output_dict = model.encoder(input, u_len, w_len) y_out = torch.ones((1, time_step, vocab_size), dtype=torch.float).to(device) y_pred = torch.zeros((1, time_step), dtype=torch.long).to(device) y_pred.fill_(stop_token_id) y_init = torch.zeros((1, 1), dtype=torch.long).to(device) y_init.fill_(start_token_id) ht = model.decoder.init_h0(1) for t in range(time_step - 1): # we have already obtained prediction up to time step 't' and want to predict 't+1' if t == 0: decoder_output, ht = model.decoder.forward_step( y_init, ht, enc_output_dict, logsoftmax=False) output = decoder_output else: decoder_output, ht = model.decoder.forward_step( y_pred[:, t - 1].unsqueeze(-1), ht, enc_output_dict, logsoftmax=False) # sum y_out from 0 upto t-1 cov_y = y_out[:, :t, :].sum(dim=1) # normalise cov_y cov_y = cov_y / cov_y.sum(dim=-1).unsqueeze(-1) bias = adaptivebias(cov_y) # maybe think about in what domain we should add this bias?? LogSoftmax?? output = decoder_output - bias y_out[:, t, :] = nn.functional.softmax(output, dim=-1) y_pred[:, t] = output.argmax(dim=-1) if y_pred[0, t] == stop_token_id: break bert_decoded_output = bert_tokenizer.decode(y_pred[0].cpu().numpy()) stop_idx = bert_decoded_output.find('[MASK]') bert_decoded_output = bert_decoded_output[:stop_idx] bert_decoded_outputs.append(bert_decoded_output) bert_decoded_target = bert_tokenizer.decode( decoder_target.cpu().numpy()) stop_idx2 = bert_decoded_target.find('[MASK]') bert_decoded_target = bert_decoded_target[:stop_idx2] bert_decoded_targets.append(bert_decoded_target) eval_idx += 1 print("#", end="") sys.stdout.flush() print() try: scores = rouge.get_scores(bert_decoded_outputs, bert_decoded_targets, avg=True) return (scores['rouge-1']['f'] + scores['rouge-2']['f'] + scores['rouge-l']['f']) * (-100) / 3 except ValueError: return 0
def train_v6(): print("Start training hierarchical RNN model") # ---------------------------------------------------------------------------------- # args = {} args['use_gpu'] = True args['num_utterances'] = 1400 # max no. utterance in a meeting args['num_words'] = 50 # max no. words in an utterance args['summary_length'] = 280 # max no. words in a summary args['summary_type'] = 'short' # long or short summary args['vocab_size'] = 30522 # BERT tokenizer args['embedding_dim'] = 256 # word embeeding dimension args['rnn_hidden_size'] = 512 # RNN hidden size args['dropout'] = 0.1 args['num_layers_enc'] = 2 # in total it's num_layers_enc*2 (word/utt) args['num_layers_dec'] = 1 args['batch_size'] = 1 args['update_nbatches'] = 2 args['num_epochs'] = 30 args['random_seed'] = 666 args['best_val_loss'] = 1e+10 args['val_batch_size'] = 1 # 1 for now --- evaluate ROUGE args['val_stop_training'] = 30 args['lr'] = 1.0 args['adjust_lr'] = True # if True overwrite the learning rate above args['initial_lr'] = 0.002 # lr = lr_0*step^(-decay_rate) args['decay_rate'] = 0.5 args['label_smoothing'] = 0.1 args['a_da'] = 0.0 args['a_ext'] = 0.0 args['a_cov'] = 0.0 args['model_save_dir'] = "/home/alta/summary/pm574/summariser1/lib/trained_models2/" # args['load_model'] = None # args['load_model'] = "/home/alta/summary/pm574/summariser1/lib/trained_models2/model-HGRUV5_FEB28A-ep6" args['load_model'] = "/home/alta/summary/pm574/summariser1/lib/trained_models2/model-HGRUV5_CNNDM_FEB26A-ep17-bn0" # add .pt later args['model_name'] = 'HGRUV6_MAR3Bn' # ---------------------------------------------------------------------------------- # print_config(args) if args['use_gpu']: if 'X_SGE_CUDA_DEVICE' in os.environ: # to run on CUED stack machine print('running on the stack... 1 GPU') cuda_device = os.environ['X_SGE_CUDA_DEVICE'] print('X_SGE_CUDA_DEVICE is set to {}'.format(cuda_device)) os.environ['CUDA_VISIBLE_DEVICES'] = cuda_device else: print('running locally...') os.environ["CUDA_VISIBLE_DEVICES"] = '0' # choose the device (GPU) here device = 'cuda' else: device = 'cpu' print("device = {}".format(device)) # random seed random.seed(args['random_seed']) torch.manual_seed(args['random_seed']) np.random.seed(args['random_seed']) train_data = load_ami_data('train') valid_data = load_ami_data('valid') # make the training data 100 # random.shuffle(valid_data) # train_data.extend(valid_data[:6]) # valid_data = valid_data[6:] model = EncoderDecoder(args, device=device) print(model) # to use multiple GPUs if torch.cuda.device_count() > 1: print("Multiple GPUs: {}".format(torch.cuda.device_count())) model = nn.DataParallel(model) # Load model if specified (path to pytorch .pt) if args['load_model'] != None: model_path = args['load_model'] + '.pt' try: model.load_state_dict(torch.load(model_path)) except RuntimeError: # need to remove module # Main model model_state_dict = torch.load(model_path) new_model_state_dict = OrderedDict() for key in model_state_dict.keys(): new_model_state_dict[key.replace("module.","")] = model_state_dict[key] new_model_state_dict["decoder.mem_utt_d.weight"] = model.decoder.mem_utt_d.weight new_model_state_dict["decoder.mem_utt_y.weight"] = model.decoder.mem_utt_y.weight new_model_state_dict["decoder.mem_utt_s.weight"] = model.decoder.mem_utt_s.weight new_model_state_dict["decoder.mem_utt_s.bias"] = model.decoder.mem_utt_s.bias model.load_state_dict(new_model_state_dict) # model.decoder.mem_utt_s.bias.data.fill_(-5) model.train() print("Loaded model from {}".format(args['load_model'])) else: print("Train a new model") # Hyperparameters BATCH_SIZE = args['batch_size'] NUM_EPOCHS = args['num_epochs'] VAL_BATCH_SIZE = args['val_batch_size'] VAL_STOP_TRAINING = args['val_stop_training'] if args['label_smoothing'] > 0.0: criterion = LabelSmoothingLoss(num_classes=args['vocab_size'], smoothing=args['label_smoothing'], reduction='none') else: criterion = nn.NLLLoss(reduction='none') # we use two separate optimisers (encoder & decoder) optimizer = optim.Adam(model.parameters(),lr=args['lr'],betas=(0.9,0.999),eps=1e-08,weight_decay=0) optimizer.zero_grad() # validation losses best_val_loss = args['best_val_loss'] best_epoch = 0 stop_counter = 0 training_step = 0 for epoch in range(NUM_EPOCHS): print("======================= Training epoch {} =======================".format(epoch)) num_train_data = len(train_data) # num_batches = int(num_train_data/BATCH_SIZE) + 1 num_batches = int(num_train_data/BATCH_SIZE) print("num_batches = {}".format(num_batches)) print("shuffle train data") random.shuffle(train_data) idx = 0 for bn in range(num_batches): input, u_len, w_len, target, tgt_len, _, dialogue_acts, extractive_label = get_a_batch( train_data, idx, BATCH_SIZE, args['num_utterances'], args['num_words'], args['summary_length'], args['summary_type'], device) # decoder target decoder_target, decoder_mask = shift_decoder_target(target, tgt_len, device, mask_offset=True) decoder_target = decoder_target.view(-1) decoder_mask = decoder_mask.view(-1) decoder_output, u_output, attn_scores, cov_scores, u_attn_scores = model(input, u_len, w_len, target) loss = criterion(decoder_output.view(-1, args['vocab_size']), decoder_target) loss = (loss * decoder_mask).sum() / decoder_mask.sum() loss.backward() idx += BATCH_SIZE if bn % args['update_nbatches'] == 0: # gradient_clipping max_norm = 0.5 nn.utils.clip_grad_norm_(model.parameters(), max_norm) # update the gradients if args['adjust_lr']: adjust_lr(optimizer, args['initial_lr'], args['decay_rate'], training_step) optimizer.step() optimizer.zero_grad() training_step += args['batch_size']*args['update_nbatches'] if bn % 1 == 0: print("[{}] batch {}/{}: loss = {:.5f}". format(str(datetime.now()), bn, num_batches, loss)) sys.stdout.flush() if bn % 10 == 0: print("======================== GENERATED SUMMARY ========================") print(bert_tokenizer.decode(torch.argmax(decoder_output[0], dim=-1).cpu().numpy()[:tgt_len[0]])) print("======================== REFERENCE SUMMARY ========================") print(bert_tokenizer.decode(decoder_target.view(BATCH_SIZE,args['summary_length'])[0,:tgt_len[0]].cpu().numpy())) if bn == 0: # e.g. eval every epoch # ---------------- Evaluate the model on validation data ---------------- # print("Evaluating the model at epoch {} step {}".format(epoch, bn)) print("learning_rate = {}".format(optimizer.param_groups[0]['lr'])) # switch to evaluation mode model.eval() with torch.no_grad(): avg_val_loss = evaluate(model, valid_data, VAL_BATCH_SIZE, args, device, use_rouge=True) print("avg_val_loss_per_token = {}".format(avg_val_loss)) # switch to training mode model.train() # ------------------- Save the model OR Stop training ------------------- # if avg_val_loss < best_val_loss: stop_counter = 0 best_val_loss = avg_val_loss best_epoch = epoch savepath = args['model_save_dir']+"model-{}-ep{}.pt".format(args['model_name'],epoch) torch.save(model.state_dict(), savepath) print("Model improved & saved at {}".format(savepath)) else: print("Model not improved #{}".format(stop_counter)) if stop_counter < VAL_STOP_TRAINING: print("Just continue training ---- no loading old weights") stop_counter += 1 else: print("Model has not improved for {} times! Stop training.".format(VAL_STOP_TRAINING)) return print("End of training hierarchical RNN model")
def evaluate_beam(model, eval_data, eval_batch_size, args, device, use_rouge=False): # num_eval_epochs = int(eval_data['num_data']/eval_batch_size) + 1 num_eval_epochs = int(len(eval_data) / eval_batch_size) print("num_eval_epochs = {}".format(num_eval_epochs)) eval_idx = 0 eval_total_loss = 0.0 eval_total_tokens = 0 from rouge import Rouge rouge = Rouge() bert_decoded_outputs = [] bert_decoded_targets = [] decode_dict = { 'k': 10, 'search_method': 'argmax', 'time_step': args['summary_length'], 'vocab_size': 30522, 'device': device, 'start_token_id': 101, 'stop_token_id': 103, 'alpha': 1.5, 'length_offset': 5, 'penalty_ug': 0, 'keypadmask_dtype': torch.bool, 'batch_size': 1 } for bn in range(num_eval_epochs): input, u_len, w_len, target, tgt_len, _, _, _ = get_a_batch( eval_data, eval_idx, eval_batch_size, args['num_utterances'], args['num_words'], args['summary_length'], args['summary_type'], device) # decoder target decoder_target, decoder_mask = shift_decoder_target( target, tgt_len, device) decoder_target = decoder_target.view(-1) decoder_mask = decoder_mask.view(-1) # Inference time decoding sys.stdout = open(os.devnull, 'w') summaries_id = model.decode_beamsearch(input, u_len, w_len, decode_dict) summaries_id = summaries_id[0][1:] bert_decoded_output = bert_tokenizer.decode(summaries_id) stop_idx = bert_decoded_output.find('[MASK]') bert_decoded_output = bert_decoded_output[:stop_idx] bert_decoded_outputs.append(bert_decoded_output) bert_decoded_target = bert_tokenizer.decode( decoder_target.cpu().numpy()) stop_idx2 = bert_decoded_target.find('[MASK]') bert_decoded_target = bert_decoded_target[:stop_idx2] bert_decoded_targets.append(bert_decoded_target) sys.stdout = sys.__stdout__ eval_idx += eval_batch_size print("#", end="") sys.stdout.flush() print() try: scores = rouge.get_scores(bert_decoded_outputs, bert_decoded_targets, avg=True) return (scores['rouge-1']['f'] + scores['rouge-2']['f'] + scores['rouge-l']['f']) * (-100) / 3 except ValueError: return 0
def train_v5_cnndm(): print("Start training hierarchical RNN model") # ---------------------------------------------------------------------------------- # args = {} args['use_gpu'] = True args['num_utterances'] = 640 # max no. utterance in a meeting args['num_words'] = 50 # max no. words in an utterance args['summary_length'] = 144 # max no. words in a summary args['summary_type'] = 'long' # long or short summary args['vocab_size'] = 30522 # BERT tokenizer args['embedding_dim'] = 256 # word embeeding dimension args['rnn_hidden_size'] = 512 # RNN hidden size args['dropout'] = 0.1 args['num_layers_enc'] = 2 # in total it's num_layers_enc*2 (word/utt) args['num_layers_dec'] = 1 args['random_seed'] = 78 # args['a_div'] = 1.0 args['memory_utt'] = False args[ 'model_save_dir'] = "/home/alta/summary/pm574/summariser1/lib/trained_models_spotify/" args[ 'load_model'] = "/home/alta/summary/pm574/summariser1/lib/trained_models2/model-HGRUV5_CNNDMDIV_APR14A-ep02.pt" args['model_name'] = 'HGRUV5DIV_SPOTIFY_JUNE18_v2' # ---------------------------------------------------------------------------------- # print_config(args) if args['use_gpu']: if 'X_SGE_CUDA_DEVICE' in os.environ: # to run on CUED stack machine print('running on the stack... 1 GPU') cuda_device = os.environ['X_SGE_CUDA_DEVICE'] print('X_SGE_CUDA_DEVICE is set to {}'.format(cuda_device)) os.environ['CUDA_VISIBLE_DEVICES'] = cuda_device else: print('running locally...') os.environ[ "CUDA_VISIBLE_DEVICES"] = '0,1,2,3' # choose the device (GPU) here device = 'cuda' else: device = 'cpu' print("device = {}".format(device)) # random seed random.seed(args['random_seed']) torch.manual_seed(args['random_seed']) np.random.seed(args['random_seed']) # Data podcasts = load_podcast_data(sets=-1) batcher = HierBatcher(bert_tokenizer, args, podcasts, device) val_podcasts = load_podcast_data(sets=[10]) val_batcher = HierBatcher(bert_tokenizer, args, val_podcasts, device) model = EncoderDecoder(args, device=device) # print(model) # Load model if specified (path to pytorch .pt) state = torch.load(args['load_model']) model_state_dict = state['model'] model.load_state_dict(model_state_dict) print("load succesful #1") criterion = nn.NLLLoss(reduction='none') # we use two separate optimisers (encoder & decoder) optimizer = optim.Adam(model.parameters(), lr=2e-20, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) optimizer.zero_grad() # validation losses training_step = 0 batch_size = 4 * 4 gradient_accum = 1 total_step = 1000000 valid_step = 2000 best_val_loss = 99999999 # to use multiple GPUs if torch.cuda.device_count() > 1: print("Multiple GPUs: {}".format(torch.cuda.device_count())) model = nn.DataParallel(model) while training_step < total_step: # get a batch input, u_len, w_len, target, tgt_len = batcher.get_a_batch(batch_size) # decoder target decoder_target, decoder_mask = shift_decoder_target(target, tgt_len, device, mask_offset=True) decoder_target = decoder_target.view(-1) decoder_mask = decoder_mask.view(-1) # decoder_output, _, _, _, u_attn_scores = model(input, u_len, w_len, target) decoder_output = model(input, u_len, w_len, target) loss = criterion(decoder_output.view(-1, args['vocab_size']), decoder_target) loss = (loss * decoder_mask).sum() / decoder_mask.sum() loss.backward() # Diversity Loss: # if batch_size == 1: # intra_div, inter_div = diverisity_loss(u_attn_scores, decoder_target, u_len, tgt_len) # if inter_div == 0: # loss_div = 0 # else: # loss_div = intra_div/inter_div # else: # dec_target_i = 0 # loss_div = 0 # for bi in range(batch_size): # one_u_attn_scores = u_attn_scores[bi:bi+1,:,:] # one_decoder_target = decoder_target[dec_target_i:dec_target_i+args['summary_length']] # one_u_len = u_len[bi:bi+1] # one_tgt_len = tgt_len[bi:bi+1] # intra_div, inter_div = diverisity_loss(one_u_attn_scores, one_decoder_target, one_u_len, one_tgt_len) # if inter_div == 0: # loss_div += 0 # else: # loss_div += intra_div/inter_div # dec_target_i += args['summary_length'] # loss_div /= batch_size # # total_loss = loss + args['a_div']*loss_div # total_loss.backward() if training_step % gradient_accum == 0: adjust_lr(optimizer, training_step) optimizer.step() optimizer.zero_grad() if training_step % 1 == 0: # print("[{}] step {}/{}: loss = {:.5f} | loss_div = {:.5f}".format( # str(datetime.now()), training_step, total_step, loss, loss_div)) print("[{}] step {}/{}: loss = {:.5f}".format( str(datetime.now()), training_step, total_step, loss)) sys.stdout.flush() if training_step % 10 == 0: print( "======================== GENERATED SUMMARY ========================" ) print( bert_tokenizer.decode( torch.argmax(decoder_output[0], dim=-1).cpu().numpy()[:tgt_len[0]])) print( "======================== REFERENCE SUMMARY ========================" ) print( bert_tokenizer.decode( decoder_target.view( batch_size, args['summary_length'])[0, :tgt_len[0]].cpu().numpy())) if training_step % valid_step == 0: # ---------------- Evaluate the model on validation data ---------------- # print("Evaluating the model at training step {}".format( training_step)) print("learning_rate = {}".format(optimizer.param_groups[0]['lr'])) # switch to evaluation mode model.eval() with torch.no_grad(): valid_loss = evaluate(model, val_batcher, batch_size, args, device) print("valid_loss = {}".format(valid_loss)) # switch to training mode model.train() if valid_loss < best_val_loss: stop_counter = 0 best_val_loss = valid_loss print("Model improved".format(stop_counter)) else: stop_counter += 1 print("Model not improved #{}".format(stop_counter)) if stop_counter == 3: print("Stop training!") return state = { 'training_step': training_step, 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'best_val_loss': best_val_loss } savepath = args['model_save_dir'] + "{}-step{}.pt".format( args['model_name'], training_step) torch.save(state, savepath) print("Saved at {}".format(savepath)) training_step += 1 print("End of training hierarchical RNN model")
def decode_beamsearch(self, input, u_len, w_len, decode_dict): k = decode_dict['k'] search_method = decode_dict['search_method'] batch_size = decode_dict['batch_size'] time_step = decode_dict['time_step'] vocab_size = decode_dict['vocab_size'] device = decode_dict['device'] start_token_id = decode_dict['start_token_id'] stop_token_id = decode_dict['stop_token_id'] alpha = decode_dict['alpha'] # length_offset = decode_dict['length_offset'] keypadmask_dtype = decode_dict['keypadmask_dtype'] # create beam array & scores beams = [None for _ in range(k)] beam_scores = np.zeros((batch_size, k)) # we should only feed through the encoder just once!! uw_len, enc_output, input_pgn = self.forward_encoder( input, u_len, w_len) # memory # we run the decoder time_step times (auto-regressive) tgt_ids = torch.zeros((batch_size, time_step), dtype=torch.int64).to(device) tgt_ids[:, 0] = start_token_id for i in range(k): beams[i] = tgt_ids finished_beams = [[] for _ in range(batch_size)] beam_htct = [self.decoder_init_h0c0(batch_size) for _ in range(k)] finish = False for t in range(time_step - 1): if finish: break decoder_output_t_array = torch.zeros((batch_size, k * vocab_size)) for i, beam in enumerate(beams): # inference decoding decoder_output, beam_htct[i] = self.forward_decoder_step( uw_len, enc_output, beam[:, t:t + 1], beam_htct[i][0], beam_htct[i][1], input_pgn, training=False) # check if there is STOP_TOKEN emitted in the previous time step already # i.e. if the input at this time step is STOP_TOKEN for n_idx in range(batch_size): if beam[n_idx][t] == stop_token_id: # already stop decoder_output[n_idx, :] = float('-inf') decoder_output[ n_idx, stop_token_id] = 0.0 # to ensure STOP_TOKEN will be picked again! decoder_output_t_array[:, i * vocab_size:(i + 1) * vocab_size] = decoder_output # add previous beam score bias for n_idx in range(batch_size): decoder_output_t_array[n_idx, i * vocab_size:(i + 1) * vocab_size] += beam_scores[n_idx, i] # only support batch_size = 1! if t == 0: decoder_output_t_array[n_idx, (i + 1) * vocab_size:] = float('-inf') break if search_method == 'sampling': # Sampling scores = np.zeros((batch_size, k)) indices = np.zeros((batch_size, k)) pmf = np.exp(decoder_output_t_array.cpu().numpy()) for bi in range(batch_size): if pmf[bi].sum() != 1.0: pmf[bi] /= pmf[bi].sum() sampled_ids = np.random.choice(k * vocab_size, size=k, p=pmf[bi]) for _s, s_id in enumerate(sampled_ids): scores[bi, _s] = decoder_output_t_array[bi, s_id] indices[bi, _s] = s_id elif search_method == 'argmax': # Argmax topk_scores, topk_ids = torch.topk(decoder_output_t_array, k, dim=-1) scores = topk_scores.double().cpu().numpy() indices = topk_ids.double().cpu().numpy() new_beams = [ torch.zeros((batch_size, time_step), dtype=torch.int64).to(device) for _ in range(k) ] for r_idx, row in enumerate(indices): for c_idx, node in enumerate(row): vocab_idx = node % vocab_size beam_idx = int(node / vocab_size) new_beams[c_idx][r_idx, :t + 1] = beams[beam_idx][r_idx, :t + 1] new_beams[c_idx][r_idx, t + 1] = vocab_idx # if there is a beam that has [END_TOKEN] --- store it if vocab_idx == stop_token_id: finished_beams[r_idx].append( new_beams[c_idx][r_idx, :t + 1 + 1]) scores[r_idx, c_idx] = float('-inf') # only support BATCH SIZE = 1 count_stop = 0 for ik in range(k): if scores[0, ik] == float('-inf'): count_stop += 1 if count_stop == k: finish = True beams = new_beams if search_method == 'sampling': # normalisation the score scores = np.exp(scores) scores = scores / scores.sum(axis=-1).reshape(batch_size, 1) beam_scores = np.log(scores + 1e-20) # suppress warning log(zero) elif search_method == 'argmax': beam_scores = scores print( "========================= t = {} =========================". format(t)) for ik in range(k): print( "beam{}: [{:.5f}]".format(ik, scores[0, ik]), bert_tokenizer.decode(beams[ik][0].cpu().numpy()[:t + 2])) # pdb.set_trace() if (t % 50) == 0: print("{}=".format(t), end="") sys.stdout.flush() print("{}=#".format(t)) for bi in range(batch_size): if len(finished_beams[bi]) == 0: finished_beams[bi].append(beams[0][bi]) summaries_id = [None for _ in range(batch_size)] # for j in range(batch_size): summaries_id[j] = beams[0][j].cpu().numpy() for j in range(batch_size): _scores = self.beam_scoring(finished_beams[j], enc_output, uw_len, input_pgn, alpha) summaries_id[j] = finished_beams[j][np.argmax( _scores)].cpu().numpy() print(bert_tokenizer.decode(summaries_id[j])) return summaries_id
def train_v5(): print("Start training hierarchical RNN model") # ---------------------------------------------------------------------------------- # args = {} args['use_gpu'] = True args['num_utterances'] = 36 # max no. utterance in a meeting args['num_words'] = 72 # max no. words in an utterance args['summary_length'] = 144 # max no. words in a summary args['summary_type'] = 'long' # long or short summary args['vocab_size'] = 30522 # BERT tokenizer args['embedding_dim'] = 128 # word embeeding dimension args['rnn_hidden_size'] = 256 # RNN hidden size args['dropout'] = 0.1 args['num_layers_enc'] = 1 args['num_layers_dec'] = 1 args['batch_size'] = 128 args['update_nbatches'] = 1 args['num_epochs'] = 40 args['random_seed'] = 29 args['best_val_loss'] = 1e+10 args['val_batch_size'] = 128 # 1 for now --- evaluate ROUGE args['val_stop_training'] = 5 args['adjust_lr'] = True # if True overwrite the learning rate above args['initial_lr'] = 0.02 # lr = lr_0*step^(-decay_rate) args['decay_rate'] = 0.20 args['model_save_dir'] = "/home/alta/summary/pm574/summariser1/lib/trained_models3/" args['load_model'] = "/home/alta/summary/pm574/summariser1/lib/trained_models3/model-HGRUV52_CNNDM_APR2B-ep0.pt" # add .pt later # args['load_model'] = None args['model_name'] = 'HGRUV52_CNNDM_APR2C' # ---------------------------------------------------------------------------------- # print_config(args) if args['use_gpu']: if 'X_SGE_CUDA_DEVICE' in os.environ: # to run on CUED stack machine print('running on the stack... 1 GPU') cuda_device = os.environ['X_SGE_CUDA_DEVICE'] print('X_SGE_CUDA_DEVICE is set to {}'.format(cuda_device)) os.environ['CUDA_VISIBLE_DEVICES'] = cuda_device else: print('running locally...') os.environ["CUDA_VISIBLE_DEVICES"] = '0,1' # choose the device (GPU) here device = 'cuda' else: device = 'cpu' print("device = {}".format(device)) # random seed random.seed(args['random_seed']) torch.manual_seed(args['random_seed']) np.random.seed(args['random_seed']) args['model_data_dir'] = "/home/alta/summary/pm574/summariser0/lib/model_data/" args['max_pos_embed'] = 512 args['max_num_sentences'] = 32 args['max_summary_length'] = args['summary_length'] train_data = load_cnndm_data(args, 'trainx', dump=False) # train_data = load_cnndm_data(args, 'test', dump=False) # print("loaded TEST data") valid_data = load_cnndm_data(args, 'valid', dump=False) model = EncoderDecoder(args, device=device) print(model) optimizer = optim.Adam(model.parameters(),lr=0.77,betas=(0.9,0.999),eps=1e-08,weight_decay=0) # Load model if specified (path to pytorch .pt) if args['load_model'] != None: state = torch.load(args['load_model']) model_state_dict = state['model'] optimizer_state_dict = state['optimizer'] try: model.load_state_dict(model_state_dict) except RuntimeError: # need to remove module new_model_state_dict = OrderedDict() for key in model_state_dict.keys(): new_model_state_dict[key.replace("module.","")] = model_state_dict[key] model.load_state_dict(new_model_state_dict) model.train() optimizer.load_state_dict(optimizer_state_dict) training_step = state['training_step'] print("Loaded model from {}".format(args['load_model'])) print("continue from training_step {}".format(training_step)) else: training_step = 0 print("Train a new model") # to use multiple GPUs if torch.cuda.device_count() > 1: print("Multiple GPUs: {}".format(torch.cuda.device_count())) model = nn.DataParallel(model) # Hyperparameters BATCH_SIZE = args['batch_size'] NUM_EPOCHS = args['num_epochs'] VAL_BATCH_SIZE = args['val_batch_size'] VAL_STOP_TRAINING = args['val_stop_training'] criterion = nn.NLLLoss(reduction='none') # validation losses best_val_loss = args['best_val_loss'] best_epoch = 0 stop_counter = 0 optimizer.zero_grad() for epoch in range(NUM_EPOCHS): print("======================= Training epoch {} =======================".format(epoch)) num_train_data = len(train_data) # num_batches = int(num_train_data/BATCH_SIZE) + 1 num_batches = int(num_train_data/BATCH_SIZE) print("num_batches = {}".format(num_batches)) print("shuffle train data") random.shuffle(train_data) idx = 0 for bn in range(num_batches): input, u_len, w_len, target, tgt_len = get_a_batch( train_data, idx, BATCH_SIZE, args['num_utterances'], args['num_words'], args['summary_length'], args['summary_type'], device) if u_len.min().item() == 0: print("BN = {}: u_len_min = 0 --- ERROR, just skip this batch!!!".format(bn)) continue # decoder target decoder_target, decoder_mask = shift_decoder_target(target, tgt_len, device, mask_offset=True) decoder_target = decoder_target.view(-1) decoder_mask = decoder_mask.view(-1) try: decoder_output = model(input, u_len, w_len, target) # decoder_output, _, attn_scores, _, u_attn_scores = model(input, u_len, w_len, target) except IndexError: print("there is an IndexError --- likely from if segment_indices[bn][-1] == u_len[bn]-1:") print("for now just skip this batch!") idx += BATCH_SIZE # previously I forget to add this line!!! continue loss = criterion(decoder_output.view(-1, args['vocab_size']), decoder_target) loss = (loss * decoder_mask).sum() / decoder_mask.sum() loss.backward() idx += BATCH_SIZE if bn % args['update_nbatches'] == 0: # gradient_clipping max_norm = 2.0 nn.utils.clip_grad_norm_(model.parameters(), max_norm) # update the gradients if args['adjust_lr']: adjust_lr(optimizer, args['initial_lr'], args['decay_rate'], training_step) optimizer.step() optimizer.zero_grad() training_step += args['batch_size']*args['update_nbatches'] if bn % 1 == 0: print("[{}] batch {}/{}: loss = {:5f}".format(str(datetime.now()), bn, num_batches, loss)) sys.stdout.flush() if bn % 50 == 0: print("======================== GENERATED SUMMARY ========================") print(bert_tokenizer.decode(torch.argmax(decoder_output[0], dim=-1).cpu().numpy()[:tgt_len[0]])) print("======================== REFERENCE SUMMARY ========================") print(bert_tokenizer.decode(decoder_target.view(BATCH_SIZE,args['summary_length'])[0,:tgt_len[0]].cpu().numpy())) if bn % 600 == 0: # ---------------- Evaluate the model on validation data ---------------- # print("Evaluating the model at epoch {} step {}".format(epoch, bn)) print("learning_rate = {}".format(optimizer.param_groups[0]['lr'])) # switch to evaluation mode model.eval() with torch.no_grad(): avg_val_loss = evaluate(model, valid_data, VAL_BATCH_SIZE, args, device) print("avg_val_loss_per_token = {}".format(avg_val_loss)) # switch to training mode model.train() # ------------------- Save the model OR Stop training ------------------- # if avg_val_loss < best_val_loss: stop_counter = 0 best_val_loss = avg_val_loss best_epoch = epoch state = { 'epoch': epoch, 'bn': bn, 'training_step': training_step, 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'best_val_loss': best_val_loss } savepath = args['model_save_dir']+"model-{}-ep{}.pt".format(args['model_name'],epoch) torch.save(state, savepath) print("Model improved & saved at {}".format(savepath)) else: print("Model not improved #{}".format(stop_counter)) if stop_counter < VAL_STOP_TRAINING: print("Just continue training ---- no loading old weights") stop_counter += 1 else: print("Model has not improved for {} times! Stop training.".format(VAL_STOP_TRAINING)) return print("End of training hierarchical RNN model")
def grad_sampling(model, input, u_len, w_len, target, num_samples, lambda1, device, printout): total_r1 = 0 total_r2 = 0 total_rl = 0 batch_size = input.size(0) if batch_size != 1: raise ValueError("batch_size error") # tgt -> reference (y) reference = bert_tokenizer.decode(target[0].cpu().numpy()) stop_idx = reference.find('[SEP]') reference = reference[:stop_idx] time_step = len( bert_tokenizer.encode(reference)) + 1 # plus 1 just in case if printout: print("reference: {}".format(reference)) print( "-----------------------------------------------------------------------------------------" ) grads = [None for _ in range(num_samples)] metrics = [None for _ in range(num_samples)] for i in range(num_samples): # forward-pass ENCODER --- need to do forward pass again as autograd freed up memory enc_output_dict = model.encoder(input, u_len, w_len) # memory u_output = enc_output_dict['u_output'] # forward-pass DECODER xt = torch.zeros((batch_size, 1), dtype=torch.int64).to(device) xt.fill_(START_TOKEN_ID) # 101 # initial hidden state ht = torch.zeros((model.decoder.num_layers, batch_size, model.decoder.dec_hidden_size), dtype=torch.float).to(device) for bn, l in enumerate(u_len): ht[:, bn, :] = u_output[bn, l - 1, :].unsqueeze(0) log_prob_seq = 0 generated_tokens = [] for t in range(time_step - 1): decoder_output, ht, _ = model.decoder.forward_step( xt, ht, enc_output_dict, logsoftmax=False) output_prob = F.softmax(decoder_output, dim=-1) m = Categorical(output_prob) sample = m.sample() log_prob_t = m.log_prob(sample) xt = sample.unsqueeze(-1) log_prob_seq += log_prob_t token = sample.item() generated_tokens.append(token) if token == SEP_TOKEN_ID: break # generated_tokens -> hypothesis (y_hat) hypothesis = bert_tokenizer.decode(generated_tokens) stop_idx = hypothesis.find('[SEP]') if stop_idx != -1: hypothesis = hypothesis[:stop_idx] # Compute D(y, y_hat) try: scores = rouge.get_scores(hypothesis, reference) r1 = scores[0]['rouge-1']['f'] r2 = scores[0]['rouge-2']['f'] rl = scores[0]['rouge-l']['f'] except ValueError: r1 = 0 r2 = 0 r3 = 0 metric = -1 * (r1 + r2 + rl) # since we 'minimise' the criterion if printout: print("sample{} [{:.2f}]: {}".format(i, -100 * metric, hypothesis)) print( "-----------------------------------------------------------------------------------------" ) total_r1 += r1 total_r2 += r2 total_rl += rl # scale to gradient by metric # log_prob_seq *= metric # log_prob_seq *= lambda1 # log_prob_seq.backward() # len(grad) = the number of model.parameters() --- checked! grad = torch.autograd.grad(log_prob_seq, model.parameters()) grads[i] = grad metrics[i] = metric mean_x = sum(metrics) / len(metrics) metrics = [xi - mean_x for xi in metrics] # for param in model.parameters(): param.grad /= num_samples for i in range(num_samples): for n, param in enumerate(model.parameters()): if i == 0: param.grad = metrics[i] * grads[i][n] else: param.grad += metrics[i] * grads[i][n] for param in model.parameters(): param.grad *= lambda1 / num_samples r1_avg = total_r1 / num_samples r2_avg = total_r2 / num_samples rl_avg = total_rl / num_samples return r1_avg, r2_avg, rl_avg
def beamsearch_approx(model, input, u_len, w_len, target, beam_width, device, printout): vocab_size = VOCAB_SIZE total_r1 = 0 total_r2 = 0 total_rl = 0 batch_size = input.size(0) if batch_size != 1: raise ValueError("batch_size error") # tgt -> reference (y) reference = bert_tokenizer.decode(target[0].cpu().numpy()) stop_idx = reference.find('[SEP]') reference = reference[:stop_idx] time_step = len( bert_tokenizer.encode(reference)) + 1 # plus 1 just in case if printout: print("reference: {}".format(reference)) print( "-----------------------------------------------------------------------------------------" ) # forward-pass ENCODER --- need to do forward pass again as autograd freed up memory enc_output_dict = model.encoder(input, u_len, w_len) # memory u_output = enc_output_dict['u_output'] # initial hidden state ht = torch.zeros( (model.decoder.num_layers, batch_size, model.decoder.dec_hidden_size), dtype=torch.float).to(device) for bn, l in enumerate(u_len): ht[:, bn, :] = u_output[bn, l - 1, :].unsqueeze(0) beam_ht = [None for _ in range(beam_width)] for _k in range(beam_width): beam_ht[_k] = ht.clone() # beam xt beam_xt = [None for _ in range(beam_width)] for i in range(beam_width): xt = torch.zeros((batch_size, 1), dtype=torch.int64).to(device) xt.fill_(START_TOKEN_ID) # 101 beam_xt[i] = xt beam_scores = [0.0 for _ in range(beam_width)] beam_generated_tokens = [[] for _ in range(beam_width)] for t in range(time_step - 1): decoder_output_t_array = torch.zeros( (batch_size, beam_width * vocab_size)) temp_ht = [None for _ in range(beam_width)] for i in range(beam_width): decoder_output, temp_ht[i], _ = model.decoder.forward_step( beam_xt[i], beam_ht[i], enc_output_dict, logsoftmax=True) decoder_output_t_array[0, i * vocab_size:(i + 1) * vocab_size] = decoder_output decoder_output_t_array[0, i * vocab_size:(i + 1) * vocab_size] += beam_scores[i] if t == 0: decoder_output_t_array[0, (i + 1) * vocab_size:] = float('-inf') break topk_scores, topk_ids = torch.topk(decoder_output_t_array, beam_width, dim=-1) scores = topk_scores[0] indices = topk_ids[0] new_beams_scores = [None for _ in range(beam_width)] new_beam_generated_tokens = [None for _ in range(beam_width)] for i in range(beam_width): vocab_idx = indices[i] % vocab_size beam_idx = int(indices[i] / vocab_size) new_beams_scores[i] = scores[i] new_beam_generated_tokens[i] = list( beam_generated_tokens[beam_idx]) new_beam_generated_tokens[i].append(vocab_idx.item()) beam_ht[i] = temp_ht[beam_idx] xt = torch.zeros((batch_size, 1), dtype=torch.int64).to(device) xt.fill_(vocab_idx) # 101 beam_xt[i] = xt beam_scores = new_beams_scores beam_generated_tokens = new_beam_generated_tokens if t % 10 == 0: print("#", end="") sys.stdout.flush() print() # print("========================= t = {} =========================".format(t)) # for ik in range(beam_width): # print("beam{}: [{:.5f}]".format(ik, beam_scores[ik]),bert_tokenizer. # decode(beam_generated_tokens[ik])) # Normalise the probablilty sum_prob = 0.0 for i in range(beam_width): sum_prob += torch.exp(beam_scores[i]) norm_probs = [None for _ in range(beam_width)] for i in range(beam_width): norm_probs[i] = torch.exp(beam_scores[i]) / sum_prob for i in range(beam_width): generated_tokens = beam_generated_tokens[i] # generated_tokens -> hypothesis (y_hat) hypothesis = bert_tokenizer.decode(generated_tokens) stop_idx = hypothesis.find('[SEP]') if stop_idx != -1: hypothesis = hypothesis[:stop_idx] if printout: print("beam{}: {}".format(i, hypothesis)) print( "-----------------------------------------------------------------------------------------" ) # Compute D(y, y_hat) scores = rouge.get_scores(hypothesis, reference) r1 = scores[0]['rouge-1']['f'] r2 = scores[0]['rouge-2']['f'] rl = scores[0]['rouge-l']['f'] metric = -1 * (r1 + r2 + rl) # since we 'minimise' the criterion total_r1 += r1 total_r2 += r2 total_rl += rl # scale to gradient by metric this_loss = norm_probs[i] * metric this_loss.backward(retain_graph=True) r1_avg = total_r1 / beam_width r2_avg = total_r2 / beam_width rl_avg = total_rl / beam_width return r1_avg, r2_avg, rl_avg
def decode_beamsearch_adaptive_bias(self, adaptivebias, input, u_len, w_len, decode_dict): """ this method is meant to be used at inference time input = input to the encoder u_len = utterance lengths w_len = word lengths decode_dict: - k = beamwidth for beamsearch - batch_size = batch_size - time_step = max_summary_length - vocab_size = 30522 for BERT - device = cpu or cuda - start_token_id = ID of the start token - stop_token_id = ID of the stop token - alpha = length normalisation - length_offset = length offset - keypadmask_dtype = torch.bool """ k = decode_dict['k'] search_method = decode_dict['search_method'] batch_size = decode_dict['batch_size'] time_step = decode_dict['time_step'] vocab_size = decode_dict['vocab_size'] device = decode_dict['device'] start_token_id = decode_dict['start_token_id'] stop_token_id = decode_dict['stop_token_id'] alpha = decode_dict['alpha'] penalty_ug = decode_dict['penalty_ug'] keypadmask_dtype = decode_dict['keypadmask_dtype'] # create beam array & scores beams = [None for _ in range(k)] beam_scores = np.zeros((batch_size, k)) # we should only feed through the encoder just once!! enc_output_dict = self.encoder(input, u_len, w_len) # memory # we run the decoder time_step times (auto-regressive) tgt_ids = torch.zeros((batch_size, time_step), dtype=torch.int64).to(device) tgt_ids[:, 0] = start_token_id for i in range(k): beams[i] = tgt_ids finished_beams = [[] for _ in range(batch_size)] beam_ht = [self.decoder.init_h0(batch_size) for _ in range(k)] finish = False y_out = [ torch.ones((1, time_step, vocab_size), dtype=torch.float).to(device) for _ in range(k) ] for t in range(time_step - 1): if finish: break decoder_output_t_array = torch.zeros((batch_size, k * vocab_size)) for i, beam in enumerate(beams): # inference decoding # decoder_output = self.decoder(beam[:,:t+1], s_output, s_len, logsoftmax=True)[:,-1,:] if t == 0: decoder_output, beam_ht[i] = self.decoder.forward_step( beam[:, t:t + 1], beam_ht[i], enc_output_dict, logsoftmax=False) output = decoder_output else: decoder_output, beam_ht[i] = self.decoder.forward_step( beam[:, t:t + 1], beam_ht[i], enc_output_dict, logsoftmax=False) cov_y = y_out[i][:, :t, :].sum(dim=1) cov_y = cov_y / cov_y.sum(dim=-1).unsqueeze(-1) bias = adaptivebias(cov_y) output = decoder_output - bias y_out[i][:, t, :] = F.softmax(output, dim=-1) decoder_output = F.log_softmax(output, dim=-1) # check if there is STOP_TOKEN emitted in the previous time step already # i.e. if the input at this time step is STOP_TOKEN for n_idx in range(batch_size): if beam[n_idx][t] == stop_token_id: # already stop decoder_output[n_idx, :] = float('-inf') decoder_output[ n_idx, stop_token_id] = 0.0 # to ensure STOP_TOKEN will be picked again! decoder_output_t_array[:, i * vocab_size:(i + 1) * vocab_size] = decoder_output # add previous beam score bias for n_idx in range(batch_size): decoder_output_t_array[n_idx, i * vocab_size:(i + 1) * vocab_size] += beam_scores[n_idx, i] # only support batch_size = 1! if t == 0: decoder_output_t_array[n_idx, (i + 1) * vocab_size:] = float('-inf') break if search_method == 'argmax_abias': # Argmax topk_scores, topk_ids = torch.topk(decoder_output_t_array, k, dim=-1) scores = topk_scores.double().cpu().numpy() indices = topk_ids.double().cpu().numpy() new_beams = [ torch.zeros((batch_size, time_step), dtype=torch.int64).to(device) for _ in range(k) ] for r_idx, row in enumerate(indices): for c_idx, node in enumerate(row): vocab_idx = node % vocab_size beam_idx = int(node / vocab_size) new_beams[c_idx][r_idx, :t + 1] = beams[beam_idx][r_idx, :t + 1] new_beams[c_idx][r_idx, t + 1] = vocab_idx # if there is a beam that has [END_TOKEN] --- store it if vocab_idx == stop_token_id: finished_beams[r_idx].append( new_beams[c_idx][r_idx, :t + 1 + 1]) scores[r_idx, c_idx] = float('-inf') # only support BATCH SIZE = 1 count_stop = 0 for ik in range(k): if scores[0, ik] == float('-inf'): count_stop += 1 if count_stop == k: finish = True beams = new_beams if search_method == 'argmax_abias': beam_scores = scores # print("========================= t = {} =========================".format(t)) # for ik in range(k): # print("beam{}: [{:.5f}]".format(ik, scores[0,ik]),bert_tokenizer.decode(beams[ik][0].cpu().numpy()[:t+2])) if (t % 50) == 0: print("{}=".format(t), end="") sys.stdout.flush() print("{}=#".format(t)) for bi in range(batch_size): if len(finished_beams[bi]) == 0: finished_beams[bi].append(beams[0][bi]) summaries_id = [None for _ in range(batch_size)] # for j in range(batch_size): summaries_id[j] = beams[0][j].cpu().numpy() for j in range(batch_size): _scores = self.beam_scoring(finished_beams[j], enc_output_dict, alpha) summaries_id[j] = finished_beams[j][np.argmax( _scores)].cpu().numpy() print(bert_tokenizer.decode(summaries_id[j])) return summaries_id
def train2(): print("Start training hierarchical RNN model") # ---------------------------------------------------------------------------------- # args = {} args['use_gpu'] = True args[ 'air_multi_gpu'] = False # to enable running on multiple GPUs on stack args['num_utterances'] = 2000 # max no. utterance in a meeting args['num_words'] = 64 # max no. words in an utterance args['summary_length'] = 250 # max no. words in a summary args['summary_type'] = 'short' # long or short summary args['vocab_size'] = 30522 # BERT tokenizer args['embedding_dim'] = 128 # word embeeding dimension args['rnn_hidden_size'] = 256 # RNN hidden size args['dropout'] = 0.0 args['num_layers_enc'] = 1 # in total it's num_layers_enc*3 (word/utt/seg) args['num_layers_dec'] = 1 args['batch_size'] = 2 args['update_nbatches'] = 1 # 0 meaning whole batch update & using SGD args['num_epochs'] = 1000 args['random_seed'] = 28 args['best_val_loss'] = 1e+10 args['val_batch_size'] = args['batch_size'] args['val_stop_training'] = 10 args['adjust_lr'] = True # if True overwrite the learning rate above args['initial_lr'] = 0.01 # lr = lr_0*step^(-decay_rate) args['decay_rate'] = 0.5 args['label_smoothing'] = 0.1 # --- PGN parameters --- # args['num_words_meeting'] = 8400 args[ 'model_save_dir'] = "/home/alta/summary/pm574/summariser1/lib/trained_models/" # args['load_model'] = "/home/alta/summary/pm574/summariser1/lib/trained_models/model-HGRUL_DEC10B-ep20.pt" args['load_model'] = None args['model_name'] = 'PGN_DEC17F' # ---------------------------------------------------------------------------------- # print_config(args) if args['use_gpu']: if 'X_SGE_CUDA_DEVICE' in os.environ: # to run on CUED stack machine if not args['air_multi_gpu']: print('running on the stack... 1 GPU') cuda_device = os.environ['X_SGE_CUDA_DEVICE'] print('X_SGE_CUDA_DEVICE is set to {}'.format(cuda_device)) os.environ['CUDA_VISIBLE_DEVICES'] = cuda_device else: print('running on the stack... multiple GPUs') os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3' write_multi_sl(args['model_name']) else: # pdb.set_trace() print('running locally...') os.environ[ "CUDA_VISIBLE_DEVICES"] = '0,1' # choose the device (GPU) here device = 'cuda' else: device = 'cpu' print("device = {}".format(device)) # random seed random.seed(args['random_seed']) torch.manual_seed(args['random_seed']) np.random.seed(args['random_seed']) train_data = load_ami_data('train') valid_data = load_ami_data('valid') model = PointerGeneratorNetwork(args, device=device) print(model) # Load model if specified (path to pytorch .pt) if args['load_model'] != None: model.load_state_dict(torch.load(args['load_model'])) model.train() print("Loaded model from {}".format(args['load_model'])) else: print("Train a new model") # to use multiple GPUs if torch.cuda.device_count() > 1: print("Multiple GPUs: {}".format(torch.cuda.device_count())) model = nn.DataParallel(model) # Hyperparameters BATCH_SIZE = args['batch_size'] NUM_EPOCHS = args['num_epochs'] VAL_BATCH_SIZE = args['val_batch_size'] VAL_STOP_TRAINING = args['val_stop_training'] if args['label_smoothing'] > 0.0: criterion = LabelSmoothingLoss(num_classes=args['vocab_size'], smoothing=args['label_smoothing'], reduction='none') else: criterion = nn.NLLLoss(reduction='none') # we use two separate optimisers (encoder & decoder) optimizer = optim.Adam(model.parameters(), lr=0.01, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) optimizer.zero_grad() # validation losses best_val_loss = args['best_val_loss'] best_epoch = 0 stop_counter = 0 training_step = 0 for epoch in range(NUM_EPOCHS): print( "======================= Training epoch {} =======================" .format(epoch)) num_train_data = len(train_data) # num_batches = int(num_train_data/BATCH_SIZE) + 1 num_batches = int(num_train_data / BATCH_SIZE) print("num_batches = {}".format(num_batches)) print("shuffle train data") random.shuffle(train_data) idx = 0 for bn in range(num_batches): input, u_len, w_len, target, tgt_len = get_a_batch( train_data, idx, BATCH_SIZE, args['num_utterances'], args['num_words'], args['summary_length'], args['summary_type'], device) # decoder target decoder_target, decoder_mask = shift_decoder_target( target, tgt_len, device) decoder_target = decoder_target.view(-1) decoder_mask = decoder_mask.view(-1) decoder_output = model(input, u_len, w_len, target) loss = criterion(decoder_output.view(-1, args['vocab_size']), decoder_target) loss = (loss * decoder_mask).sum() / decoder_mask.sum() loss.backward() idx += BATCH_SIZE if bn % args['update_nbatches'] == 0: # gradient_clipping max_norm = 0.5 nn.utils.clip_grad_norm_(model.parameters(), max_norm) # update the gradients if args['adjust_lr']: adjust_lr(optimizer, args['initial_lr'], args['decay_rate'], training_step) optimizer.step() optimizer.zero_grad() training_step += 1 if bn % 1 == 0: print("[{}] batch number {}/{}: loss = {}".format( str(datetime.now()), bn, num_batches, loss)) sys.stdout.flush() if bn % 1 == 0: print( "======================== GENERATED SUMMARY ========================" ) print( bert_tokenizer.decode( torch.argmax(decoder_output[0], dim=-1).cpu().numpy()[:tgt_len[0]])) print( "======================== REFERENCE SUMMARY ========================" ) print( bert_tokenizer.decode( decoder_target.view(BATCH_SIZE, args['summary_length']) [0, :tgt_len[0]].cpu().numpy())) if bn == 0: # e.g. eval every epoch # ---------------- Evaluate the model on validation data ---------------- # print("Evaluating the model at epoch {} step {}".format( epoch, bn)) print("learning_rate = {}".format( optimizer.param_groups[0]['lr'])) model.eval() # switch to evaluation mode with torch.no_grad(): avg_val_loss = evaluate(model, valid_data, VAL_BATCH_SIZE, args, device) print("avg_val_loss_per_token = {}".format(avg_val_loss)) model.train() # switch to training mode # ------------------- Save the model OR Stop training ------------------- # if avg_val_loss < best_val_loss: stop_counter = 0 best_val_loss = avg_val_loss best_epoch = epoch savepath = args[ 'model_save_dir'] + "model-{}-ep{}.pt".format( args['model_name'], epoch) torch.save(model.state_dict(), savepath) print("Model improved & saved at {}".format(savepath)) else: print("Model not improved #{}".format(stop_counter)) if stop_counter < VAL_STOP_TRAINING: # load the previous model latest_model = args[ 'model_save_dir'] + "model-{}-ep{}.pt".format( args['model_name'], best_epoch) model.load_state_dict(torch.load(latest_model)) model.train() print("Restored model from {}".format(latest_model)) stop_counter += 1 else: print( "Model has not improved for {} times! Stop training." .format(VAL_STOP_TRAINING)) return print("End of training hierarchical RNN model")
def decode_beamsearch(self, input, u_len, w_len, decode_dict): """ this method is meant to be used at inference time input = input to the encoder u_len = utterance lengths w_len = word lengths decode_dict: - k = beamwidth for beamsearch - batch_size = batch_size - time_step = max_summary_length - vocab_size = 30522 for BERT - device = cpu or cuda - start_token_id = ID of the start token - stop_token_id = ID of the stop token - alpha = length normalisation - length_offset = length offset - keypadmask_dtype = torch.bool """ k = decode_dict['k'] search_method = decode_dict['search_method'] batch_size = decode_dict['batch_size'] time_step = decode_dict['time_step'] vocab_size = decode_dict['vocab_size'] device = decode_dict['device'] start_token_id = decode_dict['start_token_id'] stop_token_id = decode_dict['stop_token_id'] alpha = decode_dict['alpha'] penalty_ug = decode_dict['penalty_ug'] keypadmask_dtype = decode_dict['keypadmask_dtype'] # create beam array & scores beams = [None for _ in range(k)] beam_scores = np.zeros((batch_size, k)) # we should only feed through the encoder just once!! enc_output_dict = self.encoder(input, u_len, w_len) # memory # we run the decoder time_step times (auto-regressive) tgt_ids = torch.zeros((batch_size, time_step), dtype=torch.int64).to(device) tgt_ids[:, 0] = start_token_id for i in range(k): beams[i] = tgt_ids finished_beams = [[] for _ in range(batch_size)] beam_ht = [self.decoder.init_h0(batch_size) for _ in range(k)] finish = False for t in range(time_step - 1): if finish: break decoder_output_t_array = torch.zeros((batch_size, k * vocab_size)) for i, beam in enumerate(beams): # inference decoding # decoder_output = self.decoder(beam[:,:t+1], s_output, s_len, logsoftmax=True)[:,-1,:] decoder_output, beam_ht[ i], attn_scores = self.decoder.forward_step( beam[:, t:t + 1], beam_ht[i], enc_output_dict, logsoftmax=True) # print("t = {}: attn_scores = {}".format(t , attn_scores)) # import pdb; pdb.set_trace() # check if there is STOP_TOKEN emitted in the previous time step already # i.e. if the input at this time step is STOP_TOKEN for n_idx in range(batch_size): if beam[n_idx][t] == stop_token_id: # already stop decoder_output[n_idx, :] = float('-inf') decoder_output[ n_idx, stop_token_id] = 0.0 # to ensure STOP_TOKEN will be picked again! decoder_output_t_array[:, i * vocab_size:(i + 1) * vocab_size] = decoder_output # add previous beam score bias for n_idx in range(batch_size): decoder_output_t_array[n_idx, i * vocab_size:(i + 1) * vocab_size] += beam_scores[n_idx, i] if search_method == 'argmax': # Penalty term for repeated uni-gram unigram_dict = {} for tt in range(t + 1): v = beam[n_idx, tt].cpu().numpy().item() if v not in unigram_dict: unigram_dict[v] = 1 else: unigram_dict[v] += 1 for vocab_id, vocab_count in unigram_dict.items(): decoder_output_t_array[ n_idx, (i * vocab_size) + vocab_id] -= penalty_ug * vocab_count / (t + 1) # only support batch_size = 1! if t == 0: decoder_output_t_array[n_idx, (i + 1) * vocab_size:] = float('-inf') break if search_method == 'sampling': # Sampling scores = np.zeros((batch_size, k)) indices = np.zeros((batch_size, k)) pmf = np.exp(decoder_output_t_array.cpu().numpy()) for bi in range(batch_size): if pmf[bi].sum() != 1.0: pmf[bi] /= pmf[bi].sum() sampled_ids = np.random.choice(k * vocab_size, size=k, p=pmf[bi]) for _s, s_id in enumerate(sampled_ids): scores[bi, _s] = decoder_output_t_array[bi, s_id] indices[bi, _s] = s_id elif search_method == 'argmax': # Argmax topk_scores, topk_ids = torch.topk(decoder_output_t_array, k, dim=-1) scores = topk_scores.double().cpu().numpy() indices = topk_ids.double().cpu().numpy() new_beams = [ torch.zeros((batch_size, time_step), dtype=torch.int64).to(device) for _ in range(k) ] for r_idx, row in enumerate(indices): for c_idx, node in enumerate(row): vocab_idx = node % vocab_size beam_idx = int(node / vocab_size) new_beams[c_idx][r_idx, :t + 1] = beams[beam_idx][r_idx, :t + 1] new_beams[c_idx][r_idx, t + 1] = vocab_idx # if there is a beam that has [END_TOKEN] --- store it if vocab_idx == stop_token_id: finished_beams[r_idx].append( new_beams[c_idx][r_idx, :t + 1 + 1]) scores[r_idx, c_idx] = float('-inf') # only support BATCH SIZE = 1 count_stop = 0 for ik in range(k): if scores[0, ik] == float('-inf'): count_stop += 1 if count_stop == k: finish = True beams = new_beams if search_method == 'sampling': # normalisation the score scores = np.exp(scores) scores = scores / scores.sum(axis=-1).reshape(batch_size, 1) beam_scores = np.log(scores + 1e-20) # suppress warning log(zero) elif search_method == 'argmax': beam_scores = scores # print("========================= t = {} =========================".format(t)) # for ik in range(k): # print("beam{}: [{:.5f}]".format(ik, scores[0,ik]),bert_tokenizer.decode(beams[ik][0].cpu().numpy()[:t+2])) if (t % 50) == 0: print("{}=".format(t), end="") sys.stdout.flush() print("{}=#".format(t)) for bi in range(batch_size): if len(finished_beams[bi]) == 0: finished_beams[bi].append(beams[0][bi]) summaries_id = [None for _ in range(batch_size)] # for j in range(batch_size): summaries_id[j] = beams[0][j].cpu().numpy() for j in range(batch_size): _scores = self.beam_scoring(finished_beams[j], enc_output_dict, alpha) summaries_id[j] = finished_beams[j][np.argmax( _scores)].cpu().numpy() print(bert_tokenizer.decode(summaries_id[j])) return summaries_id
def train1(): print("Start training hierarchical RNN model") # ---------------------------------------------------------------------------------- # args = {} args['use_gpu'] = True args[ 'air_multi_gpu'] = False # to enable running on multiple GPUs on stack args['num_utterances'] = 2000 # max no. utterance in a meeting args['num_words'] = 64 # max no. words in an utterance args['summary_length'] = 800 # max no. words in a summary args['summary_type'] = 'long' # long or short summary args['vocab_size'] = 30522 # BERT tokenizer args['embedding_dim'] = 256 # word embeeding dimension args['rnn_hidden_size'] = 512 # RNN hidden size args['dropout'] = 0.5 args['num_layers_enc'] = 1 # in total it's num_layers_enc*3 (word/utt/seg) args['num_layers_dec'] = 1 args['batch_size'] = 1 args['update_nbatches'] = 2 # 0 meaning whole batch update & using SGD args['num_epochs'] = 50 args['random_seed'] = 30 args['best_val_loss'] = 1e+10 args['val_batch_size'] = 1 # 1 for now --- evaluate ROUGE args['val_stop_training'] = 5 args['lr'] = 0.01 args['adjust_lr'] = True # if True overwrite the learning rate above args['initial_lr'] = 1e-2 # lr = lr_0*step^(-decay_rate) args['decay_rate'] = 0.5 args['label_smoothing'] = 0.1 args['a_ts'] = 0.0 args['a_da'] = 0.0 args['a_ext'] = 0.0 args[ 'model_save_dir'] = "/home/alta/summary/pm574/summariser1/lib/trained_models/" # args['load_model'] = "/home/alta/summary/pm574/summariser1/lib/trained_models/model-HGRUV2_CNNDM_JAN26A-ep3-bn0" # add .pt later args['load_model'] = None args['model_name'] = 'HGRUV2_FEB18A' # ---------------------------------------------------------------------------------- # print_config(args) if args['use_gpu']: if 'X_SGE_CUDA_DEVICE' in os.environ: # to run on CUED stack machine if not args['air_multi_gpu']: print('running on the stack... 1 GPU') cuda_device = os.environ['X_SGE_CUDA_DEVICE'] print('X_SGE_CUDA_DEVICE is set to {}'.format(cuda_device)) os.environ['CUDA_VISIBLE_DEVICES'] = cuda_device else: print('running on the stack... multiple GPUs') os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3' write_multi_sl(args['model_name']) else: print('running locally...') os.environ[ "CUDA_VISIBLE_DEVICES"] = '1' # choose the device (GPU) here device = 'cuda' else: device = 'cpu' print("device = {}".format(device)) # random seed random.seed(args['random_seed']) torch.manual_seed(args['random_seed']) np.random.seed(args['random_seed']) train_data = load_ami_data('train') valid_data = load_ami_data('valid') # make the training data 100 # random.shuffle(valid_data) # train_data.extend(valid_data[:6]) # valid_data = valid_data[6:] model = EncoderDecoder(args, device=device) print(model) NUM_DA_TYPES = len(DA_MAPPING) da_labeller = DALabeller(args['rnn_hidden_size'], NUM_DA_TYPES, device) print(da_labeller) ext_labeller = EXTLabeller(args['rnn_hidden_size'], device) print(ext_labeller) # to use multiple GPUs if torch.cuda.device_count() > 1: print("Multiple GPUs: {}".format(torch.cuda.device_count())) model = nn.DataParallel(model) da_labeller = nn.DataParallel(da_labeller) ext_labeller = nn.DataParallel(ext_labeller) # Load model if specified (path to pytorch .pt) if args['load_model'] != None: model_path = args['load_model'] + '.pt' da_path = args['load_model'] + '.da.pt' ext_path = args['load_model'] + '.ext.pt' try: model.load_state_dict(torch.load(model_path)) da_labeller.load_state_dict(torch.load(da_path)) ext_labeller.load_state_dict(torch.load(ext_path)) except RuntimeError: # need to remove module # Main model model_state_dict = torch.load(model_path) new_model_state_dict = OrderedDict() for key in model_state_dict.keys(): new_model_state_dict[key.replace("module.", "")] = model_state_dict[key] model.load_state_dict(new_model_state_dict) # DA model # model_state_dict = torch.load(da_path) # new_model_state_dict = OrderedDict() # for key in model_state_dict.keys(): # new_model_state_dict[key.replace("module.","")] = model_state_dict[key] # da_labeller.load_state_dict(new_model_state_dict) # EXT model # model_state_dict = torch.load(ext_path) # new_model_state_dict = OrderedDict() # for key in model_state_dict.keys(): # new_model_state_dict[key.replace("module.","")] = model_state_dict[key] # ext_labeller.load_state_dict(new_model_state_dict) model.train() da_labeller.train() ext_labeller.train() print("Loaded model from {}".format(args['load_model'])) else: print("Train a new model") # Hyperparameters BATCH_SIZE = args['batch_size'] NUM_EPOCHS = args['num_epochs'] VAL_BATCH_SIZE = args['val_batch_size'] VAL_STOP_TRAINING = args['val_stop_training'] if args['label_smoothing'] > 0.0: criterion = LabelSmoothingLoss(num_classes=args['vocab_size'], smoothing=args['label_smoothing'], reduction='none') else: criterion = nn.NLLLoss(reduction='none') topic_segment_criterion = nn.BCELoss(reduction='none') da_criterion = nn.NLLLoss(reduction='none') ext_criterion = nn.BCELoss(reduction='none') # we use two separate optimisers (encoder & decoder) optimizer = optim.Adam(model.parameters(), lr=args['lr'], betas=(0.9, 0.999), eps=1e-08, weight_decay=0) optimizer.zero_grad() sgd_optimizer = optim.SGD(model.parameters(), lr=args['lr']) sgd_optimizer.zero_grad() # DA labeller optimiser da_optimizer = optim.Adam(da_labeller.parameters(), lr=args['lr'], betas=(0.9, 0.999), eps=1e-08, weight_decay=0) da_optimizer.zero_grad() # extractive labeller optimiser ext_optimizer = optim.Adam(ext_labeller.parameters(), lr=args['lr'], betas=(0.9, 0.999), eps=1e-08, weight_decay=0) ext_optimizer.zero_grad() # validation losses best_val_loss = args['best_val_loss'] best_epoch = 0 stop_counter = 0 training_step = 0 for epoch in range(NUM_EPOCHS): print( "======================= Training epoch {} =======================" .format(epoch)) num_train_data = len(train_data) # num_batches = int(num_train_data/BATCH_SIZE) + 1 num_batches = int(num_train_data / BATCH_SIZE) print("num_batches = {}".format(num_batches)) print("shuffle train data") random.shuffle(train_data) idx = 0 # scheduled sampling probability --- start from 1.0 and go to zero # sch_prob = 0.5*(1 - (epoch/NUM_EPOCHS)) # print("epoch {}: scheduled_sampling_prob = {}".format(epoch, sch_prob)) for bn in range(num_batches): input, u_len, w_len, target, tgt_len, topic_boundary_label, dialogue_acts, extractive_label = get_a_batch( train_data, idx, BATCH_SIZE, args['num_utterances'], args['num_words'], args['summary_length'], args['summary_type'], device) # decoder target decoder_target, decoder_mask = shift_decoder_target( target, tgt_len, device, mask_offset=True) decoder_target = decoder_target.view(-1) decoder_mask = decoder_mask.view(-1) decoder_output, gate_z, u_output = model(input, u_len, w_len, target) # decoder_output, gate_z, u_output = model.forward_scheduled_sampling(input, u_len, w_len, target, prob=sch_prob) loss = criterion(decoder_output.view(-1, args['vocab_size']), decoder_target) loss = (loss * decoder_mask).sum() / decoder_mask.sum() # multitask(1): topic segmentation prediction loss_ts = topic_segment_criterion(gate_z, topic_boundary_label) loss_ts_mask = length2mask(u_len, BATCH_SIZE, args['num_utterances'], device) loss_ts = (loss_ts * loss_ts_mask).sum() / loss_ts_mask.sum() # multitask(2): dialogue act prediction da_output = da_labeller(u_output) loss_da = da_criterion(da_output.view(-1, NUM_DA_TYPES), dialogue_acts.view(-1)).view( BATCH_SIZE, -1) loss_da = (loss_da * loss_ts_mask).sum() / loss_ts_mask.sum() # multitask(3): extractive label prediction ext_output = ext_labeller(u_output).squeeze(-1) loss_ext = ext_criterion(ext_output, extractive_label) loss_ext = (loss_ext * loss_ts_mask).sum() / loss_ts_mask.sum() # # multitask(3.2): extractive label to control attention at utterance # attn_extsum = (1-extractive_label)*loss_ts_mask # loss_ext_attn = (scores_u.sum(dim=1) * attn_extsum).sum() / loss_ts_mask.sum() total_loss = loss + args['a_ts'] * loss_ts + args[ 'a_da'] * loss_da + args['a_ext'] * loss_ext total_loss.backward() idx += BATCH_SIZE if bn % args['update_nbatches'] == 0: # gradient_clipping max_norm = 0.5 nn.utils.clip_grad_norm_(model.parameters(), max_norm) nn.utils.clip_grad_norm_(da_labeller.parameters(), max_norm) nn.utils.clip_grad_norm_(ext_labeller.parameters(), max_norm) # update the gradients if args['adjust_lr']: adjust_lr(optimizer, args['initial_lr'], args['decay_rate'], training_step) adjust_lr(da_optimizer, args['initial_lr'], args['decay_rate'], training_step) adjust_lr(ext_optimizer, args['initial_lr'], args['decay_rate'], training_step) optimizer.step() optimizer.zero_grad() da_optimizer.step() da_optimizer.zero_grad() ext_optimizer.step() ext_optimizer.zero_grad() training_step += args['batch_size'] * args['update_nbatches'] if bn % 1 == 0: print( "[{}] batch {}/{}: loss = {:.5f} | loss_ts = {:.5f} | loss_da = {:.5f} | loss_ext = {:.5f}" .format(str(datetime.now()), bn, num_batches, loss, loss_ts, loss_da, loss_ext)) # print("[{}] batch {}/{}: loss = {:5f}". # format(str(datetime.now()), bn, num_batches, loss)) sys.stdout.flush() if bn % 20 == 0: print( "======================== GENERATED SUMMARY ========================" ) print( bert_tokenizer.decode( torch.argmax(decoder_output[0], dim=-1).cpu().numpy()[:tgt_len[0]])) print( "======================== REFERENCE SUMMARY ========================" ) print( bert_tokenizer.decode( decoder_target.view(BATCH_SIZE, args['summary_length']) [0, :tgt_len[0]].cpu().numpy())) if bn == 0: # e.g. eval every epoch # ---------------- Evaluate the model on validation data ---------------- # print("Evaluating the model at epoch {} step {}".format( epoch, bn)) print("learning_rate = {}".format( optimizer.param_groups[0]['lr'])) # switch to evaluation mode model.eval() da_labeller.eval() ext_labeller.eval() with torch.no_grad(): avg_val_loss = evaluate(model, valid_data, VAL_BATCH_SIZE, args, device, use_rouge=True) # avg_val_loss = evaluate_beam(model, valid_data, VAL_BATCH_SIZE, args, device, use_rouge=True) print("avg_val_loss_per_token = {}".format(avg_val_loss)) # switch to training mode model.train() da_labeller.train() ext_labeller.train() # ------------------- Save the model OR Stop training ------------------- # if avg_val_loss < best_val_loss: stop_counter = 0 best_val_loss = avg_val_loss best_epoch = epoch savepath = args[ 'model_save_dir'] + "model-{}-ep{}.pt".format( args['model_name'], epoch) savepath_da = args[ 'model_save_dir'] + "model-{}-ep{}.da.pt".format( args['model_name'], epoch) savepath_ext = args[ 'model_save_dir'] + "model-{}-ep{}.ext.pt".format( args['model_name'], epoch) torch.save(model.state_dict(), savepath) torch.save(da_labeller.state_dict(), savepath_da) torch.save(ext_labeller.state_dict(), savepath_ext) print("Model improved & saved at {}".format(savepath)) else: print("Model not improved #{}".format(stop_counter)) if stop_counter < VAL_STOP_TRAINING: # load the previous model latest_model = args[ 'model_save_dir'] + "model-{}-ep{}.pt".format( args['model_name'], best_epoch) latest_model_da = args[ 'model_save_dir'] + "model-{}-ep{}.da.pt".format( args['model_name'], best_epoch) latest_model_ext = args[ 'model_save_dir'] + "model-{}-ep{}.ext.pt".format( args['model_name'], best_epoch) model.load_state_dict(torch.load(latest_model)) da_labeller.load_state_dict( torch.load(latest_model_da)) ext_labeller.load_state_dict( torch.load(latest_model_ext)) model.train() da_labeller.train() ext_labeller.train() print("Restored model from {}".format(latest_model)) stop_counter += 1 else: print( "Model has not improved for {} times! Stop training." .format(VAL_STOP_TRAINING)) return if args['air_multi_gpu']: rm_multi_sl(args['model_name']) print("End of training hierarchical RNN model")