Example #1
0
    def beam_search(self, input_seq, beam_size=3, attentionOverrideMap=None, correctionMap=None, unk_map=None,
                    beam_length=0.5, beam_coverage=0.5,
                    max_length=MAX_LENGTH):

        torch.set_grad_enabled(False)

        input_seqs = [indexes_from_sentence(self.input_lang, input_seq)]
        input_lengths = [len(seq) for seq in input_seqs]
        input_batches = Variable(torch.LongTensor(input_seqs)).transpose(0, 1)

        if use_cuda:
            input_batches = input_batches.cuda()

        self.encoder.train(False)
        self.decoder.train(False)

        encoder_outputs, encoder_hidden = self.encoder(input_batches, input_lengths, None)

        decoder_hidden = encoder_hidden

        beam_search = BeamSearch(self.decoder, encoder_outputs, decoder_hidden, self.output_lang, beam_size,
                                 attentionOverrideMap,
                                 correctionMap, unk_map, beam_length=beam_length, beam_coverage=beam_coverage)
        result = beam_search.search()

        self.encoder.train(True)
        self.decoder.train(True)

        torch.set_grad_enabled(True)

        return result  # Return a list of indexes, one for each word in the sentence, plus EOS
Example #2
0
    def _get_top_k_sequences(self, log_probs, wordpiece_mask, k):
        batch_size = log_probs.size()[0]
        seq_length = log_probs.size()[1]

        beam_search = BeamSearch(self._end_index,
                                 max_steps=seq_length,
                                 beam_size=k,
                                 per_node_beam_size=self._per_node_beam_size)
        beam_log_probs = torch.nn.functional.pad(
            log_probs, pad=(0, 2, 0, 0, 0, 0), value=-1e7
        )  # add low log probabilites for start and end tags used in the beam search
        start_predictions = beam_log_probs.new_full(
            (batch_size, ), fill_value=self._start_index).long()

        # Shape: (batch_size, beam_size, seq_length)
        top_k_predictions, seq_log_probs = beam_search.search(
            start_predictions, {
                'log_probs': beam_log_probs,
                'wordpiece_mask': wordpiece_mask,
                'step_num': beam_log_probs.new_zeros((batch_size, )).long()
            }, self.take_step)

        # get rid of start and end tags if they slipped in
        top_k_predictions[top_k_predictions > 2] = 0

        return top_k_predictions
 def beam_search(self,
                 initial_sequence,
                 forbid_movies=None,
                 temperature=1,
                 **kwargs):
     """
     Beam search sentence generation
     :param initial_sequence: list giving the initial sequence of tokens
     :param kwargs: additional parameters to pass to model forward pass (e.g. a conditioning context)
     :return:
     """
     beam_search = BeamSearch(self.beam_size, initial_sequence,
                              self.word2id["</s>"])
     beams = beam_search.beams
     for i in range(self.max_sequence_length):
         # compute probabilities for each beam
         probabilities = []
         for beam in beams:
             # add batch_dimension
             model_input = Variable(torch.LongTensor(
                 beam.sequence)).unsqueeze(0)
             if self.model.cuda_available:
                 model_input = model_input.cuda()
             beam_forbidden_movies = forbid_movies.union(
                 beam.mentioned_movies)
             prob = self.model(input=model_input,
                               lengths=[len(beam.sequence)],
                               log_probabilities=False,
                               forbid_movies=beam_forbidden_movies,
                               temperature=temperature,
                               **kwargs)
             # get probabilities for the next token to generate
             probabilities.append(prob[0, -1, :].cpu())
         # update beams
         beams = beam_search.search(probabilities,
                                    n_gram_block=self.n_gram_block)
         # replace movie names with the corresponding words
         for beam in beams:
             if beam.sequence[-1] > len(self.word2id):
                 # update the list of movies mentioned for preventing repeated recommendations
                 beam.mentioned_movies.add(beam.sequence[-1] -
                                           len(self.word2id))
                 beam.sequence[-1:] = replace_movie_with_words(
                     beam.sequence[-1], self.movieId2name, self.word2id)
     return beams
Example #4
0
class Prediction:
    def __init__(self, args):
        """
        :param model_dir: model dir path
        :param vocab_file: vocab file path
        """
        self.tf = import_tf(0)

        self.args = args
        self.model_dir = args.logdir
        self.vocab_file = args.vocab
        self.token2idx, self.idx2token = _load_vocab(args.vocab)

        hparams = Hparams()
        parser = hparams.parser
        self.hp = parser.parse_args()

        self.model = Transformer(self.hp)

        self._add_placeholder()
        self._init_graph()

    def _init_graph(self):
        """
        init graph
        """
        self.ys = (self.input_y, None, None)
        self.xs = (self.input_x, None)
        self.memory = self.model.encode(self.xs, False)[0]
        self.logits = self.model.decode(self.xs, self.ys, self.memory, False)[0]

        ckpt = self.tf.train.get_checkpoint_state(self.model_dir).all_model_checkpoint_paths[-1]

        graph = self.logits.graph
        sess_config = self.tf.ConfigProto(allow_soft_placement=True)
        sess_config.gpu_options.allow_growth = True

        saver = self.tf.train.Saver()
        self.sess = self.tf.Session(config=sess_config, graph=graph)

        self.sess.run(self.tf.global_variables_initializer())
        self.tf.reset_default_graph()
        saver.restore(self.sess, ckpt)

        self.bs = BeamSearch(self.model,
                             self.hp.beam_size,
                             list(self.idx2token.keys())[2],
                             list(self.idx2token.keys())[3],
                             self.idx2token,
                             self.hp.maxlen2,
                             self.input_x,
                             self.input_y,
                             self.logits)

    def predict(self, content):
        """
        abstract prediction by beam search
        :param content: article content
        :return: prediction result
        """
        input_x = content.split()
        while len(input_x) < self.args.maxlen1: input_x.append('<pad>')
        input_x = input_x[:self.args.maxlen1]

        input_x = [self.token2idx.get(s, self.token2idx['<unk>']) for s in input_x]

        memory = self.sess.run(self.memory, feed_dict={self.input_x: [input_x]})

        return self.bs.search(self.sess, input_x, memory[0])

    def _add_placeholder(self):
        """
        add tensorflow placeholder
        """
        self.input_x = self.tf.placeholder(dtype=self.tf.int32, shape=[None, self.args.maxlen1], name='input_x')
        self.input_y = self.tf.placeholder(dtype=self.tf.int32, shape=[None, None], name='input_y')
class BSDecoder(object):
    def __init__(self, model, batch_reader, model_config, data_config, vocab,
                 data_loader):
        self.model = model
        self.batch_reader = batch_reader
        self.model_config = model_config
        self.data_config = data_config
        self.vocab = vocab
        self.data_loader = data_loader

        self.saver = tf.train.Saver()
        self.session = tf.Session(config=tf.ConfigProto(
            allow_soft_placement=True))

        self.restore_model_flag = self.restore_model()

        self.bs = BeamSearch(
            self.model, self.model_config.beam_size,
            self.data_loader.word_to_id(self.data_config.sentence_start),
            self.data_loader.word_to_id(self.data_config.sentence_end),
            self.model_config.abstract_length)

    def restore_model(self):
        """
        restore model
        :return: if restore model success, return True, else return False
        """
        ckpt_state = tf.train.get_checkpoint_state(
            self.model_config.model_path)
        if not (ckpt_state and ckpt_state.model_checkpoint_path):
            print('No model to decode yet at {0}'.format(
                self.model_config.model_path))
            return False

        ckpt_path = os.path.join(
            self.model_config.model_path,
            os.path.basename(ckpt_state.model_checkpoint_path))

        self.saver.restore(self.session, ckpt_path)

        return True

    def decode(self, article):
        """
        decode article to abstract by model
        :param article: article, which is word id list
        :return: abstract
        """
        if self.restore_model_flag:
            """
                convert to list, which list length is beam size
            """
            article_batch = article * self.model_config.beam_size
            article_length_batch = [len(article)] * self.model_config.beam_size

            best_beam = self.bs.search(self.session, article_batch,
                                       article_length_batch)[0]
            """
                get word id after 1, because 1 is start id
            """
            result = [int(word_id) for word_id in best_beam[1:]]

            return result

        else:
            return None
class LstmDecoder(Decoder):
    def __init__(self, embedder, hidden_dim, num_layers, dropout_p, attention,
                 beam_width):
        super(LstmDecoder, self).__init__()
        self._embedder = embedder
        self._hidden_dim = hidden_dim
        self._num_layers = num_layers
        self._attention = attention
        decoder_input_dim = self._embedder.get_embed_dim()
        decoder_input_dim += self._hidden_dim  # for input-feeding

        self._lstm = StackedLSTM(decoder_input_dim, self._hidden_dim,
                                 num_layers, dropout_p)

        self._output_projection_layer = nn.Linear(
            self._hidden_dim, self._embedder.get_vocab_size())

        self._start_index = self._embedder.get_init_token_idx()
        self._eos_index = self._embedder.get_eos_token_idx()
        self._pad_index = self._embedder.get_pad_token_idx()
        self._max_decoding_steps = 100

        self._beam_search = BeamSearch(self._eos_index,
                                       self._max_decoding_steps, beam_width)

    def init_state(self, state):
        hidden = state["encoder_hidden"]
        _, batch_size, _ = hidden.size()
        cell = hidden.new_zeros(self._num_layers, batch_size, self._hidden_dim)
        input_feed = hidden.new_zeros(batch_size, self._hidden_dim)
        return {
            "decoder_hidden": hidden,
            "decoder_cell": cell,
            "decoder_input_feed": input_feed
        }

    def forward(self, state, target_tokens=None):
        source_mask = state["source_mask"]
        batch_size, _ = source_mask.size()

        if target_tokens:
            targets = target_tokens[0]
            target_mask = utils.make_mask_by_lengths(target_tokens[1])
            _, target_sequence_length = targets.size()
            num_decoding_steps = target_sequence_length - 1  # ignore eos padding
        else:
            num_decoding_steps = self._max_decoding_steps

        last_predictions = source_mask.new_full((batch_size, ),
                                                fill_value=self._start_index)

        step_logits = []
        step_predictions = []
        for timestep in range(num_decoding_steps):
            if target_tokens is None:
                input_choices = last_predictions
            else:
                input_choices = targets[:, timestep]

            output_projections, state = self._forward_step(
                input_choices, state)
            step_logits.append(output_projections.unsqueeze(1))
            class_probabilities = F.softmax(output_projections, dim=-1)
            _, predicted_classes = torch.max(class_probabilities, 1)
            last_predictions = predicted_classes
            step_predictions.append(last_predictions.unsqueeze(1))

        predictions = torch.cat(step_predictions, 1)
        output_dict = {"predictions": predictions}

        if target_tokens:
            logits = torch.cat(step_logits, 1)
            loss = self._get_loss(logits, targets, target_mask)
            output_dict["loss"] = loss

        return output_dict

    def _get_loss(self, logits, targets, target_mask):
        relevant_targets = targets[:, 1:].contiguous()
        _, _, num_classes = logits.size()
        return F.cross_entropy(logits.view(-1, num_classes),
                               relevant_targets.view(-1),
                               ignore_index=self._pad_index)

    def _forward_step(self, last_predictions, state):
        encoder_memory_bank = state["encoder_memory_bank"]
        source_mask = state["source_mask"]
        decoder_hidden = state["decoder_hidden"]
        decoder_cell = state["decoder_cell"]

        last_predictions_embedding = self._embedder(last_predictions)

        input_feed = state["decoder_input_feed"]
        decoder_input = torch.cat((input_feed, last_predictions_embedding), -1)

        decoder_output, (decoder_hidden, decoder_cell) = self._lstm(
            decoder_input, (decoder_hidden, decoder_cell))
        state["decoder_hidden"] = decoder_hidden
        state["decoder_cell"] = decoder_cell

        decoder_output, _ = self._attention(decoder_output,
                                            encoder_memory_bank, source_mask)
        state["decoder_input_feed"] = decoder_output

        output_projections = self._output_projection_layer(decoder_output)
        class_log_probabilities = F.log_softmax(output_projections, dim=-1)
        return class_log_probabilities, state

    def _forward_beam_search(self, state):
        source_mask = state["source_mask"]
        batch_size, _ = source_mask.size()

        batch_size = state["source_mask"].size()[0]
        start_predictions = state["source_mask"].new_full(
            (batch_size, ), fill_value=self._start_index)

        all_top_k_predictions, log_probabilities = self._beam_search.search(
            start_predictions, state, self._forward_step)

        output_dict = {
            "class_log_probabilities": log_probabilities,
            "predictions": all_top_k_predictions,
        }
        return output_dict
Example #7
0
    classes = list(
        "()-.・0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ゙゚アイウエオカキクケコサシスセソタチツテトナニヌネノハヒフヘホマミムメモヤユヨ"
        "ラリルレロワンガギグゲゴザジズゼゾダヂヅデドバビブベボヴヷパピプペポ")

    beam_search = BeamSearch(beam_width, lm_factor, topk, classes, lm_model)
    mapping_inout_result = []

    s_time = time.time()
    for input_fn, lbl in mapping_inout.items():
        # prepare input
        ctc_logit_matrix = np.load(input_fn)
        ctc_logit_matrix = preprocess_ocr_logit(ctc_logit_matrix,
                                                softmax_theta)

        # search
        predict = beam_search.search(ctc_logit_matrix)

        # post-processing
        predict = jaconv.normalize(predict)
        lbl = jaconv.normalize(lbl)

        # track
        mapping_inout_result += [(predict, lbl)]

    print("take time:", time.time() - s_time)
    """
    3. Calculating accuracy 
    """
    preds = [elem[0] for elem in mapping_inout_result]
    lbls = [elem[1] for elem in mapping_inout_result]
Example #8
0
class Seq2SeqModel(nn.Module):

    def __init__(self, encoder, decoder,
                decoding_style="greedy", special_tokens_dict=None,
                max_decoding_steps=128, beam_width=10):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        if decoding_style not in ["greedy", "beam_search"]:
            print(f"{decoding_style} is not allowed parameter")
            decoding_style = "greedy"
        self.decoding_style = decoding_style
        if special_tokens_dict is None:
            self.special_tokens_dict = {"<pad>":0, "<bos>":1, "<eos>":2, "<unk>":3}
        else:
            self.special_tokens_dict = special_tokens_dict
        self.max_decoding_steps = max_decoding_steps
        self._beam_search = BeamSearch(self.special_tokens_dict["<eos>"],
                                       max_decoding_steps,
                                       beam_width)

    def forward(self, src, oracle, teacher_forcing_ratio=0.5):
        state = self.encoder(src)
        if teacher_forcing_ratio == 1:
            outputs, state = self.decoder(state=state, oracle=oracle)
        else:
            len_oracle = oracle.shape[1]
            last_pred = oracle[:, 0].unsqueeze(dim=-1)
            outputs = []
            for i in range(len_oracle):
                input = oracle[:,i].unsqueeze(dim=-1) \
                    if random.random() < teacher_forcing_ratio else last_pred
                output, state = self.decoder(state=state, oracle=input)
                outputs.append(output)
                last_pred = output.argmax(dim=-1)
            outputs = torch.cat(outputs, dim=1)
        return outputs

    def decode(self, src):
        """
        src : tensor
            src.shape = (bs, length)
        preds : tensor
            preds.shape = (bs, max_length)
        """
        state = self.encoder(src)
        if self.decoding_style == "greedy":
            preds = self.greedy_search(state)
        else:
            preds = self.beam_search(state)
        return preds

    def greedy_search(self, state):
        bs = state["hidden"].size(1)
        device = [j for i, j in self.encoder.named_parameters()][0].device
        bos_token_id = self.special_tokens_dict["<bos>"]
        last_pred = torch.full((bs, 1), bos_token_id, dtype=torch.int64)  
        pred = []
        for i in range(self.max_decoding_steps):
            last_pred = last_pred.to(device)
            output, state = \
                self.decoder(state=state, oracle=last_pred)
            next_pred = output.argmax(dim=-1).detach().cpu().view(-1, 1)
            pred.append(next_pred)
            last_pred = next_pred
        pred = torch.cat(pred, dim=1)
        return pred

    def beam_search(self, state):
        bs = state["encoder_output"].size(0)
        device = [j for i, j in self.encoder.named_parameters()][0].device
        bos_token_id = self.special_tokens_dict["<bos>"]
        start_predictions = torch.full((bs,), bos_token_id, dtype=torch.int64, device=device)
        all_top_k_predictions, log_probabilities = self._beam_search.search(
            start_predictions, state, self.decoder.forward
        )
        ind = log_probabilities.argmax(dim=-1)
        return all_top_k_predictions[range(bs), ind, :]