def beam_search(self, input_seq, beam_size=3, attentionOverrideMap=None, correctionMap=None, unk_map=None, beam_length=0.5, beam_coverage=0.5, max_length=MAX_LENGTH): torch.set_grad_enabled(False) input_seqs = [indexes_from_sentence(self.input_lang, input_seq)] input_lengths = [len(seq) for seq in input_seqs] input_batches = Variable(torch.LongTensor(input_seqs)).transpose(0, 1) if use_cuda: input_batches = input_batches.cuda() self.encoder.train(False) self.decoder.train(False) encoder_outputs, encoder_hidden = self.encoder(input_batches, input_lengths, None) decoder_hidden = encoder_hidden beam_search = BeamSearch(self.decoder, encoder_outputs, decoder_hidden, self.output_lang, beam_size, attentionOverrideMap, correctionMap, unk_map, beam_length=beam_length, beam_coverage=beam_coverage) result = beam_search.search() self.encoder.train(True) self.decoder.train(True) torch.set_grad_enabled(True) return result # Return a list of indexes, one for each word in the sentence, plus EOS
def _get_top_k_sequences(self, log_probs, wordpiece_mask, k): batch_size = log_probs.size()[0] seq_length = log_probs.size()[1] beam_search = BeamSearch(self._end_index, max_steps=seq_length, beam_size=k, per_node_beam_size=self._per_node_beam_size) beam_log_probs = torch.nn.functional.pad( log_probs, pad=(0, 2, 0, 0, 0, 0), value=-1e7 ) # add low log probabilites for start and end tags used in the beam search start_predictions = beam_log_probs.new_full( (batch_size, ), fill_value=self._start_index).long() # Shape: (batch_size, beam_size, seq_length) top_k_predictions, seq_log_probs = beam_search.search( start_predictions, { 'log_probs': beam_log_probs, 'wordpiece_mask': wordpiece_mask, 'step_num': beam_log_probs.new_zeros((batch_size, )).long() }, self.take_step) # get rid of start and end tags if they slipped in top_k_predictions[top_k_predictions > 2] = 0 return top_k_predictions
def beam_search(self, initial_sequence, forbid_movies=None, temperature=1, **kwargs): """ Beam search sentence generation :param initial_sequence: list giving the initial sequence of tokens :param kwargs: additional parameters to pass to model forward pass (e.g. a conditioning context) :return: """ beam_search = BeamSearch(self.beam_size, initial_sequence, self.word2id["</s>"]) beams = beam_search.beams for i in range(self.max_sequence_length): # compute probabilities for each beam probabilities = [] for beam in beams: # add batch_dimension model_input = Variable(torch.LongTensor( beam.sequence)).unsqueeze(0) if self.model.cuda_available: model_input = model_input.cuda() beam_forbidden_movies = forbid_movies.union( beam.mentioned_movies) prob = self.model(input=model_input, lengths=[len(beam.sequence)], log_probabilities=False, forbid_movies=beam_forbidden_movies, temperature=temperature, **kwargs) # get probabilities for the next token to generate probabilities.append(prob[0, -1, :].cpu()) # update beams beams = beam_search.search(probabilities, n_gram_block=self.n_gram_block) # replace movie names with the corresponding words for beam in beams: if beam.sequence[-1] > len(self.word2id): # update the list of movies mentioned for preventing repeated recommendations beam.mentioned_movies.add(beam.sequence[-1] - len(self.word2id)) beam.sequence[-1:] = replace_movie_with_words( beam.sequence[-1], self.movieId2name, self.word2id) return beams
class Prediction: def __init__(self, args): """ :param model_dir: model dir path :param vocab_file: vocab file path """ self.tf = import_tf(0) self.args = args self.model_dir = args.logdir self.vocab_file = args.vocab self.token2idx, self.idx2token = _load_vocab(args.vocab) hparams = Hparams() parser = hparams.parser self.hp = parser.parse_args() self.model = Transformer(self.hp) self._add_placeholder() self._init_graph() def _init_graph(self): """ init graph """ self.ys = (self.input_y, None, None) self.xs = (self.input_x, None) self.memory = self.model.encode(self.xs, False)[0] self.logits = self.model.decode(self.xs, self.ys, self.memory, False)[0] ckpt = self.tf.train.get_checkpoint_state(self.model_dir).all_model_checkpoint_paths[-1] graph = self.logits.graph sess_config = self.tf.ConfigProto(allow_soft_placement=True) sess_config.gpu_options.allow_growth = True saver = self.tf.train.Saver() self.sess = self.tf.Session(config=sess_config, graph=graph) self.sess.run(self.tf.global_variables_initializer()) self.tf.reset_default_graph() saver.restore(self.sess, ckpt) self.bs = BeamSearch(self.model, self.hp.beam_size, list(self.idx2token.keys())[2], list(self.idx2token.keys())[3], self.idx2token, self.hp.maxlen2, self.input_x, self.input_y, self.logits) def predict(self, content): """ abstract prediction by beam search :param content: article content :return: prediction result """ input_x = content.split() while len(input_x) < self.args.maxlen1: input_x.append('<pad>') input_x = input_x[:self.args.maxlen1] input_x = [self.token2idx.get(s, self.token2idx['<unk>']) for s in input_x] memory = self.sess.run(self.memory, feed_dict={self.input_x: [input_x]}) return self.bs.search(self.sess, input_x, memory[0]) def _add_placeholder(self): """ add tensorflow placeholder """ self.input_x = self.tf.placeholder(dtype=self.tf.int32, shape=[None, self.args.maxlen1], name='input_x') self.input_y = self.tf.placeholder(dtype=self.tf.int32, shape=[None, None], name='input_y')
class BSDecoder(object): def __init__(self, model, batch_reader, model_config, data_config, vocab, data_loader): self.model = model self.batch_reader = batch_reader self.model_config = model_config self.data_config = data_config self.vocab = vocab self.data_loader = data_loader self.saver = tf.train.Saver() self.session = tf.Session(config=tf.ConfigProto( allow_soft_placement=True)) self.restore_model_flag = self.restore_model() self.bs = BeamSearch( self.model, self.model_config.beam_size, self.data_loader.word_to_id(self.data_config.sentence_start), self.data_loader.word_to_id(self.data_config.sentence_end), self.model_config.abstract_length) def restore_model(self): """ restore model :return: if restore model success, return True, else return False """ ckpt_state = tf.train.get_checkpoint_state( self.model_config.model_path) if not (ckpt_state and ckpt_state.model_checkpoint_path): print('No model to decode yet at {0}'.format( self.model_config.model_path)) return False ckpt_path = os.path.join( self.model_config.model_path, os.path.basename(ckpt_state.model_checkpoint_path)) self.saver.restore(self.session, ckpt_path) return True def decode(self, article): """ decode article to abstract by model :param article: article, which is word id list :return: abstract """ if self.restore_model_flag: """ convert to list, which list length is beam size """ article_batch = article * self.model_config.beam_size article_length_batch = [len(article)] * self.model_config.beam_size best_beam = self.bs.search(self.session, article_batch, article_length_batch)[0] """ get word id after 1, because 1 is start id """ result = [int(word_id) for word_id in best_beam[1:]] return result else: return None
class LstmDecoder(Decoder): def __init__(self, embedder, hidden_dim, num_layers, dropout_p, attention, beam_width): super(LstmDecoder, self).__init__() self._embedder = embedder self._hidden_dim = hidden_dim self._num_layers = num_layers self._attention = attention decoder_input_dim = self._embedder.get_embed_dim() decoder_input_dim += self._hidden_dim # for input-feeding self._lstm = StackedLSTM(decoder_input_dim, self._hidden_dim, num_layers, dropout_p) self._output_projection_layer = nn.Linear( self._hidden_dim, self._embedder.get_vocab_size()) self._start_index = self._embedder.get_init_token_idx() self._eos_index = self._embedder.get_eos_token_idx() self._pad_index = self._embedder.get_pad_token_idx() self._max_decoding_steps = 100 self._beam_search = BeamSearch(self._eos_index, self._max_decoding_steps, beam_width) def init_state(self, state): hidden = state["encoder_hidden"] _, batch_size, _ = hidden.size() cell = hidden.new_zeros(self._num_layers, batch_size, self._hidden_dim) input_feed = hidden.new_zeros(batch_size, self._hidden_dim) return { "decoder_hidden": hidden, "decoder_cell": cell, "decoder_input_feed": input_feed } def forward(self, state, target_tokens=None): source_mask = state["source_mask"] batch_size, _ = source_mask.size() if target_tokens: targets = target_tokens[0] target_mask = utils.make_mask_by_lengths(target_tokens[1]) _, target_sequence_length = targets.size() num_decoding_steps = target_sequence_length - 1 # ignore eos padding else: num_decoding_steps = self._max_decoding_steps last_predictions = source_mask.new_full((batch_size, ), fill_value=self._start_index) step_logits = [] step_predictions = [] for timestep in range(num_decoding_steps): if target_tokens is None: input_choices = last_predictions else: input_choices = targets[:, timestep] output_projections, state = self._forward_step( input_choices, state) step_logits.append(output_projections.unsqueeze(1)) class_probabilities = F.softmax(output_projections, dim=-1) _, predicted_classes = torch.max(class_probabilities, 1) last_predictions = predicted_classes step_predictions.append(last_predictions.unsqueeze(1)) predictions = torch.cat(step_predictions, 1) output_dict = {"predictions": predictions} if target_tokens: logits = torch.cat(step_logits, 1) loss = self._get_loss(logits, targets, target_mask) output_dict["loss"] = loss return output_dict def _get_loss(self, logits, targets, target_mask): relevant_targets = targets[:, 1:].contiguous() _, _, num_classes = logits.size() return F.cross_entropy(logits.view(-1, num_classes), relevant_targets.view(-1), ignore_index=self._pad_index) def _forward_step(self, last_predictions, state): encoder_memory_bank = state["encoder_memory_bank"] source_mask = state["source_mask"] decoder_hidden = state["decoder_hidden"] decoder_cell = state["decoder_cell"] last_predictions_embedding = self._embedder(last_predictions) input_feed = state["decoder_input_feed"] decoder_input = torch.cat((input_feed, last_predictions_embedding), -1) decoder_output, (decoder_hidden, decoder_cell) = self._lstm( decoder_input, (decoder_hidden, decoder_cell)) state["decoder_hidden"] = decoder_hidden state["decoder_cell"] = decoder_cell decoder_output, _ = self._attention(decoder_output, encoder_memory_bank, source_mask) state["decoder_input_feed"] = decoder_output output_projections = self._output_projection_layer(decoder_output) class_log_probabilities = F.log_softmax(output_projections, dim=-1) return class_log_probabilities, state def _forward_beam_search(self, state): source_mask = state["source_mask"] batch_size, _ = source_mask.size() batch_size = state["source_mask"].size()[0] start_predictions = state["source_mask"].new_full( (batch_size, ), fill_value=self._start_index) all_top_k_predictions, log_probabilities = self._beam_search.search( start_predictions, state, self._forward_step) output_dict = { "class_log_probabilities": log_probabilities, "predictions": all_top_k_predictions, } return output_dict
classes = list( "()-.・0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ゙゚アイウエオカキクケコサシスセソタチツテトナニヌネノハヒフヘホマミムメモヤユヨ" "ラリルレロワンガギグゲゴザジズゼゾダヂヅデドバビブベボヴヷパピプペポ") beam_search = BeamSearch(beam_width, lm_factor, topk, classes, lm_model) mapping_inout_result = [] s_time = time.time() for input_fn, lbl in mapping_inout.items(): # prepare input ctc_logit_matrix = np.load(input_fn) ctc_logit_matrix = preprocess_ocr_logit(ctc_logit_matrix, softmax_theta) # search predict = beam_search.search(ctc_logit_matrix) # post-processing predict = jaconv.normalize(predict) lbl = jaconv.normalize(lbl) # track mapping_inout_result += [(predict, lbl)] print("take time:", time.time() - s_time) """ 3. Calculating accuracy """ preds = [elem[0] for elem in mapping_inout_result] lbls = [elem[1] for elem in mapping_inout_result]
class Seq2SeqModel(nn.Module): def __init__(self, encoder, decoder, decoding_style="greedy", special_tokens_dict=None, max_decoding_steps=128, beam_width=10): super().__init__() self.encoder = encoder self.decoder = decoder if decoding_style not in ["greedy", "beam_search"]: print(f"{decoding_style} is not allowed parameter") decoding_style = "greedy" self.decoding_style = decoding_style if special_tokens_dict is None: self.special_tokens_dict = {"<pad>":0, "<bos>":1, "<eos>":2, "<unk>":3} else: self.special_tokens_dict = special_tokens_dict self.max_decoding_steps = max_decoding_steps self._beam_search = BeamSearch(self.special_tokens_dict["<eos>"], max_decoding_steps, beam_width) def forward(self, src, oracle, teacher_forcing_ratio=0.5): state = self.encoder(src) if teacher_forcing_ratio == 1: outputs, state = self.decoder(state=state, oracle=oracle) else: len_oracle = oracle.shape[1] last_pred = oracle[:, 0].unsqueeze(dim=-1) outputs = [] for i in range(len_oracle): input = oracle[:,i].unsqueeze(dim=-1) \ if random.random() < teacher_forcing_ratio else last_pred output, state = self.decoder(state=state, oracle=input) outputs.append(output) last_pred = output.argmax(dim=-1) outputs = torch.cat(outputs, dim=1) return outputs def decode(self, src): """ src : tensor src.shape = (bs, length) preds : tensor preds.shape = (bs, max_length) """ state = self.encoder(src) if self.decoding_style == "greedy": preds = self.greedy_search(state) else: preds = self.beam_search(state) return preds def greedy_search(self, state): bs = state["hidden"].size(1) device = [j for i, j in self.encoder.named_parameters()][0].device bos_token_id = self.special_tokens_dict["<bos>"] last_pred = torch.full((bs, 1), bos_token_id, dtype=torch.int64) pred = [] for i in range(self.max_decoding_steps): last_pred = last_pred.to(device) output, state = \ self.decoder(state=state, oracle=last_pred) next_pred = output.argmax(dim=-1).detach().cpu().view(-1, 1) pred.append(next_pred) last_pred = next_pred pred = torch.cat(pred, dim=1) return pred def beam_search(self, state): bs = state["encoder_output"].size(0) device = [j for i, j in self.encoder.named_parameters()][0].device bos_token_id = self.special_tokens_dict["<bos>"] start_predictions = torch.full((bs,), bos_token_id, dtype=torch.int64, device=device) all_top_k_predictions, log_probabilities = self._beam_search.search( start_predictions, state, self.decoder.forward ) ind = log_probabilities.argmax(dim=-1) return all_top_k_predictions[range(bs), ind, :]