class Predict():
    @timer(module='initalize predicter')
    def __init__(self):
        self.DEVICE = config.DEVICE

        dataset = PairDataset(config.data_path,
                              max_src_len=config.max_src_len,
                              max_tgt_len=config.max_tgt_len,
                              truncate_src=config.truncate_src,
                              truncate_tgt=config.truncate_tgt)

        self.vocab = dataset.build_vocab(embed_file=config.embed_file)

        self.model = PGN(self.vocab)
        self.stop_word = list(
            set([
                self.vocab[x.strip()] for x in open(
                    config.stop_word_file, encoding='utf-8').readlines()
            ]))
        self.model.load_model()
        self.model.to(self.DEVICE)

    def greedy_search(self, x, max_sum_len, len_oovs, x_padding_masks):
        """Function which returns a summary by always picking
           the highest probability option conditioned on the previous word.

        Args:
            x (Tensor): Input sequence as the source.
            max_sum_len (int): The maximum length a summary can have.
            len_oovs (Tensor): Numbers of out-of-vocabulary tokens.
            x_padding_masks (Tensor):
                The padding masks for the input sequences
                with shape (batch_size, seq_len).

        Returns:
            summary (list): The token list of the result summary.
        """

        # Get encoder output and states.Call encoder forward propagation
        ###########################################
        #          TODO: module 4 task 2          #
        ###########################################
        # use decoder to generate vocab distribution for the next token
        encoder_output, encoder_states = self.model.encoder(
            replace_oovs(x, self.vocab), self.model.decoder.embedding)

        # Initialize decoder's hidden states with encoder's hidden states.
        decoder_states = self.model.reduce_state(encoder_states)

        # Initialize decoder's input at time step 0 with the SOS token.
        x_t = torch.ones(1) * self.vocab.SOS
        x_t = x_t.to(self.DEVICE, dtype=torch.int64)
        summary = [self.vocab.SOS]
        coverage_vector = torch.zeros((1, x.shape[1])).to(self.DEVICE)
        # Generate hypothesis with maximum decode step.
        while int(x_t.item()) != (self.vocab.EOS) \
                and len(summary) < max_sum_len:

            context_vector, attention_weights, coverage_vector = \
                self.model.attention(decoder_states,
                                     encoder_output,
                                     x_padding_masks,
                                     coverage_vector)
            p_vocab, decoder_states, p_gen = \
                self.model.decoder(x_t.unsqueeze(1),
                                   decoder_states,
                                   context_vector)
            final_dist = self.model.get_final_distribution(
                x, p_gen, p_vocab, attention_weights, torch.max(len_oovs))
            # Get next token with maximum probability.
            x_t = torch.argmax(final_dist, dim=1).to(self.DEVICE)
            decoder_word_idx = x_t.item()
            summary.append(decoder_word_idx)
            x_t = replace_oovs(x_t, self.vocab)

        return summary


#     @timer('best k')

    def best_k(self, beam, k, encoder_output, x_padding_masks, x, len_oovs):
        """Get best k tokens to extend the current sequence at the current time step.

        Args:
            beam (untils.Beam): The candidate beam to be extended.
            k (int): Beam size.
            encoder_output (Tensor): The lstm output from the encoder.
            x_padding_masks (Tensor):
                The padding masks for the input sequences.
            x (Tensor): Source token ids.
            len_oovs (Tensor): Number of oov tokens in a batch.

        Returns:
            best_k (list(Beam)): The list of best k candidates.

        """
        # use decoder to generate vocab distribution for the next token
        x_t = torch.tensor(beam.tokens[-1]).reshape(1, 1)
        x_t = x_t.to(self.DEVICE)

        # Get context vector from attention network.
        context_vector, attention_weights, coverage_vector = \
            self.model.attention(beam.decoder_states,
                                 encoder_output,
                                 x_padding_masks,
                                 beam.coverage_vector)

        # Replace the indexes of OOV words with the index of OOV token
        # to prevent index-out-of-bound error in the decoder.

        p_vocab, decoder_states, p_gen = \
            self.model.decoder(replace_oovs(x_t, self.vocab),
                               beam.decoder_states,
                               context_vector)

        final_dist = self.model.get_final_distribution(x, p_gen, p_vocab,
                                                       attention_weights,
                                                       torch.max(len_oovs))
        # Calculate log probabilities.
        log_probs = torch.log(final_dist.squeeze())
        # Filter forbidden tokens.
        if len(beam.tokens) == 1:
            forbidden_ids = [
                self.vocab[u"台独"], self.vocab[u"吸毒"], self.vocab[u"黄赌毒"]
            ]
            log_probs[forbidden_ids] = -float('inf')
        # EOS token penalty. Follow the definition in
        # https://opennmt.net/OpenNMT/translation/beam_search/.
        log_probs[self.vocab.EOS] *= \
            config.gamma * x.size()[1] / len(beam.tokens)

        log_probs[self.vocab.UNK] = -float('inf')
        # Get top k tokens and the corresponding logprob.
        topk_probs, topk_idx = torch.topk(log_probs, k)

        # Extend the current hypo with top k tokens, resulting k new hypos.
        best_k = [
            beam.extend(x, log_probs[x], decoder_states, coverage_vector)
            for x in topk_idx.tolist()
        ]

        return best_k

    def beam_search(self, x, max_sum_len, beam_width, len_oovs,
                    x_padding_masks):
        """Using beam search to generate summary.

        Args:
            x (Tensor): Input sequence as the source.
            max_sum_len (int): The maximum length a summary can have.
            beam_width (int): Beam size.
            max_oovs (int): Number of out-of-vocabulary tokens.
            x_padding_masks (Tensor):
                The padding masks for the input sequences.

        Returns:
            result (list(Beam)): The list of best k candidates.
        """
        # run body_sequence input through encoder. Call encoder forward propagation
        ###########################################
        #          TODO: module 4 task 2          #
        ###########################################
        encoder_output, encoder_states = self.model.encoder(
            replace_oovs(x, self.vocab), self.model.decoder.embedding)
        coverage_vector = torch.zeros((1, x.shape[1])).to(self.DEVICE)
        # initialize decoder states with encoder forward states
        decoder_states = self.model.reduce_state(encoder_states)

        # initialize the hypothesis with a class Beam instance.

        init_beam = Beam([self.vocab.SOS], [0], decoder_states,
                         coverage_vector)

        # get the beam size and create a list for stroing current candidates
        # and a list for completed hypothesis
        k = beam_width
        curr, completed = [init_beam], []

        # use beam search for max_sum_len (maximum length) steps
        for _ in range(max_sum_len):
            # get k best hypothesis when adding a new token

            topk = []
            for beam in curr:
                # When an EOS token is generated, add the hypo to the completed
                # list and decrease beam size.
                if beam.tokens[-1] == self.vocab.EOS:
                    completed.append(beam)
                    k -= 1
                    continue
                for can in self.best_k(beam, k,
                                       encoder_output, x_padding_masks, x,
                                       torch.max(len_oovs)):
                    # Using topk as a heap to keep track of top k candidates.
                    # Using the sequence scores of the hypos to campare
                    # and object ids to break ties.
                    add2heap(topk, (can.seq_score(), id(can), can), k)

            curr = [items[2] for items in topk]
            # stop when there are enough completed hypothesis
            if len(completed) == beam_width:
                break
        # When there are not engouh completed hypotheses,
        # take whatever when have in current best k as the final candidates.
        completed += curr
        # sort the hypothesis by normalized probability and choose the best one
        result = sorted(completed, key=lambda x: x.seq_score(),
                        reverse=True)[0].tokens
        return result

    @timer(module='doing prediction')
    def predict(self, text, tokenize=True, beam_search=True):
        """Generate summary.

        Args:
            text (str or list): Source.
            tokenize (bool, optional):
                Whether to do tokenize or not. Defaults to True.
            beam_search (bool, optional):
                Whether to use beam search or not.
                Defaults to True (means using greedy search).

        Returns:
            str: The final summary.
        """
        if isinstance(text, str) and tokenize:
            text = list(jieba.cut(text))
        x, oov = source2ids(text, self.vocab)
        x = torch.tensor(x).to(self.DEVICE)
        len_oovs = torch.tensor([len(oov)]).to(self.DEVICE)
        x_padding_masks = torch.ne(x, 0).byte().float()
        if beam_search:
            summary = self.beam_search(x.unsqueeze(0),
                                       max_sum_len=config.max_dec_steps,
                                       beam_width=config.beam_size,
                                       len_oovs=len_oovs,
                                       x_padding_masks=x_padding_masks)
        else:
            summary = self.greedy_search(x.unsqueeze(0),
                                         max_sum_len=config.max_dec_steps,
                                         len_oovs=len_oovs,
                                         x_padding_masks=x_padding_masks)
        summary = outputids2words(summary, oov, self.vocab)
        return summary.replace('<SOS>', '').replace('<EOS>', '').strip()
Exemple #2
0
def train(dataset, val_dataset, v, start_epoch=0):
    """Train the model, evaluate it and store it.

    Args:
        dataset (dataset.PairDataset): The training dataset.
        val_dataset (dataset.PairDataset): The evaluation dataset.
        v (vocab.Vocab): The vocabulary built from the training dataset.
        start_epoch (int, optional): The starting epoch number. Defaults to 0.
    """

    DEVICE = torch.device("cuda" if config.is_cuda else "cpu")

    model = PGN(v)
    model.load_model()
    model.to(DEVICE)
    if config.fine_tune:
        # In fine-tuning mode, we fix the weights of all parameters except attention.wc.
        print('Fine-tuning mode.')
        for name, params in model.named_parameters():
            if name != 'attention.wc.weight':
                params.requires_grad = False
    # forward
    print("loading data")
    train_data = SampleDataset(dataset.pairs, v)
    val_data = SampleDataset(val_dataset.pairs, v)

    print("initializing optimizer")

    # Define the optimizer.
    optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
    train_dataloader = DataLoader(dataset=train_data,
                                  batch_size=config.batch_size,
                                  shuffle=True,
                                  collate_fn=collate_fn)

    val_losses = np.inf
    if (os.path.exists(config.losses_path)):
        with open(config.losses_path, 'rb') as f:
            val_losses = pickle.load(f)


#     torch.cuda.empty_cache()
# SummaryWriter: Log writer used for TensorboardX visualization.
    writer = SummaryWriter(config.log_path)
    # tqdm: A tool for drawing progress bars during training.
    # scheduled_sampler : A tool for choosing teacher_forcing or not
    num_epochs = len(range(start_epoch, config.epochs))
    scheduled_sampler = ScheduledSampler(num_epochs)
    if config.scheduled_sampling:
        print('scheduled_sampling mode.')
    #  teacher_forcing = True

    with tqdm(total=config.epochs) as epoch_progress:
        for epoch in range(start_epoch, config.epochs):
            print(config_info(config))
            batch_losses = []  # Get loss of each batch.
            num_batches = len(train_dataloader)
            # set a teacher_forcing signal
            if config.scheduled_sampling:
                teacher_forcing = scheduled_sampler.teacher_forcing(
                    epoch - start_epoch)
            else:
                teacher_forcing = True
            print('teacher_forcing = {}'.format(teacher_forcing))
            with tqdm(total=num_batches) as batch_progress:
                for batch, data in enumerate(tqdm(train_dataloader)):
                    x, y, x_len, y_len, oov, len_oovs = data
                    assert not np.any(np.isnan(x.numpy()))
                    if config.is_cuda:  # Training with GPUs.
                        x = x.to(DEVICE)
                        y = y.to(DEVICE)
                        x_len = x_len.to(DEVICE)
                        len_oovs = len_oovs.to(DEVICE)

                    model.train()  # Sets the module in training mode.
                    optimizer.zero_grad()  # Clear gradients.
                    # Calculate loss.  Call model forward propagation
                    loss = model(x,
                                 x_len,
                                 y,
                                 len_oovs,
                                 batch=batch,
                                 num_batches=num_batches,
                                 teacher_forcing=teacher_forcing)
                    batch_losses.append(loss.item())
                    loss.backward()  # Backpropagation.

                    # Do gradient clipping to prevent gradient explosion.
                    clip_grad_norm_(model.encoder.parameters(),
                                    config.max_grad_norm)
                    clip_grad_norm_(model.decoder.parameters(),
                                    config.max_grad_norm)
                    clip_grad_norm_(model.attention.parameters(),
                                    config.max_grad_norm)
                    optimizer.step()  # Update weights.

                    # Output and record epoch loss every 100 batches.
                    if (batch % 32) == 0:
                        batch_progress.set_description(f'Epoch {epoch}')
                        batch_progress.set_postfix(Batch=batch,
                                                   Loss=loss.item())
                        batch_progress.update()
                        # Write loss for tensorboard.
                        writer.add_scalar(f'Average loss for epoch {epoch}',
                                          np.mean(batch_losses),
                                          global_step=batch)
            # Calculate average loss over all batches in an epoch.
            epoch_loss = np.mean(batch_losses)

            epoch_progress.set_description(f'Epoch {epoch}')
            epoch_progress.set_postfix(Loss=epoch_loss)
            epoch_progress.update()

            avg_val_loss = evaluate(model, val_data, epoch)

            print('training loss:{}'.format(epoch_loss),
                  'validation loss:{}'.format(avg_val_loss))

            # Update minimum evaluating loss.
            if (avg_val_loss < val_losses):
                torch.save(model.encoder, config.encoder_save_name)
                torch.save(model.decoder, config.decoder_save_name)
                torch.save(model.attention, config.attention_save_name)
                torch.save(model.reduce_state, config.reduce_state_save_name)
                val_losses = avg_val_loss
            with open(config.losses_path, 'wb') as f:
                pickle.dump(val_losses, f)

    writer.close()
def train(dataset, val_dataset, v, start_epoch=0):
    """Train the model, evaluate it and store it.

    Args:
        dataset (dataset.PairDataset): The training dataset.
        val_dataset (dataset.PairDataset): The evaluation dataset.
        v (vocab.Vocab): The vocabulary built from the training dataset.
        start_epoch (int, optional): The starting epoch number. Defaults to 0.
    """
    torch.autograd.set_detect_anomaly(True)
    DEVICE = torch.device("cuda" if config.is_cuda else "cpu")

    model = PGN(v)
    model.load_model()
    model.to(DEVICE)
    if config.fine_tune:
        # In fine-tuning mode, we fix the weights of all parameters except attention.wc.
        logging.info('Fine-tuning mode.')
        for name, params in model.named_parameters():
            if name != 'attention.wc.weight':
                params.requires_grad = False
    # forward
    logging.info("loading data")
    train_data = dataset
    val_data = val_dataset

    logging.info("initializing optimizer")

    # Define the optimizer.
    #     optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
    optimizer = optim.Adagrad(
        model.parameters(),
        lr=config.learning_rate,
        initial_accumulator_value=config.initial_accumulator_value)
    scheduler = StepLR(optimizer, step_size=10, gamma=0.2)  # 学习率调整
    train_dataloader = DataLoader(dataset=train_data,
                                  batch_size=config.batch_size,
                                  shuffle=True,
                                  collate_fn=collate_fn)

    val_loss = np.inf
    if (os.path.exists(config.losses_path)):
        with open(config.losses_path, 'r') as f:
            val_loss = float(f.readlines()[-1].split("=")[-1])
            logging.info("the last best val loss is: " + str(val_loss))


#     torch.cuda.empty_cache()
# SummaryWriter: Log writer used for TensorboardX visualization.
    writer = SummaryWriter(config.log_path)
    # tqdm: A tool for drawing progress bars during training.
    early_stopping_count = 0

    logging.info("start training model {}, ".format(config.model_name) + \
        "epoch : {}, ".format(config.epochs) +
        "batch_size : {}, ".format(config.batch_size) +
        "num batches: {}, ".format(len(train_dataloader)))

    for epoch in range(start_epoch, config.epochs):
        batch_losses = []  # Get loss of each batch.
        num_batches = len(train_dataloader)
        #             with tqdm(total=num_batches//100) as batch_progress:
        for batch, data in enumerate(train_dataloader):
            x, y, x_len, y_len, oov, len_oovs, img_vec = data
            assert not np.any(np.isnan(x.numpy()))
            if config.is_cuda:  # Training with GPUs.
                x = x.to(DEVICE)
                y = y.to(DEVICE)
                x_len = x_len.to(DEVICE)
                len_oovs = len_oovs.to(DEVICE)
                img_vec = img_vec.to(DEVICE)
            if batch == 0:
                logging.info("x: %s, shape: %s" % (x, x.shape))
                logging.info("y: %s, shape: %s" % (y, y.shape))
                logging.info("oov: %s" % oov)
                logging.info("img_vec: %s, shape: %s" %
                             (img_vec, img_vec.shape))

            model.train()  # Sets the module in training mode.
            optimizer.zero_grad()  # Clear gradients.

            loss = model(x,
                         y,
                         len_oovs,
                         img_vec,
                         batch=batch,
                         num_batches=num_batches)
            batch_losses.append(loss.item())
            loss.backward()  # Backpropagation.

            # Do gradient clipping to prevent gradient explosion.
            clip_grad_norm_(model.encoder.parameters(), config.max_grad_norm)
            clip_grad_norm_(model.decoder.parameters(), config.max_grad_norm)
            clip_grad_norm_(model.attention.parameters(), config.max_grad_norm)
            clip_grad_norm_(model.reduce_state.parameters(),
                            config.max_grad_norm)
            optimizer.step()  # Update weights.
            #             scheduler.step()

            #                     # Output and record epoch loss every 100 batches.
            if (batch % 100) == 0:
                #                         batch_progress.set_description(f'Epoch {epoch}')
                #                         batch_progress.set_postfix(Batch=batch,
                #                                                    Loss=loss.item())
                #                         batch_progress.update()
                #                         # Write loss for tensorboard.
                writer.add_scalar(f'Average_loss_for_epoch_{epoch}',
                                  np.mean(batch_losses),
                                  global_step=batch)
                logging.info('epoch: {}, batch:{}, training loss:{}'.format(
                    epoch, batch, np.mean(batch_losses)))

        # Calculate average loss over all batches in an epoch.
        epoch_loss = np.mean(batch_losses)

        #             epoch_progress.set_description(f'Epoch {epoch}')
        #             epoch_progress.set_postfix(Loss=epoch_loss)
        #             epoch_progress.update()

        avg_val_loss = evaluate(model, val_data, epoch)

        logging.info('epoch: {} '.format(epoch) +
                     'training loss:{} '.format(epoch_loss) +
                     'validation loss:{} '.format(avg_val_loss))

        # Update minimum evaluating loss.
        if not os.path.exists(os.path.dirname(config.encoder_save_name)):
            os.mkdir(os.path.dirname(config.encoder_save_name))
        if (avg_val_loss < val_loss):
            logging.info("saving model to ../saved_model/ %s" %
                         config.model_name)
            torch.save(model.encoder, config.encoder_save_name)
            torch.save(model.decoder, config.decoder_save_name)
            torch.save(model.attention, config.attention_save_name)
            torch.save(model.reduce_state, config.reduce_state_save_name)
            val_loss = avg_val_loss
            with open(config.losses_path, 'a') as f:
                f.write(f"best val loss={val_loss}\n")
        else:
            early_stopping_count += 1
        if early_stopping_count >= config.patience:
            logging.info(
                f'Validation loss did not decrease for {config.patience} epochs, stop training.'
            )
            break

    writer.close()