class Predict(): @timer(module='initalize predicter') def __init__(self): self.DEVICE = config.DEVICE dataset = PairDataset(config.data_path, max_src_len=config.max_src_len, max_tgt_len=config.max_tgt_len, truncate_src=config.truncate_src, truncate_tgt=config.truncate_tgt) self.vocab = dataset.build_vocab(embed_file=config.embed_file) self.model = PGN(self.vocab) self.stop_word = list( set([ self.vocab[x.strip()] for x in open( config.stop_word_file, encoding='utf-8').readlines() ])) self.model.load_model() self.model.to(self.DEVICE) def greedy_search(self, x, max_sum_len, len_oovs, x_padding_masks): """Function which returns a summary by always picking the highest probability option conditioned on the previous word. Args: x (Tensor): Input sequence as the source. max_sum_len (int): The maximum length a summary can have. len_oovs (Tensor): Numbers of out-of-vocabulary tokens. x_padding_masks (Tensor): The padding masks for the input sequences with shape (batch_size, seq_len). Returns: summary (list): The token list of the result summary. """ # Get encoder output and states.Call encoder forward propagation ########################################### # TODO: module 4 task 2 # ########################################### # use decoder to generate vocab distribution for the next token encoder_output, encoder_states = self.model.encoder( replace_oovs(x, self.vocab), self.model.decoder.embedding) # Initialize decoder's hidden states with encoder's hidden states. decoder_states = self.model.reduce_state(encoder_states) # Initialize decoder's input at time step 0 with the SOS token. x_t = torch.ones(1) * self.vocab.SOS x_t = x_t.to(self.DEVICE, dtype=torch.int64) summary = [self.vocab.SOS] coverage_vector = torch.zeros((1, x.shape[1])).to(self.DEVICE) # Generate hypothesis with maximum decode step. while int(x_t.item()) != (self.vocab.EOS) \ and len(summary) < max_sum_len: context_vector, attention_weights, coverage_vector = \ self.model.attention(decoder_states, encoder_output, x_padding_masks, coverage_vector) p_vocab, decoder_states, p_gen = \ self.model.decoder(x_t.unsqueeze(1), decoder_states, context_vector) final_dist = self.model.get_final_distribution( x, p_gen, p_vocab, attention_weights, torch.max(len_oovs)) # Get next token with maximum probability. x_t = torch.argmax(final_dist, dim=1).to(self.DEVICE) decoder_word_idx = x_t.item() summary.append(decoder_word_idx) x_t = replace_oovs(x_t, self.vocab) return summary # @timer('best k') def best_k(self, beam, k, encoder_output, x_padding_masks, x, len_oovs): """Get best k tokens to extend the current sequence at the current time step. Args: beam (untils.Beam): The candidate beam to be extended. k (int): Beam size. encoder_output (Tensor): The lstm output from the encoder. x_padding_masks (Tensor): The padding masks for the input sequences. x (Tensor): Source token ids. len_oovs (Tensor): Number of oov tokens in a batch. Returns: best_k (list(Beam)): The list of best k candidates. """ # use decoder to generate vocab distribution for the next token x_t = torch.tensor(beam.tokens[-1]).reshape(1, 1) x_t = x_t.to(self.DEVICE) # Get context vector from attention network. context_vector, attention_weights, coverage_vector = \ self.model.attention(beam.decoder_states, encoder_output, x_padding_masks, beam.coverage_vector) # Replace the indexes of OOV words with the index of OOV token # to prevent index-out-of-bound error in the decoder. p_vocab, decoder_states, p_gen = \ self.model.decoder(replace_oovs(x_t, self.vocab), beam.decoder_states, context_vector) final_dist = self.model.get_final_distribution(x, p_gen, p_vocab, attention_weights, torch.max(len_oovs)) # Calculate log probabilities. log_probs = torch.log(final_dist.squeeze()) # Filter forbidden tokens. if len(beam.tokens) == 1: forbidden_ids = [ self.vocab[u"台独"], self.vocab[u"吸毒"], self.vocab[u"黄赌毒"] ] log_probs[forbidden_ids] = -float('inf') # EOS token penalty. Follow the definition in # https://opennmt.net/OpenNMT/translation/beam_search/. log_probs[self.vocab.EOS] *= \ config.gamma * x.size()[1] / len(beam.tokens) log_probs[self.vocab.UNK] = -float('inf') # Get top k tokens and the corresponding logprob. topk_probs, topk_idx = torch.topk(log_probs, k) # Extend the current hypo with top k tokens, resulting k new hypos. best_k = [ beam.extend(x, log_probs[x], decoder_states, coverage_vector) for x in topk_idx.tolist() ] return best_k def beam_search(self, x, max_sum_len, beam_width, len_oovs, x_padding_masks): """Using beam search to generate summary. Args: x (Tensor): Input sequence as the source. max_sum_len (int): The maximum length a summary can have. beam_width (int): Beam size. max_oovs (int): Number of out-of-vocabulary tokens. x_padding_masks (Tensor): The padding masks for the input sequences. Returns: result (list(Beam)): The list of best k candidates. """ # run body_sequence input through encoder. Call encoder forward propagation ########################################### # TODO: module 4 task 2 # ########################################### encoder_output, encoder_states = self.model.encoder( replace_oovs(x, self.vocab), self.model.decoder.embedding) coverage_vector = torch.zeros((1, x.shape[1])).to(self.DEVICE) # initialize decoder states with encoder forward states decoder_states = self.model.reduce_state(encoder_states) # initialize the hypothesis with a class Beam instance. init_beam = Beam([self.vocab.SOS], [0], decoder_states, coverage_vector) # get the beam size and create a list for stroing current candidates # and a list for completed hypothesis k = beam_width curr, completed = [init_beam], [] # use beam search for max_sum_len (maximum length) steps for _ in range(max_sum_len): # get k best hypothesis when adding a new token topk = [] for beam in curr: # When an EOS token is generated, add the hypo to the completed # list and decrease beam size. if beam.tokens[-1] == self.vocab.EOS: completed.append(beam) k -= 1 continue for can in self.best_k(beam, k, encoder_output, x_padding_masks, x, torch.max(len_oovs)): # Using topk as a heap to keep track of top k candidates. # Using the sequence scores of the hypos to campare # and object ids to break ties. add2heap(topk, (can.seq_score(), id(can), can), k) curr = [items[2] for items in topk] # stop when there are enough completed hypothesis if len(completed) == beam_width: break # When there are not engouh completed hypotheses, # take whatever when have in current best k as the final candidates. completed += curr # sort the hypothesis by normalized probability and choose the best one result = sorted(completed, key=lambda x: x.seq_score(), reverse=True)[0].tokens return result @timer(module='doing prediction') def predict(self, text, tokenize=True, beam_search=True): """Generate summary. Args: text (str or list): Source. tokenize (bool, optional): Whether to do tokenize or not. Defaults to True. beam_search (bool, optional): Whether to use beam search or not. Defaults to True (means using greedy search). Returns: str: The final summary. """ if isinstance(text, str) and tokenize: text = list(jieba.cut(text)) x, oov = source2ids(text, self.vocab) x = torch.tensor(x).to(self.DEVICE) len_oovs = torch.tensor([len(oov)]).to(self.DEVICE) x_padding_masks = torch.ne(x, 0).byte().float() if beam_search: summary = self.beam_search(x.unsqueeze(0), max_sum_len=config.max_dec_steps, beam_width=config.beam_size, len_oovs=len_oovs, x_padding_masks=x_padding_masks) else: summary = self.greedy_search(x.unsqueeze(0), max_sum_len=config.max_dec_steps, len_oovs=len_oovs, x_padding_masks=x_padding_masks) summary = outputids2words(summary, oov, self.vocab) return summary.replace('<SOS>', '').replace('<EOS>', '').strip()
def train(dataset, val_dataset, v, start_epoch=0): """Train the model, evaluate it and store it. Args: dataset (dataset.PairDataset): The training dataset. val_dataset (dataset.PairDataset): The evaluation dataset. v (vocab.Vocab): The vocabulary built from the training dataset. start_epoch (int, optional): The starting epoch number. Defaults to 0. """ DEVICE = torch.device("cuda" if config.is_cuda else "cpu") model = PGN(v) model.load_model() model.to(DEVICE) if config.fine_tune: # In fine-tuning mode, we fix the weights of all parameters except attention.wc. print('Fine-tuning mode.') for name, params in model.named_parameters(): if name != 'attention.wc.weight': params.requires_grad = False # forward print("loading data") train_data = SampleDataset(dataset.pairs, v) val_data = SampleDataset(val_dataset.pairs, v) print("initializing optimizer") # Define the optimizer. optimizer = optim.Adam(model.parameters(), lr=config.learning_rate) train_dataloader = DataLoader(dataset=train_data, batch_size=config.batch_size, shuffle=True, collate_fn=collate_fn) val_losses = np.inf if (os.path.exists(config.losses_path)): with open(config.losses_path, 'rb') as f: val_losses = pickle.load(f) # torch.cuda.empty_cache() # SummaryWriter: Log writer used for TensorboardX visualization. writer = SummaryWriter(config.log_path) # tqdm: A tool for drawing progress bars during training. # scheduled_sampler : A tool for choosing teacher_forcing or not num_epochs = len(range(start_epoch, config.epochs)) scheduled_sampler = ScheduledSampler(num_epochs) if config.scheduled_sampling: print('scheduled_sampling mode.') # teacher_forcing = True with tqdm(total=config.epochs) as epoch_progress: for epoch in range(start_epoch, config.epochs): print(config_info(config)) batch_losses = [] # Get loss of each batch. num_batches = len(train_dataloader) # set a teacher_forcing signal if config.scheduled_sampling: teacher_forcing = scheduled_sampler.teacher_forcing( epoch - start_epoch) else: teacher_forcing = True print('teacher_forcing = {}'.format(teacher_forcing)) with tqdm(total=num_batches) as batch_progress: for batch, data in enumerate(tqdm(train_dataloader)): x, y, x_len, y_len, oov, len_oovs = data assert not np.any(np.isnan(x.numpy())) if config.is_cuda: # Training with GPUs. x = x.to(DEVICE) y = y.to(DEVICE) x_len = x_len.to(DEVICE) len_oovs = len_oovs.to(DEVICE) model.train() # Sets the module in training mode. optimizer.zero_grad() # Clear gradients. # Calculate loss. Call model forward propagation loss = model(x, x_len, y, len_oovs, batch=batch, num_batches=num_batches, teacher_forcing=teacher_forcing) batch_losses.append(loss.item()) loss.backward() # Backpropagation. # Do gradient clipping to prevent gradient explosion. clip_grad_norm_(model.encoder.parameters(), config.max_grad_norm) clip_grad_norm_(model.decoder.parameters(), config.max_grad_norm) clip_grad_norm_(model.attention.parameters(), config.max_grad_norm) optimizer.step() # Update weights. # Output and record epoch loss every 100 batches. if (batch % 32) == 0: batch_progress.set_description(f'Epoch {epoch}') batch_progress.set_postfix(Batch=batch, Loss=loss.item()) batch_progress.update() # Write loss for tensorboard. writer.add_scalar(f'Average loss for epoch {epoch}', np.mean(batch_losses), global_step=batch) # Calculate average loss over all batches in an epoch. epoch_loss = np.mean(batch_losses) epoch_progress.set_description(f'Epoch {epoch}') epoch_progress.set_postfix(Loss=epoch_loss) epoch_progress.update() avg_val_loss = evaluate(model, val_data, epoch) print('training loss:{}'.format(epoch_loss), 'validation loss:{}'.format(avg_val_loss)) # Update minimum evaluating loss. if (avg_val_loss < val_losses): torch.save(model.encoder, config.encoder_save_name) torch.save(model.decoder, config.decoder_save_name) torch.save(model.attention, config.attention_save_name) torch.save(model.reduce_state, config.reduce_state_save_name) val_losses = avg_val_loss with open(config.losses_path, 'wb') as f: pickle.dump(val_losses, f) writer.close()
def train(dataset, val_dataset, v, start_epoch=0): """Train the model, evaluate it and store it. Args: dataset (dataset.PairDataset): The training dataset. val_dataset (dataset.PairDataset): The evaluation dataset. v (vocab.Vocab): The vocabulary built from the training dataset. start_epoch (int, optional): The starting epoch number. Defaults to 0. """ torch.autograd.set_detect_anomaly(True) DEVICE = torch.device("cuda" if config.is_cuda else "cpu") model = PGN(v) model.load_model() model.to(DEVICE) if config.fine_tune: # In fine-tuning mode, we fix the weights of all parameters except attention.wc. logging.info('Fine-tuning mode.') for name, params in model.named_parameters(): if name != 'attention.wc.weight': params.requires_grad = False # forward logging.info("loading data") train_data = dataset val_data = val_dataset logging.info("initializing optimizer") # Define the optimizer. # optimizer = optim.Adam(model.parameters(), lr=config.learning_rate) optimizer = optim.Adagrad( model.parameters(), lr=config.learning_rate, initial_accumulator_value=config.initial_accumulator_value) scheduler = StepLR(optimizer, step_size=10, gamma=0.2) # 学习率调整 train_dataloader = DataLoader(dataset=train_data, batch_size=config.batch_size, shuffle=True, collate_fn=collate_fn) val_loss = np.inf if (os.path.exists(config.losses_path)): with open(config.losses_path, 'r') as f: val_loss = float(f.readlines()[-1].split("=")[-1]) logging.info("the last best val loss is: " + str(val_loss)) # torch.cuda.empty_cache() # SummaryWriter: Log writer used for TensorboardX visualization. writer = SummaryWriter(config.log_path) # tqdm: A tool for drawing progress bars during training. early_stopping_count = 0 logging.info("start training model {}, ".format(config.model_name) + \ "epoch : {}, ".format(config.epochs) + "batch_size : {}, ".format(config.batch_size) + "num batches: {}, ".format(len(train_dataloader))) for epoch in range(start_epoch, config.epochs): batch_losses = [] # Get loss of each batch. num_batches = len(train_dataloader) # with tqdm(total=num_batches//100) as batch_progress: for batch, data in enumerate(train_dataloader): x, y, x_len, y_len, oov, len_oovs, img_vec = data assert not np.any(np.isnan(x.numpy())) if config.is_cuda: # Training with GPUs. x = x.to(DEVICE) y = y.to(DEVICE) x_len = x_len.to(DEVICE) len_oovs = len_oovs.to(DEVICE) img_vec = img_vec.to(DEVICE) if batch == 0: logging.info("x: %s, shape: %s" % (x, x.shape)) logging.info("y: %s, shape: %s" % (y, y.shape)) logging.info("oov: %s" % oov) logging.info("img_vec: %s, shape: %s" % (img_vec, img_vec.shape)) model.train() # Sets the module in training mode. optimizer.zero_grad() # Clear gradients. loss = model(x, y, len_oovs, img_vec, batch=batch, num_batches=num_batches) batch_losses.append(loss.item()) loss.backward() # Backpropagation. # Do gradient clipping to prevent gradient explosion. clip_grad_norm_(model.encoder.parameters(), config.max_grad_norm) clip_grad_norm_(model.decoder.parameters(), config.max_grad_norm) clip_grad_norm_(model.attention.parameters(), config.max_grad_norm) clip_grad_norm_(model.reduce_state.parameters(), config.max_grad_norm) optimizer.step() # Update weights. # scheduler.step() # # Output and record epoch loss every 100 batches. if (batch % 100) == 0: # batch_progress.set_description(f'Epoch {epoch}') # batch_progress.set_postfix(Batch=batch, # Loss=loss.item()) # batch_progress.update() # # Write loss for tensorboard. writer.add_scalar(f'Average_loss_for_epoch_{epoch}', np.mean(batch_losses), global_step=batch) logging.info('epoch: {}, batch:{}, training loss:{}'.format( epoch, batch, np.mean(batch_losses))) # Calculate average loss over all batches in an epoch. epoch_loss = np.mean(batch_losses) # epoch_progress.set_description(f'Epoch {epoch}') # epoch_progress.set_postfix(Loss=epoch_loss) # epoch_progress.update() avg_val_loss = evaluate(model, val_data, epoch) logging.info('epoch: {} '.format(epoch) + 'training loss:{} '.format(epoch_loss) + 'validation loss:{} '.format(avg_val_loss)) # Update minimum evaluating loss. if not os.path.exists(os.path.dirname(config.encoder_save_name)): os.mkdir(os.path.dirname(config.encoder_save_name)) if (avg_val_loss < val_loss): logging.info("saving model to ../saved_model/ %s" % config.model_name) torch.save(model.encoder, config.encoder_save_name) torch.save(model.decoder, config.decoder_save_name) torch.save(model.attention, config.attention_save_name) torch.save(model.reduce_state, config.reduce_state_save_name) val_loss = avg_val_loss with open(config.losses_path, 'a') as f: f.write(f"best val loss={val_loss}\n") else: early_stopping_count += 1 if early_stopping_count >= config.patience: logging.info( f'Validation loss did not decrease for {config.patience} epochs, stop training.' ) break writer.close()