Пример #1
0
    def train_and_validate(self, train_data: Dataset, valid_data: Dataset) \
            -> None:
        """
        Train the model and validate it from time to time on the validation set.

        :param train_data: training data
        :param valid_data: validation data
        """
        train_iter = make_data_iter(train_data,
                                    batch_size=self.batch_size,
                                    batch_type=self.batch_type,
                                    train=True,
                                    shuffle=self.shuffle)

        # For last batch in epoch batch_multiplier needs to be adjusted
        # to fit the number of leftover training examples
        leftover_batch_size = len(train_data) % (self.batch_multiplier *
                                                 self.batch_size)

        for epoch_no in range(self.epochs):
            self.logger.info("EPOCH %d", epoch_no + 1)

            if self.scheduler is not None and self.scheduler_step_at == "epoch":
                self.scheduler.step(epoch=epoch_no)

            self.model.train()

            # Reset statistics for each epoch.
            start = time.time()
            total_valid_duration = 0
            start_tokens = self.total_tokens
            self.current_batch_multiplier = self.batch_multiplier
            count = self.current_batch_multiplier - 1
            epoch_loss = 0

            for i, batch in enumerate(iter(train_iter)):
                # reactivate training
                self.model.train()
                # create a Batch object from torchtext batch
                batch = Batch(batch, self.pad_index, use_cuda=self.use_cuda)

                # only update every batch_multiplier batches
                # see https://medium.com/@davidlmorton/
                # increasing-mini-batch-size-without-increasing-
                # memory-6794e10db672

                # Set current_batch_mutliplier to fit
                # number of leftover examples for last batch in epoch
                if self.batch_multiplier > 1 and i == len(train_iter) - \
                        math.ceil(leftover_batch_size / self.batch_size):
                    self.current_batch_multiplier = math.ceil(
                        leftover_batch_size / self.batch_size)
                    count = self.current_batch_multiplier - 1

                update = count == 0
                # print(count, update, self.steps)
                batch_loss = self._train_batch(batch,
                                               update=update,
                                               count=count)

                # Only save finaly computed batch_loss of full batch
                if update:
                    self.tb_writer.add_scalar("train/train_batch_loss",
                                              batch_loss, self.steps)

                count = self.batch_multiplier if update else count
                count -= 1

                # Only add complete batch_loss of full mini-batch to epoch_loss
                if update:
                    epoch_loss += batch_loss.detach().cpu().numpy()

                if self.scheduler is not None and \
                        self.scheduler_step_at == "step" and update:
                    self.scheduler.step()

                # log learning progress
                if self.steps % self.logging_freq == 0 and update:
                    elapsed = time.time() - start - total_valid_duration
                    elapsed_tokens = self.total_tokens - start_tokens
                    self.logger.info(
                        "Epoch %3d Step: %8d Batch Loss: %12.6f "
                        "Tokens per Sec: %8.0f, Lr: %.6f", epoch_no + 1,
                        self.steps, batch_loss, elapsed_tokens / elapsed,
                        self.optimizer.param_groups[0]["lr"])
                    start = time.time()
                    total_valid_duration = 0
                    start_tokens = self.total_tokens

                # validate on the entire dev set
                if self.steps % self.validation_freq == 0 and update:
                    valid_start_time = time.time()

                    valid_score, valid_loss, valid_ppl, valid_sources, \
                        valid_sources_raw, valid_references, valid_hypotheses, \
                        valid_hypotheses_raw, valid_attention_scores = \
                        validate_on_data(
                            logger=self.logger,
                            batch_size=self.eval_batch_size,
                            data=valid_data,
                            eval_metric=self.eval_metric,
                            level=self.level, model=self.model,
                            use_cuda=self.use_cuda,
                            max_output_length=self.max_output_length,
                            loss_function=self.loss,
                            beam_size=1,  # greedy validations
                            batch_type=self.eval_batch_type
                        )

                    self.tb_writer.add_scalar("valid/valid_loss", valid_loss,
                                              self.steps)
                    self.tb_writer.add_scalar("valid/valid_score", valid_score,
                                              self.steps)
                    self.tb_writer.add_scalar("valid/valid_ppl", valid_ppl,
                                              self.steps)

                    if self.early_stopping_metric == "loss":
                        ckpt_score = valid_loss
                    elif self.early_stopping_metric in ["ppl", "perplexity"]:
                        ckpt_score = valid_ppl
                    else:
                        ckpt_score = valid_score

                    new_best = False
                    if self.is_best(ckpt_score):
                        self.best_ckpt_score = ckpt_score
                        self.best_ckpt_iteration = self.steps
                        self.logger.info(
                            'Hooray! New best validation result [%s]!',
                            self.early_stopping_metric)
                        if self.ckpt_queue.maxsize > 0:
                            self.logger.info("Saving new checkpoint.")
                            new_best = True
                            self._save_checkpoint()

                    if self.scheduler is not None \
                            and self.scheduler_step_at == "validation":
                        self.scheduler.step(ckpt_score)

                    # append to validation report
                    self._add_report(valid_score=valid_score,
                                     valid_loss=valid_loss,
                                     valid_ppl=valid_ppl,
                                     eval_metric=self.eval_metric,
                                     new_best=new_best)

                    self._log_examples(
                        sources_raw=[v for v in valid_sources_raw],
                        sources=valid_sources,
                        hypotheses_raw=valid_hypotheses_raw,
                        hypotheses=valid_hypotheses,
                        references=valid_references)

                    valid_duration = time.time() - valid_start_time
                    total_valid_duration += valid_duration
                    self.logger.info(
                        'Validation result (greedy) at epoch %3d, '
                        'step %8d: %s: %6.2f, loss: %8.4f, ppl: %8.4f, '
                        'duration: %.4fs', epoch_no + 1, self.steps,
                        self.eval_metric, valid_score, valid_loss, valid_ppl,
                        valid_duration)

                    # store validation set outputs
                    #self._store_outputs(valid_hypotheses)

                    # store attention plots for selected valid sentences
                    if valid_attention_scores:
                        store_attention_plots(
                            attentions=valid_attention_scores,
                            targets=valid_hypotheses_raw,
                            sources=[s for s in valid_data.src],
                            indices=self.log_valid_sents,
                            output_prefix="{}/att.{}".format(
                                self.model_dir, self.steps),
                            tb_writer=self.tb_writer,
                            steps=self.steps)

                if self.stop:
                    break
            if self.stop:
                self.logger.info(
                    'Training ended since minimum lr %f was reached.',
                    self.learning_rate_min)
                break

            self.logger.info('Epoch %3d: total training loss %.2f',
                             epoch_no + 1, epoch_loss)
        else:
            self.logger.info('Training ended after %3d epochs.', epoch_no + 1)
        self.logger.info(
            'Best validation result (greedy) at step '
            '%8d: %6.2f %s.', self.best_ckpt_iteration, self.best_ckpt_score,
            self.early_stopping_metric)

        self.tb_writer.close()  # close Tensorboard writer
Пример #2
0
def Q_learning(cfg_file: str) -> None:
    """
    Main training function. After training, also test on test data if given.
    :param cfg_file: path to configuration yaml file
    """
    cfg = load_config(cfg_file)  # config is a dict
    # make logger
    model_dir = make_model_dir(cfg["training"]["model_dir"],
                               overwrite=cfg["training"].get(
                                   "overwrite", False))
    _ = make_logger(model_dir, mode="train")  # version string returned
    # TODO: save version number in model checkpoints

    # set the random seed
    set_seed(seed=cfg["training"].get("random_seed", 42))

    # load the data
    print("loadding data here")
    train_data, dev_data, test_data, src_vocab, trg_vocab = load_data(
        data_cfg=cfg["data"])
    # The training data is filtered to include sentences up to `max_sent_length`
    #     on source and target side.

    # training config:
    train_config = cfg["training"]
    shuffle = train_config.get("shuffle", True)
    batch_size = train_config["batch_size"]
    mini_BATCH_SIZE = train_config["mini_batch_size"]
    batch_type = train_config.get("batch_type", "sentence")
    outer_epochs = train_config.get("outer_epochs", 10)
    inner_epochs = train_config.get("inner_epochs", 10)
    TARGET_UPDATE = train_config.get("target_update", 10)
    Gamma = train_config.get("Gamma", 0.999)
    use_cuda = train_config["use_cuda"] and torch.cuda.is_available()

    # validation part config
    # validation
    validation_freq = train_config.get("validation_freq", 1000)
    ckpt_queue = queue.Queue(maxsize=train_config.get("keep_last_ckpts", 5))
    eval_batch_size = train_config.get("eval_batch_size", batch_size)
    level = cfg["data"]["level"]
    eval_metric = train_config.get("eval_metric", "bleu")
    n_gpu = torch.cuda.device_count() if use_cuda else 0
    eval_batch_type = train_config.get("eval_batch_type", batch_type)
    # eval options
    test_config = cfg["testing"]
    bpe_type = test_config.get("bpe_type", "subword-nmt")
    #sacrebleu = {"remove_whitespace": True, "tokenize": "13a"}
    max_output_length = train_config.get("max_output_length", None)
    minimize_metric = True
    # initialize training statistics
    stats = TrainStatistics(
        steps=0,
        stop=False,
        total_tokens=0,
        best_ckpt_iter=0,
        best_ckpt_score=np.inf if minimize_metric else -np.inf,
        minimize_metric=minimize_metric)

    early_stopping_metric = train_config.get("early_stopping_metric",
                                             "eval_metric")
    if early_stopping_metric in ["ppl", "loss"]:
        stats.minimize_metric = True
    elif early_stopping_metric == "eval_metric":
        if eval_metric in [
                "bleu", "chrf", "token_accuracy", "sequence_accuracy"
        ]:
            stats.minimize_metric = False
        # eval metric that has to get minimized (not yet implemented)
        else:
            stats.minimize_metric = True

    # data loader(modified from train_and_validate function
    # Returns a torchtext iterator for a torchtext dataset.
    # param dataset: torchtext dataset containing src and optionally trg
    train_iter = make_data_iter(train_data,
                                batch_size=batch_size,
                                batch_type=batch_type,
                                train=True,
                                shuffle=shuffle)

    # initialize the Replay Memory D with capacity N
    memory = ReplayMemory(10000)
    steps_done = 0

    # initialize two DQN networks
    policy_net = build_model(cfg["model"],
                             src_vocab=src_vocab,
                             trg_vocab=trg_vocab)  # Q_network
    target_net = build_model(cfg["model"],
                             src_vocab=src_vocab,
                             trg_vocab=trg_vocab)  # Q_hat_network
    #logger.info(policy_net.src_vocab.stoi)
    #print("###############trg vocab: ", len(target_net.trg_vocab.stoi))
    #print("trg embed: ", target_net.trg_embed.vocab_size)
    if use_cuda:
        policy_net.cuda()
        target_net.cuda()

    target_net.load_state_dict(policy_net.state_dict())
    # Initialize target net Q_hat with weights equal to policy_net

    target_net.eval()  # target_net not update the parameters, test mode

    # Optimizer
    optimizer = build_optimizer(config=cfg["training"],
                                parameters=policy_net.parameters())
    # Loss function
    loss_function = torch.nn.MSELoss()

    pad_index = policy_net.pad_index
    # print('!!!'*10, pad_index)

    cross_entropy_loss = XentLoss(pad_index=pad_index)
    policy_net.loss_function = cross_entropy_loss

    for i_episode in range(outer_epochs):
        # Outer loop

        # get batch
        for i, batch in enumerate(iter(train_iter)):  # joeynmt training.py 377

            # create a Batch object from torchtext batch
            # ( use class Batch from batch.py)
            # return the sentences same length (with padding) in one batch
            batch = Batch(batch, policy_net.pad_index, use_cuda=use_cuda)
            # we want to get batch.src and batch.trg
            # the shape of batch.src: (batch_size * length of the sentence)

            # source here is represented by the word index not word embedding.
            # Use Model._encode: self.src_embed(src) to turn word index into word embedding.
            # print(batch.src)
            # print(batch.src_length)
            # print(torch.max(batch.src))
            encoder_output_batch, _, _, _ = policy_net(
                return_type="encode",
                src=batch.src,
                src_length=batch.src_length,
                src_mask=batch.src_mask,
            )
            #print('batch.src.shape is: ', batch.src.shape)
            #logger.info(encoder_output_batch.shape)
            # print('max_output_length', max_output_length)

            # get the translated output of a batch
            trans_output_batch, _ = transformer_greedy(
                src_mask=batch.src_mask,
                max_output_length=max_output_length,
                model=policy_net,
                encoder_output=encoder_output_batch,
                steps_done=steps_done,
                use_cuda=use_cuda)
            #print('steps_done',steps_done)

            steps_done += 1

            #print('trans_output_batch.shape is:', trans_output_batch.shape)
            # batch_size * max_translation_sentence_length

            # decode back to symbols
            # Convert multiple arrays containing sequences of token IDs to their
            # sentences, optionally cutting them off at the end-of-sequence token.
            # :param arrays: 2D array containing indices
            # :param cut_at_eos: cut the decoded sentences at the first <eos>
            # :param skip_pad: skip generated <pad> tokens
            # :return: list of list of strings (tokens)

            #print('batch.trg', batch.trg)
            # print('batch.trg.shape is:', batch.trg.shape)
            #print('trans_output_batch', trans_output_batch)

            reward_batch = [
            ]  # Get the reward_batch (Get the bleu score of the sentences in a batch)

            for i in range(int(batch.src.shape[0])):
                all_outputs = [(trans_output_batch[i])[1:]]
                hypotheses = policy_net.trg_vocab.arrays_to_sentences(
                    arrays=all_outputs, cut_at_eos=True)

                all_ref = [batch.trg[i]]
                references = policy_net.trg_vocab.arrays_to_sentences(
                    arrays=all_ref, cut_at_eos=True)

                #print('hypothese', hypotheses)
                #print('reference', references)
                # evaluate with metric on full dataset
                join_char = " " if level in ["word", "bpe"] else ""
                # valid_sources = [join_char.join(s) for s in data.src]
                valid_references = [join_char.join(t) for t in references]
                valid_hypotheses = [join_char.join(t) for t in hypotheses]
                print(valid_references, valid_hypotheses)
                ''' current_valid_score = bleu(
                    valid_hypotheses, valid_references,
                    tokenize=sacrebleu["tokenize"])
                '''

                current_valid_score = sacrebleu.corpus_bleu(
                    sys_stream=valid_hypotheses,
                    ref_streams=[valid_references],
                    smooth_method='floor',
                    smooth_value=0.01).score

                reward_batch.append(current_valid_score)
            print('reward batch is', reward_batch)
            reward_batch = torch.tensor(reward_batch, dtype=torch.float)

            # reward_batch = bleu(hypotheses, references, tokenize="13a")
            # print('reward_batch.shape', reward_batch.shape)

            # make prefix and push tuples into memory
            push_sample_to_memory(eos_index=policy_net.eos_index,
                                  memory=memory,
                                  src_batch=batch.src,
                                  trans_output_batch=trans_output_batch,
                                  reward_batch=reward_batch,
                                  max_output_length=max_output_length)

            # inner loop
            for t in range(inner_epochs):
                # Sample mini-batch from the memory
                transitions = memory.sample(mini_BATCH_SIZE)
                # transition = [Transition(source=array([]), prefix=array([]), next_word= int, reward= int),
                #               Transition(source=array([]), prefix=array([]), next_word= int, reward= int,...]
                # Each Transition is what we push into memory for one sentence: memory.push(source, prefix, next_word, reward_batch[i])
                mini_batch = Transition(*zip(*transitions))
                # merge the same class in transition together
                # mini_batch = Transition(source=(array([]), array([]),...), prefix=(array([],...),
                #               next_word=array([...]), reward=array([...]))
                # mini_batch.reward is tuple: length is mini_BATCH_SIZE.
                #print('mini_batch', mini_batch)

                #concatenate together into a tensor.
                words = []
                for word in mini_batch.next_word:
                    new_word = word.unsqueeze(0)
                    words.append(new_word)
                mini_next_word = torch.cat(words)  # shape (mini_BATCH_SIZE,)
                mini_reward = torch.tensor(
                    mini_batch.reward)  # shape (mini_BATCH_SIZE,)

                mini_src_length = [
                    len(item) for item in mini_batch.source_sentence
                ]
                mini_src_length = torch.Tensor(mini_src_length)

                mini_src = pad_sequence(mini_batch.source_sentence,
                                        batch_first=True,
                                        padding_value=float(pad_index))
                # shape (mini_BATCH_SIZE, max_length_src)

                length_prefix = [len(item) for item in mini_batch.prefix]
                mini_prefix_length = torch.Tensor(length_prefix)

                prefix_list = []
                for prefix_ in mini_batch.prefix:
                    prefix_ = torch.from_numpy(prefix_)
                    prefix_list.append(prefix_)

                mini_prefix = pad_sequence(prefix_list,
                                           batch_first=True,
                                           padding_value=pad_index)
                # shape (mini_BATCH_SIZE, max_length_prefix)

                mini_src_mask = (mini_src != pad_index).unsqueeze(1)
                mini_trg_mask = (mini_prefix != pad_index).unsqueeze(1)

                #print('mini_src',  mini_src)
                #print('mini_src_length', mini_src_length)
                #print('mini_src_mask', mini_src_mask)
                #print('mini_prefix', mini_prefix)
                #print('mini_trg_mask', mini_trg_mask)

                #print('mini_reward', mini_reward)

                # max_length_src = torch.max(mini_src_length) #max([len(item) for item in mini_batch.source_sentence])

                if use_cuda:
                    mini_src = mini_src.cuda()
                    mini_prefix = mini_prefix.cuda()
                    mini_src_mask = mini_src_mask.cuda()
                    mini_src_length = mini_src_length.cuda()
                    mini_trg_mask = mini_trg_mask.cuda()
                    mini_next_word = mini_next_word.cuda()

                # print(next(policy_net.parameters()).is_cuda)
                # print(mini_trg_mask.get_device())
                # calculate the Q_value
                logits_Q, _, _, _ = policy_net._encode_decode(
                    src=mini_src,
                    trg_input=mini_prefix,
                    src_mask=mini_src_mask,
                    src_length=mini_src_length,
                    trg_mask=
                    mini_trg_mask  # trg_mask = (self.trg_input != pad_index).unsqueeze(1)
                )
                #print('mini_prefix_length', mini_prefix_length)

                #print('logits_Q.shape', logits_Q.shape) # torch.Size([64, 99, 31716])
                #print('logits_Q', logits_Q)

                # length_prefix = max([len(item) for item in mini_batch.prefix])
                # logits_Q shape: batch_size * length of the sentence * total number of words in corpus.
                logits_Q = logits_Q[range(mini_BATCH_SIZE),
                                    mini_prefix_length.long() - 1, :]
                #print('logits_Q_.shape', logits_Q.shape) #shape(mini_batch_size, num_words)
                # logits shape: mini_batch_size * total number of words in corpus
                Q_value = logits_Q[range(mini_BATCH_SIZE), mini_next_word]
                #print('mini_next_word', mini_next_word)
                #print("Q_value", Q_value)

                mini_prefix_add = torch.cat(
                    [mini_prefix, mini_next_word.unsqueeze(1)], dim=1)
                #print('mini_prefix_add', mini_prefix_add)
                mini_trg_mask_add = (mini_prefix_add != pad_index).unsqueeze(1)
                #print('mini_trg_mask_add', mini_trg_mask_add)

                if use_cuda:
                    mini_prefix_add = mini_prefix_add.cuda()
                    mini_trg_mask_add = mini_trg_mask_add.cuda()

                logits_Q_hat, _, _, _ = target_net._encode_decode(
                    src=mini_src,
                    trg_input=mini_prefix_add,
                    src_mask=mini_src_mask,
                    src_length=mini_src_length,
                    trg_mask=mini_trg_mask_add)
                #print('mini_prefix_add.shape', mini_prefix_add.shape)
                #print('logits_Q_hat.shape', logits_Q_hat.shape)
                #print('mini_prefix_length.long()', mini_prefix_length.long())
                logits_Q_hat = logits_Q_hat[range(mini_BATCH_SIZE),
                                            mini_prefix_length.long(), :]
                Q_hat_value, _ = torch.max(logits_Q_hat, dim=1)
                #print('Q_hat_value', Q_hat_value)

                yj = mini_reward.float()
                #print('yj', yj)
                index = (mini_reward == 0)
                #print('index', index)

                if use_cuda:
                    yj = yj.cuda()
                    Q_hat_value = Q_hat_value.cuda()

                yj[index] = Gamma * Q_hat_value[index]
                #print("yj = ", yj)
                #print("Q_value=", Q_value)

                yj.detach()
                # Optimize the model
                policy_net.zero_grad()

                # Compute loss
                loss = loss_function(yj, Q_value)
                #print('loss', loss)
                logger.info("step = {}, loss = {}".format(
                    stats.steps, loss.item()))
                loss.backward()
                #for param in policy_net.parameters():
                #   param.grad.data.clamp_(-1, 1)
                optimizer.step()

                stats.steps += 1
                #print('step', stats.steps)

                if stats.steps % TARGET_UPDATE == 0:
                    #print('update the parameters in target_net.')
                    target_net.load_state_dict(policy_net.state_dict())

                if stats.steps % validation_freq == 0:  # Validation
                    print('Start validation')

                    valid_score, valid_loss, valid_ppl, valid_sources, \
                    valid_sources_raw, valid_references, valid_hypotheses, \
                    valid_hypotheses_raw, valid_attention_scores = \
                        validate_on_data(
                            model=policy_net,
                            data=dev_data,
                            batch_size=eval_batch_size,
                            use_cuda=use_cuda,
                            level=level,
                            eval_metric=eval_metric,
                            n_gpu=n_gpu,
                            compute_loss=True,
                            beam_size=1,
                            beam_alpha=-1,
                            batch_type=eval_batch_type,
                            postprocess=True,
                            bpe_type=bpe_type,
                            sacrebleu=sacrebleu,
                            max_output_length=max_output_length
                        )
                    print('validation_loss: {}, validation_score: {}'.format(
                        valid_loss, valid_score))
                    logger.info(valid_loss)
                    print('average loss: total_loss/n_tokens:', valid_ppl)

                    if early_stopping_metric == "loss":
                        ckpt_score = valid_loss
                    elif early_stopping_metric in ["ppl", "perplexity"]:
                        ckpt_score = valid_ppl
                    else:
                        ckpt_score = valid_score
                    if stats.is_best(ckpt_score):
                        stats.best_ckpt_score = ckpt_score
                        stats.best_ckpt_iter = stats.steps
                        logger.info('Hooray! New best validation result [%s]!',
                                    early_stopping_metric)
                        if ckpt_queue.maxsize > 0:
                            logger.info("Saving new checkpoint.")

                            # def _save_checkpoint(self) -> None:
                            """
                            Save the model's current parameters and the training state to a
                            checkpoint.
                            The training state contains the total number of training steps,
                            the total number of training tokens,
                            the best checkpoint score and iteration so far,
                            and optimizer and scheduler states.
                            """
                            model_path = "{}/{}.ckpt".format(
                                model_dir, stats.steps)
                            model_state_dict = policy_net.module.state_dict() \
                                if isinstance(policy_net, torch.nn.DataParallel) \
                                else policy_net.state_dict()
                            state = {
                                "steps": stats.steps,
                                "total_tokens": stats.total_tokens,
                                "best_ckpt_score": stats.best_ckpt_score,
                                "best_ckpt_iteration": stats.best_ckpt_iter,
                                "model_state": model_state_dict,
                                "optimizer_state": optimizer.state_dict(),
                                # "scheduler_state": scheduler.state_dict() if
                                # self.scheduler is not None else None,
                                # 'amp_state': amp.state_dict() if self.fp16 else None
                            }
                            torch.save(state, model_path)
                            if ckpt_queue.full():
                                to_delete = ckpt_queue.get(
                                )  # delete oldest ckpt
                                try:
                                    os.remove(to_delete)
                                except FileNotFoundError:
                                    logger.warning(
                                        "Wanted to delete old checkpoint %s but "
                                        "file does not exist.", to_delete)

                            ckpt_queue.put(model_path)

                            best_path = "{}/best.ckpt".format(model_dir)
                            try:
                                # create/modify symbolic link for best checkpoint
                                symlink_update("{}.ckpt".format(stats.steps),
                                               best_path)
                            except OSError:
                                # overwrite best.ckpt
                                torch.save(state, best_path)
Пример #3
0
def validate_on_data(model: Model, data: Dataset,
                     logger: Logger,
                     batch_size: int,
                     use_cuda: bool, max_output_length: int,
                     level: str, eval_metric: Optional[str],
                     loss_function: torch.nn.Module = None,
                     beam_size: int = 1, beam_alpha: int = -1,
                     batch_type: str = "sentence",
                     postprocess: bool = True
                     ) \
        -> (float, float, float, List[str], List[List[str]], List[str],
            List[str], List[List[str]], List[np.array]):
    """
    Generate translations for the given data.
    If `loss_function` is not None and references are given,
    also compute the loss.

    :param model: model module
    :param logger: logger
    :param data: dataset for validation
    :param batch_size: validation batch size
    :param use_cuda: if True, use CUDA
    :param max_output_length: maximum length for generated hypotheses
    :param level: segmentation level, one of "char", "bpe", "word"
    :param eval_metric: evaluation metric, e.g. "bleu"
    :param loss_function: loss function that computes a scalar loss
        for given inputs and targets
    :param beam_size: beam size for validation.
        If <2 then greedy decoding (default).
    :param beam_alpha: beam search alpha for length penalty,
        disabled if set to -1 (default).
    :param batch_type: validation batch type (sentence or token)
    :param postprocess: if True, remove BPE segmentation from translations

    :return:
        - current_valid_score: current validation score [eval_metric],
        - valid_loss: validation loss,
        - valid_ppl:, validation perplexity,
        - valid_sources: validation sources,
        - valid_sources_raw: raw validation sources (before post-processing),
        - valid_references: validation references,
        - valid_hypotheses: validation_hypotheses,
        - decoded_valid: raw validation hypotheses (before post-processing),
        - valid_attention_scores: attention scores for validation hypotheses
    """
    if batch_size > 1000 and batch_type == "sentence":
        logger.warning(
            "WARNING: Are you sure you meant to work on huge batches like "
            "this? 'batch_size' is > 1000 for sentence-batching. "
            "Consider decreasing it or switching to"
            " 'eval_batch_type: token'.")
    valid_iter = make_data_iter(dataset=data,
                                batch_size=batch_size,
                                batch_type=batch_type,
                                shuffle=False,
                                train=False)
    valid_sources_raw = data.src
    pad_index = model.src_vocab.stoi[PAD_TOKEN]
    # disable dropout
    model.eval()
    # don't track gradients during validation
    with torch.no_grad():
        all_outputs = []
        valid_attention_scores = []
        total_loss = 0
        total_ntokens = 0
        total_nseqs = 0
        for valid_batch in iter(valid_iter):
            # run as during training to get validation loss (e.g. xent)

            batch = Batch(valid_batch, pad_index, use_cuda=use_cuda)
            # sort batch now by src length and keep track of order
            sort_reverse_index = batch.sort_by_src_lengths()

            # run as during training with teacher forcing
            if loss_function is not None and batch.trg is not None:
                batch_loss = model.get_loss_for_batch(
                    batch, loss_function=loss_function)
                total_loss += batch_loss
                total_ntokens += batch.ntokens
                total_nseqs += batch.nseqs

            # run as during inference to produce translations
            output, attention_scores = model.run_batch(
                batch=batch,
                beam_size=beam_size,
                beam_alpha=beam_alpha,
                max_output_length=max_output_length)

            # sort outputs back to original order
            all_outputs.extend(output[sort_reverse_index])
            valid_attention_scores.extend(
                attention_scores[sort_reverse_index]
                if attention_scores is not None else [])

        assert len(all_outputs) == len(data)

        if loss_function is not None and total_ntokens > 0:
            # total validation loss
            valid_loss = total_loss
            # exponent of token-level negative log prob
            valid_ppl = torch.exp(total_loss / total_ntokens)
        else:
            valid_loss = -1
            valid_ppl = -1

        # decode back to symbols
        decoded_valid = model.trg_vocab.arrays_to_sentences(arrays=all_outputs,
                                                            cut_at_eos=True)

        # evaluate with metric on full dataset
        join_char = " " if level in ["word", "bpe"] else ""
        valid_sources = [join_char.join(s) for s in data.src]
        valid_references = [join_char.join(t) for t in data.trg]
        valid_hypotheses = [join_char.join(t) for t in decoded_valid]

        # post-process
        if level == "bpe" and postprocess:
            valid_sources = [bpe_postprocess(s) for s in valid_sources]
            valid_references = [bpe_postprocess(v) for v in valid_references]
            valid_hypotheses = [bpe_postprocess(v) for v in valid_hypotheses]

        # if references are given, evaluate against them
        if valid_references:
            assert len(valid_hypotheses) == len(valid_references)

            current_valid_score = 0
            if eval_metric.lower() == 'bleu':
                # this version does not use any tokenization
                current_valid_score = bleu(valid_hypotheses, valid_references)
            elif eval_metric.lower() == 'chrf':
                current_valid_score = chrf(valid_hypotheses, valid_references)
            elif eval_metric.lower() == 'token_accuracy':
                current_valid_score = token_accuracy(valid_hypotheses,
                                                     valid_references,
                                                     level=level)
            elif eval_metric.lower() == 'sequence_accuracy':
                current_valid_score = sequence_accuracy(
                    valid_hypotheses, valid_references)
        else:
            current_valid_score = -1

    return current_valid_score, valid_loss, valid_ppl, valid_sources, \
        valid_sources_raw, valid_references, valid_hypotheses, \
        decoded_valid, valid_attention_scores
Пример #4
0
    def train_and_validate(self, train_data, valid_data):
        """
        Train the model and validate it from time to time on the validation set.

        :param train_data:
        :param valid_data:
        :return:
        """
        train_iter = make_data_iter(train_data,
                                    batch_size=self.batch_size,
                                    train=True,
                                    shuffle=self.shuffle)
        for epoch_no in range(self.epochs):
            self.logger.info("EPOCH {}".format(epoch_no + 1))
            self.model.train()

            start = time.time()
            total_valid_duration = 0
            processed_tokens = self.total_tokens
            count = 0

            for batch_no, batch in enumerate(iter(train_iter), 1):
                # reactivate training
                self.model.train()
                batch = Batch(batch, self.pad_index, use_cuda=self.use_cuda)

                # only update every batch_multiplier batches
                # see https://medium.com/@davidlmorton/increasing-mini-batch-size-without-increasing-memory-6794e10db672
                update = count == 0
                # print(count, update, self.steps)
                batch_loss = self._train_batch(batch, update=update)
                count = self.batch_multiplier if update else count
                count -= 1

                # log learning progress
                if self.model.training and self.steps % self.logging_freq == 0 \
                        and update:
                    elapsed = time.time() - start - total_valid_duration
                    elapsed_tokens = self.total_tokens - processed_tokens
                    self.logger.info(
                        "Epoch %d Step: %d Loss: %f Tokens per Sec: %f" %
                        (epoch_no + 1, self.steps, batch_loss,
                         elapsed_tokens / elapsed))
                    start = time.time()
                    total_valid_duration = 0

                # validate on the entire dev set
                if self.steps % self.validation_freq == 0 and update:
                    valid_start_time = time.time()

                    valid_score, valid_loss, valid_ppl, valid_sources, \
                        valid_sources_raw, valid_references, valid_hypotheses, \
                        valid_hypotheses_raw, valid_attention_scores = \
                        validate_on_data(
                            batch_size=self.batch_size, data=valid_data,
                            eval_metric=self.eval_metric,
                            level=self.level, model=self.model,
                            use_cuda=self.use_cuda,
                            max_output_length=self.max_output_length,
                            criterion=self.criterion)

                    if self.ckpt_metric == "loss":
                        ckpt_score = valid_loss
                    elif self.ckpt_metric in ["ppl", "perplexity"]:
                        ckpt_score = valid_ppl
                    else:
                        ckpt_score = valid_score

                    new_best = False
                    if self.is_best(ckpt_score):
                        self.best_ckpt_score = ckpt_score
                        self.best_ckpt_iteration = self.steps
                        self.logger.info(
                            'Hooray! New best validation result [{}]!'.format(
                                self.ckpt_metric))
                        new_best = True
                        self.save_checkpoint()

                    # pass validation score or loss or ppl to scheduler
                    if self.schedule_metric == "loss":
                        # schedule based on loss
                        schedule_score = valid_loss
                    elif self.schedule_metric in ["ppl", "perplexity"]:
                        # schedule based on perplexity
                        schedule_score = valid_ppl
                    else:
                        # schedule based on evaluation score
                        schedule_score = valid_score
                    if self.scheduler is not None:
                        self.scheduler.step(schedule_score)

                    # append to validation report
                    self._add_report(valid_score=valid_score,
                                     valid_loss=valid_loss,
                                     valid_ppl=valid_ppl,
                                     eval_metric=self.eval_metric,
                                     new_best=new_best)

                    # always print first x sentences
                    for p in range(self.print_valid_sents):
                        self.logger.debug("Example #{}".format(p))
                        self.logger.debug("\tRaw source: {}".format(
                            valid_sources_raw[p]))
                        self.logger.debug("\tSource: {}".format(
                            valid_sources[p]))
                        self.logger.debug("\tReference: {}".format(
                            valid_references[p]))
                        self.logger.debug("\tRaw hypothesis: {}".format(
                            valid_hypotheses_raw[p]))
                        self.logger.debug("\tHypothesis: {}".format(
                            valid_hypotheses[p]))
                    valid_duration = time.time() - valid_start_time
                    total_valid_duration += valid_duration
                    self.logger.info(
                        'Validation result at epoch {}, step {}: {}: {}, '
                        'loss: {}, ppl: {}, duration: {:.4f}s'.format(
                            epoch_no + 1, self.steps, self.eval_metric,
                            valid_score, valid_loss, valid_ppl,
                            valid_duration))

                    # store validation set outputs
                    self.store_outputs(valid_hypotheses)

                    # store attention plots for first three sentences of
                    # valid data and one randomly chosen example
                    store_attention_plots(attentions=valid_attention_scores,
                                          targets=valid_hypotheses_raw,
                                          sources=[s for s in valid_data.src],
                                          idx=[
                                              0, 1, 2,
                                              np.random.randint(
                                                  0, len(valid_hypotheses))
                                          ],
                                          output_prefix="{}/att.{}".format(
                                              self.model_dir, self.steps))

                if self.stop:
                    break
            if self.stop:
                self.logger.info(
                    'Training ended since minimum lr {} was reached.'.format(
                        self.learning_rate_min))
                break
        else:
            self.logger.info(
                'Training ended after {} epochs.'.format(epoch_no + 1))
        self.logger.info('Best validation result at step {}: {} {}.'.format(
            self.best_ckpt_iteration, self.best_ckpt_score, self.ckpt_metric))