Beispiel #1
0
    def train(engine, mini_batch):
        # You have to reset the gradients of all model parameters
        # before to take another step in gradient descent.
        engine.model.train() # Because we assign model as class variable, we can easily access to it.
        engine.optimizer.zero_grad()

        x, y = mini_batch
        x, y = x.to(engine.device), y.to(engine.device)

        # Take feed-forward
        y_hat = engine.model(x)

        loss = engine.crit(y_hat, y)
        loss.backward()

        # Calculate accuracy only if 'y' is LongTensor,
        # which means that 'y' is one-hot representation.
        if isinstance(y, torch.LongTensor) or isinstance(y, torch.cuda.LongTensor):
            accuracy = (torch.argmax(y_hat, dim=-1) == y).sum() / float(y.size(0))
        else:
            accuracy = 0

        p_norm = float(get_parameter_norm(engine.model.parameters()))
        g_norm = float(get_grad_norm(engine.model.parameters()))

        # Take a step of gradient descent.
        engine.optimizer.step()

        return {
            'loss': float(loss),
            'accuracy': float(accuracy),
            '|param|': p_norm,
            '|g_param|': g_norm,
        }
    def train(engine, mini_batch):
        engine.model.train()
        engine.optimizer.zero_grad()

        x, y = mini_batch
        x, y = x.to(engine.device), y.to(engine.device)

        y_hat = engine.model(x)
        y_hat = y_hat.view(-1)

        loss = engine.crit(y_hat, y)
        loss.backward()

        if isinstance(y, torch.LongTensor) or isinstance(
                y, torch.cuda.LongTensor):
            accuracy = (torch.argmax(y_hat, dim=-1) == y).sum() / float(
                y.size(0))
        else:
            accuracy = 0

        # 커지는 걸 보면서 모델이 학습을 하고 있다는걸 확인
        p_nrom = float(get_parameter_norm(engine.model.parameters()))

        # gradient의 크기가 작아지는 걸 보면서 descent가 이루어지는 걸 확인
        g_norm = float(get_grad_norm(engine.model.parameters()))

        engine.optimizer.step()

        return {
            "loss": loss,
            "accuracy": accuracy,
            "|param|": p_nrom,
            "|g_param|": g_norm
        }
Beispiel #3
0
    def train_epoch(self, train, optimizer):
        '''
        Train an epoch with given train iterator and optimizer.
        '''
        total_loss, total_word_count = 0, 0
        total_grad_norm = 0
        avg_loss, avg_grad_norm = 0, 0
        sample_cnt = 0

        # Iterate whole train-set.
        for idx, mini_batch in enumerate(train):
            # Raw target variable has both BOS and EOS token.
            # The output of sequence-to-sequence does not have BOS token.
            # Thus, remove BOS token for reference.
            x, y = mini_batch.src, mini_batch.tgt[0][:, 1:]
            # |x| = (batch_size, length)
            # |y| = (batch_size, length)

            # You have to reset the gradients of all model parameters before to take another step in gradient descent.
            optimizer.zero_grad()

            # Take feed-forward
            # Similar as before, the input of decoder does not have EOS token.
            # Thus, remove EOS token for decoder input.
            y_hat = self.model(x, mini_batch.tgt[0][:, :-1])
            # |y_hat| = (batch_size, length, output_size)

            # Calcuate loss and gradients with back-propagation.
            loss = self._get_loss(y_hat, y)
            loss.div(y.size(0)).backward()

            # Simple math to show stats.
            # Don't forget to detach final variables.
            total_loss += float(loss)
            total_word_count += int(mini_batch.tgt[1].sum())
            param_norm = float(
                utils.get_parameter_norm(self.model.parameters()))
            total_grad_norm += float(
                utils.get_grad_norm(self.model.parameters()))

            avg_loss = total_loss / total_word_count
            avg_grad_norm = total_grad_norm / (idx + 1)

            # In orther to avoid gradient exploding, we apply gradient clipping.
            torch_utils.clip_grad_norm_(self.model.parameters(),
                                        self.config.max_grad_norm)
            # Take a step of gradient descent.
            optimizer.step()

            sample_cnt += mini_batch.tgt[0].size(0)

            if idx >= len(train) * self.config.train_ratio_per_epoch:
                break

        return avg_loss, param_norm, avg_grad_norm
Beispiel #4
0
    def train_epoch(self, 
                    train, 
                    optimizer, 
                    batch_size=64, 
                    verbose=VERBOSE_SILENT
                    ):
        '''
        Train an epoch with given train iterator and optimizer.
        '''
        total_loss, total_param_norm, total_grad_norm = 0, 0, 0
        avg_loss, avg_param_norm, avg_grad_norm = 0, 0, 0
        sample_cnt = 0

        progress_bar = tqdm(train, 
                            desc='Training: ', 
                            unit='batch'
                            ) if verbose is VERBOSE_BATCH_WISE else train
        # Iterate whole train-set.
        for idx, mini_batch in enumerate(progress_bar):
            x, y = mini_batch.text, mini_batch.label
            # Don't forget make grad zero before another back-prop.
            optimizer.zero_grad()

            y_hat = self.model(x)

            loss = self.get_loss(y_hat, y)
            loss.backward()

            total_loss += loss
            total_param_norm += utils.get_parameter_norm(self.model.parameters())
            total_grad_norm += utils.get_grad_norm(self.model.parameters())

            # Caluclation to show status
            avg_loss = total_loss / (idx + 1)
            avg_param_norm = total_param_norm / (idx + 1)
            avg_grad_norm = total_grad_norm / (idx + 1)

            if verbose is VERBOSE_BATCH_WISE:
                progress_bar.set_postfix_str('|param|=%.2f |g_param|=%.2f loss=%.4e' % (avg_param_norm,
                                                                                        avg_grad_norm,
                                                                                        avg_loss
                                                                                        ))

            optimizer.step()

            sample_cnt += mini_batch.text.size(0)
            if sample_cnt >= len(train.dataset.examples):
                break

        if verbose is VERBOSE_BATCH_WISE:
            progress_bar.close()

        return avg_loss, avg_param_norm, avg_grad_norm
Beispiel #5
0
    def train(engine, mini_batch):
        from utils import get_grad_norm, get_parameter_norm

        # You have to reset the gradients of all model parameters
        # before to take another step in gradient descent.
        engine.model.train()
        engine.optimizer.zero_grad()

        # if 'is_src_target' is true, the trainer would train language model for source language.
        # For dsl case, both x and y has BOS and EOS tokens.
        # Thus, we need to remove BOS and EOS before the training.
        x = mini_batch.src[
            0][:, :-1] if engine.is_src_target else mini_batch.tgt[0][:, :-1]
        y = mini_batch.src[
            0][:, 1:] if engine.is_src_target else mini_batch.tgt[0][:, 1:]
        # |x| = |y| = (batch_size, length)

        y_hat = engine.model(x)
        # |y_hat| = (batch_size, length, output_size)

        loss = engine.crit(
            y_hat.contiguous().view(-1, y_hat.size(-1)),
            y.contiguous().view(-1),
        ).sum()
        loss.div(y.size(0)).backward()
        word_count = int(
            mini_batch.src[1].sum()) if engine.is_src_target else int(
                mini_batch.tgt[1].sum())

        p_norm = float(get_parameter_norm(engine.model.parameters()))
        g_norm = float(get_grad_norm(engine.model.parameters()))

        # In orther to avoid gradient exploding, we apply gradient clipping.
        torch_utils.clip_grad_norm_(
            engine.model.parameters(),
            engine.config.max_grad_norm,
        )
        # Take a step of gradient descent.
        engine.optimizer.step()

        return {
            'loss': float(loss / word_count),
            '|param|': p_norm,
            '|g_param|': g_norm,
        }
Beispiel #6
0
    def step(engine, mini_batch):
        from utils import get_grad_norm, get_parameter_norm

        engine.model.train()
        engine.optimizer.zero_grad()

        x, y = mini_batch.text, mini_batch.label

        y_hat = engine.model(x)
        loss = engine.crit(y_hat, y)
        loss.backward()

        p_norm = float(get_parameter_norm(engine.model.parameters()))
        g_norm = float(get_grad_norm(engine.model.parameters()))

        engine.optimizer.step()

        return float(loss), p_norm, g_norm
Beispiel #7
0
    def train(engine, mini_batch):
        # You have to reset the gradients of all model parameters
        # before to take another step in gradient descent.
        engine.model.train(
        )  # Because we assign model as class variable, we can easily access to it.
        engine.optimizer.zero_grad()

        images, targets = mini_batch
        images = list(image.to(engine.device) for image in images)
        targets = [{k: v.to(engine.device)
                    for k, v in t.items()} for t in targets]

        # Take feed-forward

        losses = engine.model(images, targets)
        loss = sum(loss for loss in losses.values())
        loss.backward()

        # Calculate accuracy only if 'y' is LongTensor,
        # which means that 'y' is one-hot representation.
        if isinstance(targets[0]["boxes"], torch.LongTensor) or isinstance(
                targets[0]["boxes"], torch.cuda.LongTensor):
            accuracy = (torch.argmax(targets[0]["boxes"], dim=-1)
                        == y).sum() / float(targets[0]["boxes"].size(0))
        else:
            accuracy = 0

        p_norm = float(get_parameter_norm(engine.model.parameters()))
        g_norm = float(get_grad_norm(engine.model.parameters()))

        # Take a step of gradient descent.
        engine.optimizer.step()

        # Take a step of scheduler
        engine.scheduler.step(loss)

        return {
            'loss': float(loss),
            'accuracy': float(accuracy),
            '|param|': p_norm,
            '|g_param|': g_norm,
        }
Beispiel #8
0
    def train(engine, mini_batch):
        # You have to reset the gradients of all model parameters
        # before to take another step in gradient descent.
        engine.model.train(
        )  # Because we assign model as class variable, we can easily access to it.
        engine.optimizer.zero_grad()

        x, y = mini_batch
        x, y = x.to(engine.device), y.to(engine.device)
        # 모델과 같은 device를 할당

        # Take feed-forward
        y_hat = engine.model(x)
        # y_hat = (bs, 10)# 10차원의 확률값이 나옴

        loss = engine.crit(y_hat, y)  # crit를 지나면 loss는 scaler값이 됨
        loss.backward()

        # Calculate accuracy only if 'y' is LongTensor,
        # which means that 'y' is one-hot representation.
        if isinstance(y, torch.LongTensor) or isinstance(
                y, torch.cuda.LongTensor):
            accuracy = (torch.argmax(y_hat, dim=-1) == y).sum() / float(
                y.size(0))
        else:
            accuracy = 0

        p_norm = float(get_parameter_norm(engine.model.parameters()))
        # 파라미터의 l2 norm (학습이 진행될수록 커져야함)
        g_norm = float(get_grad_norm(engine.model.parameters()))
        # gradient의 l2 norm (학습이 진행될수록 작아져야함)
        # p_norm, g_nomr을 통해 학습이 잘되고 있는지 판단하는 지표라고 생각

        engine.optimizer.step()
        # 경사하강 스텝을 수행하라는 코드

        return {
            'loss': float(loss),
            'accuracy': float(accuracy),
            '|param|': p_norm,
            '|g_param|': g_norm,
        }
Beispiel #9
0
    def train(engine, mini_batch):
        from utils import get_grad_norm, get_parameter_norm

        # You have to reset the gradients of all model parameters
        # before to take another step in gradient descent.
        engine.model.train()        
        engine.optimizer.zero_grad()

        # Raw target variable has both BOS and EOS token. 
        # The output of sequence-to-sequence does not have BOS token. 
        # Thus, remove BOS token for reference.
        x, y = mini_batch.src, mini_batch.tgt[0][:, 1:]
        # |x| = (batch_size, length)
        # |y| = (batch_size, length)

        # Take feed-forward
        # Similar as before, the input of decoder does not have EOS token.
        # Thus, remove EOS token for decoder input.
        y_hat = engine.model(x, mini_batch.tgt[0][:, :-1])
        # |y_hat| = (batch_size, length, output_size)

        loss = engine.crit(y_hat.contiguous().view(-1, y_hat.size(-1)),
                           y.contiguous().view(-1)
                           )
        loss.div(y.size(0)).backward()
        word_count = int(mini_batch.tgt[1].sum())

        p_norm = float(get_parameter_norm(engine.model.parameters()))
        g_norm = float(get_grad_norm(engine.model.parameters()))

        # In orther to avoid gradient exploding, we apply gradient clipping.
        torch_utils.clip_grad_norm_(
            engine.model.parameters(),
            engine.config.max_grad_norm,
        )
        # Take a step of gradient descent.
        engine.optimizer.step()

        if engine.config.use_noam_decay and engine.lr_scheduler is not None:
            engine.lr_scheduler.step()

        return float(loss / word_count), p_norm, g_norm
    def train(engine, mini_batch):
        engine.model.train()
        engine.optimizer.zero_grad()
        x, y = mini_batch
        # 학습하는 중간에 데이터를 GPU로 전송해줌
        x, y = x.to(engine.device), y.to(engine.device)

        pred_y = engine.model(x)

        loss = engine.crit(pred_y, y)
        loss.backward()

        # 회귀 예측의 경우, 어큐러시 측정이 불가능하므로 스킵해야 함
        # Y가 정수로 나온다면 Classification일 것이라 가정
        if isinstance(y, torch.LongTensor) or isinstance(
                y, torch.cuda.LongTensor):
            accuracy = (torch.argmax(pred_y, dim=-1) == y).sum() / float(
                y.size(0))
        else:
            accuracy = 0

        # 파라미터 놈: 학습이 진행될수록 점진적으로 커짐
        p_norm = float(get_parameter_norm(engine.model.parameters()))
        # 그래드 놈: 값이 크면 클수록 많이 배우고 있다는 뜻
        # 처음 시작할 때는 많이 배우기 때문에 값이 클 것임 (=기울기가 가파름)
        # 진행됨에 따라서 서서히 줄어드는 게 보편적 (물론 아닐 수도 있음)
        # 적었다 커졌다 날뛰거나, Nan으로 Loss 자체가 날라가서 학습이 실패할수도 있음
        # 즉, 학습의 안정성을 보장함
        g_norm = float(get_grad_norm(engine.model.parameters()))

        engine.optimizer.step()

        return {
            'loss': float(loss),
            'accuracy': float(accuracy),
            '|param|': p_norm,
            '|g_param|': g_norm,
        }
Beispiel #11
0
    def train_epoch(self, train, optimizer, verbose=VERBOSE_BATCH_WISE):
        '''
        Train an epoch with given train iterator and optimizer.
        '''
        total_loss, total_word_count = 0, 0
        total_grad_norm = 0
        avg_loss, avg_grad_norm = 0, 0
        sample_cnt = 0

        progress_bar = tqdm(
            train, desc='Training: ',
            unit='batch') if verbose is VERBOSE_BATCH_WISE else train
        # Iterate whole train-set.
        for idx, mini_batch in enumerate(progress_bar):
            # Raw target variable has both BOS and EOS token.
            # The output of sequence-to-sequence does not have BOS token.
            # Thus, remove BOS token for reference.
            x = mini_batch.src[0][:, :-1] if self.is_src else mini_batch.tgt[
                0][:, :-1]
            y = mini_batch.src[0][:,
                                  1:] if self.is_src else mini_batch.tgt[0][:,
                                                                            1:]
            # |x| = (batch_size, length)

            # You have to reset the gradients of all model parameters before to take another step in gradient descent.
            optimizer.zero_grad()

            # Take feed-forward
            y_hat = self.model(x)
            loss = self._get_loss(y_hat, y).sum()
            loss.div(y.size(0)).backward()

            # Simple math to show stats.
            # Don't forget to detach final variables.
            total_loss += float(loss)
            total_word_count += int(mini_batch.src[1].sum() if self.
                                    is_src else mini_batch.tgt[1].sum())
            param_norm = float(
                utils.get_parameter_norm(self.model.parameters()))
            total_grad_norm += float(
                utils.get_grad_norm(self.model.parameters()))

            avg_loss = total_loss / total_word_count
            avg_grad_norm = total_grad_norm / (idx + 1)

            if verbose is VERBOSE_BATCH_WISE:
                progress_bar.set_postfix_str(
                    '|param|=%.2f |g_param|=%.2f loss=%.4e PPL=%.2f' %
                    (param_norm, avg_grad_norm, avg_loss, exp(avg_loss)))

            # In orther to avoid gradient exploding, we apply gradient clipping.
            torch_utils.clip_grad_norm_(self.model.parameters(),
                                        self.config.max_grad_norm)
            # Take a step of gradient descent.
            optimizer.step()

            sample_cnt += x.size(0)
            if idx >= len(progress_bar):
                break

        if verbose is VERBOSE_BATCH_WISE:
            progress_bar.close()

        return avg_loss, param_norm, avg_grad_norm
Beispiel #12
0
def train_epoch(model, criterion, train_iter, valid_iter, config, start_epoch = 1, others_to_save = None):
    current_lr = config.rl_lr

    highest_valid_bleu = -np.inf
    no_improve_cnt = 0

    # Print initial valid BLEU before we start RL.
    model.eval()
    total_reward, sample_cnt = 0, 0
    for batch_index, batch in enumerate(valid_iter):
        current_batch_word_cnt = torch.sum(batch.tgt[1])
        x = batch.src
        y = batch.tgt[0][:, 1:]
        batch_size = y.size(0)
        # |x| = (batch_size, length)
        # |y| = (batch_size, length)

        # feed-forward
        y_hat, indice = model.search(x, is_greedy = True, max_length = config.max_length)
        # |y_hat| = (batch_size, length, output_size)
        # |indice| = (batch_size, length)

        reward = get_reward(y, indice, n_gram = config.rl_n_gram)

        total_reward += float(reward.sum())
        sample_cnt += batch_size
        if sample_cnt >= len(valid_iter.dataset.examples):
            break
    avg_bleu = total_reward / sample_cnt
    print("initial valid BLEU: %.4f" % avg_bleu) # You can figure-out improvement.
    model.train() # Now, begin training.

    # Start RL
    for epoch in range(start_epoch, config.rl_n_epochs + 1):
        #optimizer = optim.Adam(model.parameters(), lr = current_lr)
        optimizer = optim.SGD(model.parameters(), lr = current_lr) # Default hyper-parameter is set for SGD.
        print("current learning rate: %f" % current_lr)
        print(optimizer)

        sample_cnt = 0
        total_loss, total_bleu, total_sample_count, total_word_count, total_parameter_norm, total_grad_norm = 0, 0, 0, 0, 0, 0
        start_time = time.time()
        train_bleu = np.inf

        for batch_index, batch in enumerate(train_iter):
            optimizer.zero_grad()

            current_batch_word_cnt = torch.sum(batch.tgt[1])
            x = batch.src  
            y = batch.tgt[0][:, 1:]  # 정답, 앞의 BOS를 뺀다. - 샘플링으로 진행하기 떄문?
            batch_size = y.size(0)
            # |x| = (batch_size, length)
            # |y| = (batch_size, length)

            # Take sampling process because set False for is_greedy.
            y_hat, indice = model.search(x, is_greedy = False, max_length = config.max_length)
            # Based on the result of sampling, get reward.
            _a_actor = get_reward(y, indice, n_gram = config.rl_n_gram)
            # |y_hat| = (batch_size, length, output_size) ## 샘플링하기 전의 Softmax 값
            # |indice| = (batch_size, length) ## Long Tensor - 샘플링을 통해 만들어지다..?
            # |q_actor| = (batch_size)

            # Take samples as many as n_samples, and get average rewards for them.
            # I figured out that n_samples = 1 would be enough.
            baseline = []
            with torch.no_grad():
                for i in range(config.n_samples):
                    _, sampled_indice = model.search(x, is_greedy = False, max_length = config.max_length)
                    baseline += [get_reward(y, sampled_indice, n_gram = config.rl_n_gram)]
                baseline = torch.stack(baseline).sum(dim = 0).div(config.n_samples)
                # |baseline| = (n_samples, batch_size) --> (batch_size)

            # Now, we have relatively expected cumulative reward.
            # Which score can be drawn from q_actor subtracted by baseline.
            tmp_reward = q_actor - baseline
            # |tmp_reward| = (batch_size)
            # calcuate gradients with back-propagation
            get_gradient(indice, y_hat, criterion, reward = tmp_reward)

            # simple math to show stats
            total_loss += float(tmp_reward.sum())
            total_bleu += float(q_actor.sum())
            total_sample_count += batch_size
            total_word_count += int(current_batch_word_cnt)
            total_parameter_norm += float(utils.get_parameter_norm(model.parameters()))
            total_grad_norm += float(utils.get_grad_norm(model.parameters()))

            if (batch_index + 1) % config.print_every == 0:
                avg_loss = total_loss / total_sample_count
                avg_bleu = total_bleu / total_sample_count
                avg_parameter_norm = total_parameter_norm / config.print_every
                avg_grad_norm = total_grad_norm / config.print_every
                elapsed_time = time.time() - start_time

                print("epoch: %d batch: %d/%d\t|param|: %.2f\t|g_param|: %.2f\trwd: %.4f\tBLEU: %.4f\t%5d words/s %3d secs" % (epoch, 
                                                                                                            batch_index + 1, 
                                                                                                            int(len(train_iter.dataset.examples) // config.batch_size), 
                                                                                                            avg_parameter_norm, 
                                                                                                            avg_grad_norm, 
                                                                                                            avg_loss,
                                                                                                            avg_bleu,
                                                                                                            total_word_count // elapsed_time,
                                                                                                            elapsed_time
                                                                                                            ))

                total_loss, total_bleu, total_sample_count, total_word_count, total_parameter_norm, total_grad_norm = 0, 0, 0, 0, 0, 0
                start_time = time.time()

                train_bleu = avg_bleu

            # In orther to avoid gradient exploding, we apply gradient clipping.
            torch_utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
            # Take a step of gradient descent.
            optimizer.step()

            sample_cnt += batch_size
            if sample_cnt >= len(train_iter.dataset.examples):
                break

        sample_cnt = 0
        total_reward = 0

        # Start validation
        with torch.no_grad():
            model.eval() # Turn-off drop-out

            for batch_index, batch in enumerate(valid_iter):
                current_batch_word_cnt = torch.sum(batch.tgt[1])
                x = batch.src
                y = batch.tgt[0][:, 1:]
                batch_size = y.size(0)
                # |x| = (batch_size, length)
                # |y| = (batch_size, length)

                # feed-forward
                y_hat, indice = model.search(x, is_greedy = True, max_length = config.max_length)
                # |y_hat| = (batch_size, length, output_size)
                # |indice| = (batch_size, length)

                reward = get_reward(y, indice, n_gram = config.rl_n_gram)

                total_reward += float(reward.sum())
                sample_cnt += batch_size
                if sample_cnt >= len(valid_iter.dataset.examples):
                    break

            avg_bleu = total_reward / sample_cnt
            print("valid BLEU: %.4f" % avg_bleu)

            if highest_valid_bleu < avg_bleu:
                highest_valid_bleu = avg_bleu
                no_improve_cnt = 0
            else:
                no_improve_cnt += 1

            model.train()

        model_fn = config.model.split(".")
        model_fn = model_fn[:-1] + ["%02d" % (config.n_epochs + epoch), "%.2f-%.4f" % (train_bleu, avg_bleu)] + [model_fn[-1]]

        # PyTorch provides efficient method for save and load model, which uses python pickle.
        to_save = {"model": model.state_dict(),
                    "config": config,
                    "epoch": config.n_epochs + epoch + 1,
                    "current_lr": current_lr
                    }
        if others_to_save is not None:
            for k, v in others_to_save.items():
                to_save[k] = v
        torch.save(to_save, '.'.join(model_fn))

        if config.early_stop > 0 and no_improve_cnt > config.early_stop:
            break
Beispiel #13
0
    def train_epoch(self,
                    train,
                    optimizer,
                    max_grad_norm=5,
                    verbose=VERBOSE_SILENT
                    ):
        '''
        Train an epoch with given train iterator and optimizer.
        '''
        total_reward, total_actor_reward = 0, 0
        total_grad_norm = 0
        avg_reward, avg_actor_reward = 0, 0
        avg_param_norm, avg_grad_norm = 0, 0
        sample_cnt = 0

        progress_bar = tqdm(train,
                            desc='Training: ',
                            unit='batch'
                            ) if verbose is VERBOSE_BATCH_WISE else train
        # Iterate whole train-set.
        for idx, mini_batch in enumerate(progress_bar):
            # Raw target variable has both BOS and EOS token. 
            # The output of sequence-to-sequence does not have BOS token. 
            # Thus, remove BOS token for reference.
            x, y = mini_batch.src, mini_batch.tgt[0][:, 1:]
            # |x| = (batch_size, length)
            # |y| = (batch_size, length)
            
            # You have to reset the gradients of all model parameters before to take another step in gradient descent.
            optimizer.zero_grad()

            # Take sampling process because set False for is_greedy.
            y_hat, indice = self.model.search(x,
                                              is_greedy=False,
                                              max_length=self.config.max_length
                                              )
            # Based on the result of sampling, get reward.
            actor_reward = self._get_reward(indice,
                                            y,
                                            n_gram=self.config.rl_n_gram
                                            )
            # |y_hat| = (batch_size, length, output_size)
            # |indice| = (batch_size, length)
            # |actor_reward| = (batch_size)

            # Take samples as many as n_samples, and get average rewards for them.
            # I figured out that n_samples = 1 would be enough.
            baseline = []
            with torch.no_grad():
                for i in range(self.config.n_samples):
                    _, sampled_indice = self.model.search(x,
                                                          is_greedy=False,
                                                          max_length=self.config.max_length
                                                          )
                    baseline += [self._get_reward(sampled_indice,
                                                  y,
                                                  n_gram=self.config.rl_n_gram
                                                  )]
                baseline = torch.stack(baseline).sum(dim=0).div(self.config.n_samples)
                # |baseline| = (n_samples, batch_size) --> (batch_size)

            # Now, we have relatively expected cumulative reward.
            # Which score can be drawn from actor_reward subtracted by baseline.
            final_reward = actor_reward - baseline
            # |final_reward| = (batch_size)

            # calcuate gradients with back-propagation
            self._get_gradient(y_hat, indice, reward=final_reward)
            
            # Simple math to show stats.
            total_reward += float(final_reward.sum())
            total_actor_reward += float(actor_reward.sum())
            sample_cnt += int(actor_reward.size(0))
            param_norm = float(utils.get_parameter_norm(self.model.parameters()))
            total_grad_norm += float(utils.get_grad_norm(self.model.parameters()))

            avg_reward = total_reward / sample_cnt
            avg_actor_reward = total_actor_reward / sample_cnt
            avg_grad_norm = total_grad_norm / (idx + 1)

            if verbose is VERBOSE_BATCH_WISE:
                progress_bar.set_postfix_str('|g_param|=%.2f rwd=%4.2f avg_frwd=%.2e BLEU=%.4f' % (avg_grad_norm,
                                                                                                 float(actor_reward.sum().div(y.size(0))),
                                                                                                 avg_reward,
                                                                                                 avg_actor_reward
                                                                                                 ))

            # In orther to avoid gradient exploding, we apply gradient clipping.
            torch_utils.clip_grad_norm_(self.model.parameters(),
                                        self.config.max_grad_norm
                                        )
            # Take a step of gradient descent.
            optimizer.step()


            if idx >= len(progress_bar) * self.config.train_ratio_per_epoch:
                break

        if verbose is VERBOSE_BATCH_WISE:
            progress_bar.close()

        return avg_actor_reward, param_norm, avg_grad_norm
Beispiel #14
0
def train_epoch(model,
                criterion,
                train_iter,
                valid_iter,
                config,
                start_epoch=1,
                others_to_save=None):
    current_lr = config.lr

    lowest_valid_loss = np.inf
    no_improve_cnt = 0

    for epoch in range(start_epoch, config.n_epochs + 1):
        if config.adam:
            optimizer = optim.Adam(model.parameters(), lr=current_lr)
        else:
            optimizer = optim.SGD(model.parameters(), lr=current_lr)
        print("current learning rate: %f" % current_lr)
        print(optimizer)

        sample_cnt = 0
        total_loss, total_word_count, total_parameter_norm, total_grad_norm = 0, 0, 0, 0
        start_time = time.time()
        train_loss = np.inf

        for batch_index, batch in enumerate(train_iter):
            optimizer.zero_grad()

            current_batch_word_cnt = torch.sum(batch.tgt[1])
            x = batch.src
            y = batch.tgt[0][:, 1:]

            # |x| = (batch_size, length)
            # |y| = (batch_size, length)

            # feed-forward
            y_hat = model(x, batch.tgt[0][:, :-1])

            # |y_hat| = (batch_size, length, output_size)

            # calcuate loss and gradients with back-propagation
            loss = get_loss(y, y_hat, criterion)

            # simple math to show stats
            total_loss += float(loss)
            total_word_count += int(current_batch_word_cnt)
            total_parameter_norm += float(
                utils.get_parameter_norm(model.parameters()))
            total_grad_norm += float(utils.get_grad_norm(model.parameters()))

            if (batch_index + 1) % config.print_every == 0:
                avg_loss = total_loss / total_word_count
                avg_parameter_norm = total_parameter_norm / config.print_every
                avg_grad_norm = total_grad_norm / config.print_every
                elapsed_time = time.time() - start_time

                print(
                    "epoch: %d batch: %d/%d\t|param|: %.2f\t|g_param|: %.2f\tloss: %.4f\tPPL: %.2f\t%5d words/s %3d secs"
                    % (epoch, batch_index + 1,
                       int(
                           len(train_iter.dataset.examples) //
                           config.batch_size), avg_parameter_norm,
                       avg_grad_norm, avg_loss, np.exp(avg_loss),
                       total_word_count // elapsed_time, elapsed_time))

                total_loss, total_word_count, total_parameter_norm, total_grad_norm = 0, 0, 0, 0
                start_time = time.time()

                train_loss = avg_loss

            # Another important line in this method.
            # In orther to avoid gradient exploding, we apply gradient clipping.
            torch_utils.clip_grad_norm_(model.parameters(),
                                        config.max_grad_norm)
            # Take a step of gradient descent.
            optimizer.step()

            sample_cnt += batch.tgt[0].size(0)
            if sample_cnt >= len(train_iter.dataset.examples):
                break

        sample_cnt = 0
        total_loss, total_word_count = 0, 0

        with torch.no_grad():
            model.eval()

            for batch_index, batch in enumerate(valid_iter):
                current_batch_word_cnt = torch.sum(batch.tgt[1])
                x = batch.src
                y = batch.tgt[0][:, 1:]

                # |x| = (batch_size, length)
                # |y| = (batch_size, length)

                # feed-forward
                y_hat = model(x, batch.tgt[0][:, :-1])

                # |y_hat| = (batch_size, length, output_size)

                loss = get_loss(y, y_hat, criterion, do_backward=False)

                total_loss += float(loss)
                total_word_count += int(current_batch_word_cnt)

                sample_cnt += batch.tgt[0].size(0)
                if sample_cnt >= len(valid_iter.dataset.examples):
                    break

            avg_loss = total_loss / total_word_count
            print("valid loss: %.4f\tPPL: %.2f" % (avg_loss, np.exp(avg_loss)))

            if lowest_valid_loss > avg_loss:
                lowest_valid_loss = avg_loss
                no_improve_cnt = 0

                if epoch >= config.lr_decay_start_at:
                    current_lr = max(config.min_lr,
                                     current_lr * config.lr_decay_rate)
            else:
                # decrease learing rate if there is no improvement.
                current_lr = max(config.min_lr,
                                 current_lr * config.lr_decay_rate)
                no_improve_cnt += 1

            model.train()

        model_fn = config.model.split(".")
        model_fn = model_fn[:-1] + [
            "%02d" % epoch,
            "%.2f-%.2f" % (train_loss, np.exp(train_loss)),
            "%.2f-%.2f" % (avg_loss, np.exp(avg_loss))
        ] + [model_fn[-1]]

        # PyTorch provides efficient method for save and load model, which uses python pickle.
        to_save = {
            "model": model.state_dict(),
            "config": config,
            "epoch": epoch + 1,
            "current_lr": current_lr
        }
        if others_to_save is not None:
            for k, v in others_to_save.items():
                to_save[k] = v
        torch.save(to_save, '.'.join(model_fn))

        if config.early_stop > 0 and no_improve_cnt > config.early_stop:
            break
Beispiel #15
0
def train_epoch(model, criterion, train_iter, valid_iter, config, start_epoch = 1, others_to_save = None):
    current_lr = config.rl_lr

    highest_valid_bleu = -np.inf
    no_improve_cnt = 0

    # Print initial valid BLEU before we start RL.
    model.eval()
    total_reward, sample_cnt = 0, 0
    for batch_index, batch in enumerate(valid_iter):
        current_batch_word_cnt = torch.sum(batch.tgt[1])
        x = batch.src
        y = batch.tgt[0][:, 1:]
        batch_size = y.size(0)
        # |x| = (batch_size, length)
        # |y| = (batch_size, length)

        # feed-forward
        y_hat, indice = model.search(x, is_greedy = True, max_length = config.max_length)
        # |y_hat| = (batch_size, length, output_size)
        # |indice| = (batch_size, length)

        reward = get_reward(y, indice)

        total_reward += float(reward.sum())
        sample_cnt += batch_size
        if sample_cnt >= len(valid_iter.dataset.examples):
            break
    avg_bleu = total_reward / sample_cnt
    print("initial valid BLEU: %.4f" % avg_bleu)
    model.train()

    # Start RL
    for epoch in range(start_epoch, config.rl_n_epochs + 1):
        #optimizer = optim.Adam(model.parameters(), lr = current_lr)
        optimizer = optim.SGD(model.parameters(), lr = current_lr)
        print("current learning rate: %f" % current_lr)
        print(optimizer)

        sample_cnt = 0
        total_loss, total_bleu, total_sample_count, total_word_count, total_parameter_norm, total_grad_norm = 0, 0, 0, 0, 0, 0
        start_time = time.time()
        train_bleu = np.inf

        for batch_index, batch in enumerate(train_iter):
            optimizer.zero_grad()

            current_batch_word_cnt = torch.sum(batch.tgt[1])
            x = batch.src
            y = batch.tgt[0][:, 1:]
            batch_size = y.size(0)
            # |x| = (batch_size, length)
            # |y| = (batch_size, length)

            # feed-forward
            y_hat, indice = model.search(x, is_greedy = False, max_length = config.max_length)
            q_actor = get_reward(y, indice)
            # |y_hat| = (batch_size, length, output_size)
            # |indice| = (batch_size, length)
            # |q_actor| = (batch_size)

            baseline = []
            with torch.no_grad():
                for i in range(config.n_samples):
                    _, sampled_indice = model.search(x, is_greedy = False, max_length = config.max_length)
                    baseline += [get_reward(y, sampled_indice)]
                baseline = torch.stack(baseline).sum(dim = 0).div(config.n_samples)
                # |baseline| = (n_samples, batch_size) --> (batch_size)

            # calcuate gradients with back-propagation
            tmp_reward = q_actor - baseline
            # |tmp_reward| = (batch_size)
            get_gradient(indice, y_hat, criterion, reward = tmp_reward)

            # simple math to show stats
            total_loss += float(tmp_reward.sum())
            total_bleu += float(q_actor.sum())
            total_sample_count += batch_size
            total_word_count += int(current_batch_word_cnt)
            total_parameter_norm += float(utils.get_parameter_norm(model.parameters()))
            total_grad_norm += float(utils.get_grad_norm(model.parameters()))

            if (batch_index + 1) % config.print_every == 0:
                avg_loss = total_loss / total_sample_count
                avg_bleu = total_bleu / total_sample_count
                avg_parameter_norm = total_parameter_norm / config.print_every
                avg_grad_norm = total_grad_norm / config.print_every
                elapsed_time = time.time() - start_time

                print("epoch: %d batch: %d/%d\t|param|: %.2f\t|g_param|: %.2f\trwd: %.4f\tBLEU: %.4f\t%5d words/s %3d secs" % (epoch, 
                                                                                                            batch_index + 1, 
                                                                                                            int(len(train_iter.dataset.examples) // config.batch_size), 
                                                                                                            avg_parameter_norm, 
                                                                                                            avg_grad_norm, 
                                                                                                            avg_loss,
                                                                                                            avg_bleu,
                                                                                                            total_word_count // elapsed_time,
                                                                                                            elapsed_time
                                                                                                            ))

                total_loss, total_bleu, total_sample_count, total_word_count, total_parameter_norm, total_grad_norm = 0, 0, 0, 0, 0, 0
                start_time = time.time()

                train_bleu = avg_bleu

            # Another important line in this method.
            # In orther to avoid gradient exploding, we apply gradient clipping.
            torch_utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
            # Take a step of gradient descent.
            optimizer.step()

            sample_cnt += batch_size
            if sample_cnt >= len(train_iter.dataset.examples) * config.rl_ratio_per_epoch:
                break

        sample_cnt = 0
        total_reward = 0

        with torch.no_grad():
            model.eval()

            for batch_index, batch in enumerate(valid_iter):
                current_batch_word_cnt = torch.sum(batch.tgt[1])
                x = batch.src
                y = batch.tgt[0][:, 1:]
                batch_size = y.size(0)
                # |x| = (batch_size, length)
                # |y| = (batch_size, length)

                # feed-forward
                y_hat, indice = model.search(x, is_greedy = True, max_length = config.max_length)
                # |y_hat| = (batch_size, length, output_size)
                # |indice| = (batch_size, length)

                reward = get_reward(y, indice)

                total_reward += float(reward.sum())
                sample_cnt += batch_size
                if sample_cnt >= len(valid_iter.dataset.examples):
                    break

            avg_bleu = total_reward / sample_cnt
            print("valid BLEU: %.4f" % avg_bleu)

            if highest_valid_bleu < avg_bleu:
                highest_valid_bleu = avg_bleu
                no_improve_cnt = 0
            else:
                no_improve_cnt += 1

            model.train()

        model_fn = config.model.split(".")
        model_fn = model_fn[:-1] + ["%02d" % (config.n_epochs + epoch), "%.2f-%.4f" % (train_bleu, avg_bleu)] + [model_fn[-1]]

        # PyTorch provides efficient method for save and load model, which uses python pickle.
        to_save = {"model": model.state_dict(),
                    "config": config,
                    "epoch": config.n_epochs + epoch + 1,
                    "current_lr": current_lr
                    }
        if others_to_save is not None:
            for k, v in others_to_save.items():
                to_save[k] = v
        torch.save(to_save, '.'.join(model_fn))

        if config.early_stop > 0 and no_improve_cnt > config.early_stop:
            break
Beispiel #16
0
def train_epoch(model,
                criterion,
                train_iter,
                valid_iter,
                config,
                start_epoch=1,
                others_to_save=None):
    current_lr = config.lr

    lowest_valid_loss = np.inf
    no_improve_cnt = 0

    for epoch in range(start_epoch, config.n_epochs + 1):
        if config.adam:
            optimizer = optim.Adam(model.parameters(), lr=current_lr)
        else:
            optimizer = optim.SGD(model.parameters(), lr=current_lr)
        print("current learning rate: %f" % current_lr)
        print(optimizer)

        sample_cnt = 0
        total_loss, total_word_count, total_parameter_norm, total_grad_norm = 0, 0, 0, 0
        start_time = time.time()
        train_loss = np.inf

        for batch_index, batch in enumerate(train_iter):
            # You have to reset the gradients of all model parameters before to take another step in gradient descent.
            optimizer.zero_grad()

            current_batch_word_cnt = torch.sum(batch.tgt[1])
            x = batch.src
            # Raw target variable has both BOS and EOS token.
            # The output of sequence-to-sequence does not have BOS token.
            # Thus, remove BOS token for reference.
            y = batch.tgt[0][:, 1:]
            # |x| = (batch_size, length)
            # |y| = (batch_size, length)

            # Take feed-forward
            # Similar as before, the input of decoder does not have EOS token.
            # Thus, remove EOS token for decoder input.
            y_hat = model(x, batch.tgt[0][:, :-1])
            # |y_hat| = (batch_size, length, output_size)

            # Calcuate loss and gradients with back-propagation.
            loss = get_loss(y, y_hat, criterion)

            # Simple math to show stats.
            total_loss += float(loss)
            total_word_count += int(current_batch_word_cnt)
            total_parameter_norm += float(
                utils.get_parameter_norm(model.parameters()))
            total_grad_norm += float(utils.get_grad_norm(model.parameters()))

            # Print current training status in every this number of mini-batch is done.
            if (batch_index + 1) % config.print_every == 0:
                avg_loss = total_loss / total_word_count
                avg_parameter_norm = total_parameter_norm / config.print_every
                avg_grad_norm = total_grad_norm / config.print_every
                elapsed_time = time.time() - start_time

                # You can check the current status using parameter norm and gradient norm.
                # Also, you can check the speed of the training.
                print(
                    "epoch: %d batch: %d/%d\t|param|: %.2f\t|g_param|: %.2f\tloss: %.4f\tPPL: %.2f\t%5d words/s %3d secs"
                    % (epoch, batch_index + 1,
                       int(
                           len(train_iter.dataset.examples) //
                           config.batch_size), avg_parameter_norm,
                       avg_grad_norm, avg_loss, np.exp(avg_loss),
                       total_word_count // elapsed_time, elapsed_time))

                total_loss, total_word_count, total_parameter_norm, total_grad_norm = 0, 0, 0, 0
                start_time = time.time()

                train_loss = avg_loss

            # In orther to avoid gradient exploding, we apply gradient clipping.
            torch_utils.clip_grad_norm_(model.parameters(),
                                        config.max_grad_norm)
            ##### timestep이 다를 수록 gradient가 쌓인다.

            # Take a step of gradient descent.
            optimizer.step()

            sample_cnt += batch.tgt[0].size(0)
            if sample_cnt >= len(train_iter.dataset.examples):
                break

        sample_cnt = 0
        total_loss, total_word_count = 0, 0

        with torch.no_grad():  # In validation, we don't need to get gradients.
            model.eval()  # Turn-on the evaluation mode.

            for batch_index, batch in enumerate(valid_iter):
                current_batch_word_cnt = torch.sum(batch.tgt[1])
                x = batch.src
                y = batch.tgt[0][:, 1:]
                # |x| = (batch_size, length)
                # |y| = (batch_size, length)

                # Take feed-forward
                y_hat = model(x, batch.tgt[0][:, :-1])
                # |y_hat| = (batch_size, length, output_size)

                loss = get_loss(y, y_hat, criterion, do_backward=False)

                total_loss += float(loss)
                total_word_count += int(current_batch_word_cnt)

                sample_cnt += batch.tgt[0].size(0)
                if sample_cnt >= len(valid_iter.dataset.examples):
                    break

            # Print result of validation.
            avg_loss = total_loss / total_word_count
            print("valid loss: %.4f\tPPL: %.2f" % (avg_loss, np.exp(avg_loss)))

            if lowest_valid_loss > avg_loss:
                lowest_valid_loss = avg_loss
                no_improve_cnt = 0

                # Altough there is an improvement in last epoch, we need to decay the learning-rate if it meets the requirements.
                if epoch >= config.lr_decay_start_at:
                    current_lr = max(config.min_lr,
                                     current_lr * config.lr_decay_rate)
            else:
                # Decrease learing rate if there is no improvement.
                current_lr = max(config.min_lr,
                                 current_lr * config.lr_decay_rate)
                no_improve_cnt += 1

            # Again, turn-on the training mode.
            model.train()

        # Set a filename for model of last epoch.
        # We need to put every information to filename, as much as possible.
        model_fn = config.model.split(".")
        model_fn = model_fn[:-1] + [
            "%02d" % epoch,
            "%.2f-%.2f" % (train_loss, np.exp(train_loss)),
            "%.2f-%.2f" % (avg_loss, np.exp(avg_loss))
        ] + [model_fn[-1]]

        # PyTorch provides efficient method for save and load model, which uses python pickle.
        to_save = {
            "model": model.state_dict(),
            "config": config,
            "epoch": epoch + 1,
            "current_lr": current_lr
        }
        if others_to_save is not None:  # Add others if it is necessary.
            for k, v in others_to_save.items():
                to_save[k] = v
        torch.save(to_save, '.'.join(model_fn))

        # Take early stopping if it meets the requirement.
        if config.early_stop > 0 and no_improve_cnt > config.early_stop:
            break
Beispiel #17
0
def train_epoch(model,
                bimpm,
                criterion,
                train_iter,
                valid_iter,
                config,
                start_epoch=1,
                others_to_save=None,
                valid_nli_iter=None):
    current_lr = config.rl_lr

    highest_valid_bleu = -np.inf
    no_improve_cnt = 0

    # Print initial valid BLEU before we start RL.
    model.eval()
    total_reward, sample_cnt = 0, 0
    for batch_index, batch in enumerate(valid_iter):
        current_batch_word_cnt = torch.sum(batch.tgt[1])
        x = batch.src
        y = batch.tgt[0][:, 1:]
        batch_size = y.size(0)
        # |x| = (batch_size, length)
        # |y| = (batch_size, length)

        # feed-forward
        y_hat, indice = model.search(x,
                                     is_greedy=True,
                                     max_length=config.max_length)
        # |y_hat| = (batch_size, length, output_size)
        # |indice| = (batch_size, length)

        reward = get_bleu_reward(y,
                                 indice,
                                 n_gram=min(config.rl_n_gram, indice.size(1)))

        total_reward += float(reward.sum())
        sample_cnt += batch_size
        if sample_cnt >= len(valid_iter.dataset.examples):
            break
    avg_bleu = total_reward / sample_cnt
    print("initial valid BLEU: %.4f" %
          avg_bleu)  # You can figure-out improvement.

    if valid_nli_iter:
        nli_validation(valid_nli_iter, model, bimpm, config)
    model.train()  # Now, begin training.

    # Start RL
    nli_criterion = nn.CrossEntropyLoss(reduce=False)
    print("start epoch:", start_epoch)
    print("number of epoch to complete:", config.rl_n_epochs + 1)

    if config.reward_mode == 'combined':
        if config.gpu_id >= 0:
            nli_weight = torch.tensor([1.0], requires_grad=True, device="cuda")
            bleu_weight = torch.tensor([1.0],
                                       requires_grad=True,
                                       device="cuda")
        else:
            nli_weight = torch.tensor([1.0], requires_grad=True)
            bleu_weight = torch.tensor([1.0], requires_grad=True)

        print("nli_weight, bleu_weight:",
              nli_weight.data.cpu().numpy()[0],
              bleu_weight.data.cpu().numpy()[0])
        weight_optimizer = optim.Adam(iter([nli_weight, bleu_weight]),
                                      lr=0.0001)

    optimizer = optim.SGD(
        model.parameters(), lr=current_lr,
        momentum=0.9)  # Default hyper-parameter is set for SGD.
    print("current learning rate: %f" % current_lr)
    print(optimizer)

    for epoch in range(start_epoch, config.rl_n_epochs + 1):
        sample_cnt = 0
        total_risk, total_errors, total_sample_count, total_word_count, total_parameter_norm, total_grad_norm = 0, 0, 0, 0, 0, 0
        total_label_correct = np.zeros(3)
        total_label_size = np.zeros(3)
        start_time = time.time()
        train_loss = np.inf

        for batch_index, batch in enumerate(train_iter):
            optimizer.zero_grad()

            current_batch_word_cnt = torch.sum(batch.tgt[1])
            x = batch.src
            y = batch.tgt[0][:, 1:]
            batch_size = y.size(0)
            epoch_accuracy = []
            max_sample_length = 0
            sequence_probs, errors = [], []
            if config.reward_mode != 'bleu':
                premise = batch.premise
                hypothesis = batch.hypothesis
                isSrcPremise = batch.isSrcPremise
                label = batch.labels

            # |x| = (batch_size, length)
            # |y| = (batch_size, length)

            for _ in range(config.n_samples):
                # Take sampling process because set False for is_greedy.
                y_hat, indice = model.search(x,
                                             is_greedy=False,
                                             max_length=config.max_length)
                max_sample_length = max(max_sample_length, indice.size(1))
                prob = y_hat.gather(2, indice.unsqueeze(2)).squeeze(2)
                sequence_probs.append(prob)
                # |prob| = (batch_size, length)

                if config.reward_mode == 'bleu':
                    bleu = get_bleu_reward(y, indice, n_gram=config.rl_n_gram)
                    reward = 1 - bleu / 100
                    epoch_accuracy.append(bleu.sum() / batch_size)
                else:
                    padded_indice, padded_premise, padded_hypothesis = padding_three_tensors(
                        indice, premise, hypothesis, batch_size)

                    # put pred sentece into either premise and hypothesis
                    for i in range(batch_size):
                        if isSrcPremise[i]:
                            padded_premise[i] = padded_indice[i]
                        else:
                            padded_hypothesis[i] = padded_indice[i]

                    with torch.no_grad():
                        kwargs = {'p': padded_premise, 'h': padded_hypothesis}
                        pred_logit = bimpm(**kwargs)
                        accuracy = get_accuracy(pred_logit, label)
                        num_correct, size = get_label_accuracy(
                            pred_logit, label)
                        total_label_correct += num_correct
                        total_label_size += size
                    epoch_accuracy.append(accuracy)

                    # Based on the result of sampling, get reward.
                    if config.reward_mode == 'nli':
                        reward = get_nli_reward(pred_logit, label,
                                                nli_criterion)
                    else:
                        reward = 1/(2 * nli_weight.pow(2)) * get_nli_reward(pred_logit, label, nli_criterion) \
                            + 1/(2 * bleu_weight.pow(2)) * (1 - get_bleu_reward(y, indice, n_gram=config.rl_n_gram)/100) \
                            + torch.log(nli_weight * bleu_weight)
                # |y_hat| = (batch_size, length, output_size)
                # |indice| = (batch_size, length)
                # |reward| = (batch_size)
                errors.append(reward)

            padded_probs = pad_probs(sequence_probs, max_sample_length)
            sequence_probs = torch.stack(padded_probs, dim=2)
            # |sequence_probs| = (batch_size, max_sample_length, sample_size)
            errors = torch.stack(errors, dim=1)
            # |errors| = (batch_size, sample_size)

            avg_probs = sequence_probs.sum(dim=1) / max_sample_length
            if config.temperature != 1.0:
                probs = avg_probs.exp_().pow(1 / config.temperature)
                probs = nn.functional.softmax(probs, dim=1)
            else:
                probs = nn.functional.softmax(avg_probs.exp_(), dim=1)
            risk = (probs * errors).sum() / batch_size
            risk.backward()

            # simple math to show stats
            total_risk += float(risk.sum())
            total_errors += float(reward.sum())
            total_sample_count += batch_size
            total_word_count += int(current_batch_word_cnt)
            total_parameter_norm += float(
                utils.get_parameter_norm(model.parameters()))
            total_grad_norm += float(utils.get_grad_norm(model.parameters()))

            if (batch_index + 1) % config.print_every == 0:
                avg_risk = total_risk / total_sample_count
                avg_errors = total_errors / total_sample_count
                avg_parameter_norm = total_parameter_norm / config.print_every
                avg_grad_norm = total_grad_norm / config.print_every
                avg_epoch_accuracy = sum(epoch_accuracy) / len(epoch_accuracy)
                elapsed_time = time.time() - start_time

                print(
                    "epoch: %d batch: %d/%d\t|param|: %.2f\t|g_param|: %.2f\trisk: %.4f\terror: %.4f\tAccuracy: %.2f\t%5d words/s %3d secs"
                    % (epoch, batch_index + 1,
                       int(
                           len(train_iter.dataset.examples) //
                           config.batch_size), avg_parameter_norm,
                       avg_grad_norm, avg_risk, avg_errors, avg_epoch_accuracy,
                       total_word_count // elapsed_time, elapsed_time))

                print("train label accuracy:",
                      total_label_correct / total_label_size)
                if config.reward_mode == 'combined':
                    print("nli_weight, bleu_weight:",
                          nli_weight.data.cpu().numpy()[0],
                          bleu_weight.data.cpu().numpy()[0])

                total_risk, total_errors, total_sample_count, total_word_count, total_parameter_norm, total_grad_norm = 0, 0, 0, 0, 0, 0
                epoch_accuracy = []
                total_label_correct = np.zeros(3)
                total_label_size = np.zeros(3)
                start_time = time.time()

                train_loss = avg_bleu

            # In orther to avoid gradient exploding, we apply gradient clipping.
            torch_utils.clip_grad_norm_(model.parameters(),
                                        config.max_grad_norm)
            # Take a step of gradient descent.
            optimizer.step()
            if config.reward_mode == 'combined':
                weight_optimizer.step()

            sample_cnt += batch_size
            if sample_cnt >= len(train_iter.dataset.examples):
                break

        sample_cnt = 0
        total_reward = 0

        # Start validation
        with torch.no_grad():
            model.eval()  # Turn-off drop-out

            for batch_index, batch in enumerate(valid_iter):
                current_batch_word_cnt = torch.sum(batch.tgt[1])
                x = batch.src
                y = batch.tgt[0][:, 1:]
                batch_size = y.size(0)
                # |x| = (batch_size, length)
                # |y| = (batch_size, length)

                # feed-forward
                y_hat, indice = model.search(x,
                                             is_greedy=True,
                                             max_length=config.max_length)
                # |y_hat| = (batch_size, length, output_size)
                # |indice| = (batch_size, length)

                reward = get_bleu_reward(y, indice, n_gram=config.rl_n_gram)

                total_reward += float(reward.sum())
                sample_cnt += batch_size
                if sample_cnt >= len(valid_iter.dataset.examples):
                    break

            avg_bleu = total_reward / sample_cnt
            print("valid BLEU: %.4f" % avg_bleu)

            if highest_valid_bleu < avg_bleu:
                highest_valid_bleu = avg_bleu
                no_improve_cnt = 0
            else:
                no_improve_cnt += 1

            if valid_nli_iter:
                nli_validation(valid_nli_iter, model, bimpm, config)
            model.train()

        model_fn = config.model.split(".")
        model_fn = model_fn[:-1] + [
            "%02d" % (config.n_epochs + epoch),
            "%.2f-%.4f" % (train_loss, avg_bleu)
        ] + [model_fn[-1]] + [config.reward_mode]

        # PyTorch provides efficient method for save and load model, which uses python pickle.
        to_save = {
            "model": model.state_dict(),
            "config": config,
            "epoch": config.n_epochs + epoch + 1,
            "current_lr": current_lr
        }
        if others_to_save is not None:
            for k, v in others_to_save.items():
                to_save[k] = v
        torch.save(to_save, '.'.join(model_fn))

        if config.early_stop > 0 and no_improve_cnt > config.early_stop:
            break
Beispiel #18
0
def train_epoch(model,
                bimpm,
                criterion,
                train_iter,
                valid_iter,
                config,
                start_epoch=1,
                others_to_save=None,
                valid_nli_iter=None):
    current_lr = config.rl_lr

    highest_valid_bleu = -np.inf
    no_improve_cnt = 0

    # Print initial valid BLEU before we start RL.
    model.eval()
    total_reward, sample_cnt = 0, 0
    for batch_index, batch in enumerate(valid_iter):
        current_batch_word_cnt = torch.sum(batch.tgt[1])
        x = batch.src
        y = batch.tgt[0][:, 1:]
        batch_size = y.size(0)
        # |x| = (batch_size, length)
        # |y| = (batch_size, length)

        # feed-forward
        y_hat, indice = model.search(x,
                                     is_greedy=True,
                                     max_length=config.max_length)
        # |y_hat| = (batch_size, length, output_size)
        # |indice| = (batch_size, length)

        reward = get_bleu_reward(y, indice, n_gram=config.rl_n_gram)

        total_reward += float(reward.sum())
        sample_cnt += batch_size
        if sample_cnt >= len(valid_iter.dataset.examples):
            break
    avg_bleu = total_reward / sample_cnt
    print("initial valid BLEU: %.4f" %
          avg_bleu)  # You can figure-out improvement.

    if valid_nli_iter:
        nli_validation(valid_nli_iter, model, bimpm, config)
    model.train()  # Now, begin training.

    # Start RL
    nli_criterion = nn.CrossEntropyLoss(reduce=False)
    print("start rl epoch:", start_epoch)
    print("number of epoch to complete:", config.rl_n_epochs + 1)

    if config.reward_mode == 'combined':
        if config.gpu_id >= 0:
            nli_weight = torch.tensor([1.0], requires_grad=True, device="cuda")
            bleu_weight = torch.tensor([1.0],
                                       requires_grad=True,
                                       device="cuda")
        else:
            nli_weight = torch.tensor([1.0], requires_grad=True)
            bleu_weight = torch.tensor([1.0], requires_grad=True)

        print("nli_weight, bleu_weight:",
              nli_weight.data.cpu().numpy()[0],
              bleu_weight.data.cpu().numpy()[0])
        weight_optimizer = optim.Adam(iter([nli_weight, bleu_weight]),
                                      lr=0.0001)

    optimizer = optim.SGD(
        model.parameters(),
        lr=current_lr,
    )  # Default hyper-parameter is set for SGD.
    print("current learning rate: %f" % current_lr)
    print(optimizer)

    for epoch in range(start_epoch, config.rl_n_epochs + 1):
        sample_cnt = 0
        total_loss, total_actor_loss, total_sample_count, total_word_count, total_parameter_norm, total_grad_norm = 0, 0, 0, 0, 0, 0
        start_time = time.time()
        train_loss = np.inf
        epoch_accuracy = []

        for batch_index, batch in enumerate(train_iter):
            optimizer.zero_grad()

            current_batch_word_cnt = torch.sum(batch.tgt[1])
            x = batch.src
            y = batch.tgt[0][:, 1:]
            batch_size = y.size(0)
            if config.reward_mode != 'bleu':
                premise = batch.premise
                hypothesis = batch.hypothesis
                isSrcPremise = batch.isSrcPremise
                label = batch.labels

            # |x| = (batch_size, length)
            # |y| = (batch_size, length)

            # Take sampling process because set False for is_greedy.
            y_hat, indice = model.search(x,
                                         is_greedy=False,
                                         max_length=config.max_length)

            if config.reward_mode == 'bleu':
                q_actor = get_bleu_reward(y, indice, n_gram=config.rl_n_gram)
                epoch_accuracy.append(q_actor.sum() / batch_size)
            else:
                padded_indice, padded_premise, padded_hypothesis = padding_three_tensors(
                    indice, premise, hypothesis, batch_size)

                # put pred sentece into either premise and hypothesis
                for i in range(batch_size):
                    if not isSrcPremise[i]:
                        padded_premise[i] = padded_indice[i]
                    else:
                        padded_hypothesis[i] = padded_indice[i]

                kwargs = {'p': padded_premise, 'h': padded_hypothesis}
                pred_logit = bimpm(**kwargs)
                accuracy = get_accuracy(pred_logit, label)
                epoch_accuracy.append(accuracy)

                # Based on the result of sampling, get reward.
                if config.reward_mode == 'nli':
                    q_actor = -get_nli_reward(pred_logit, label, nli_criterion)
                else:
                    q_actor = 1/(2 * nli_weight.pow(2)) * -get_nli_reward(pred_logit, label, nli_criterion) \
                        + 1/(2 * bleu_weight.pow(2)) * (get_bleu_reward(y, indice, n_gram=config.rl_n_gram)/100) \
                        + torch.log(nli_weight * bleu_weight)
            # |y_hat| = (batch_size, length, output_size)
            # |indice| = (batch_size, length)
            # |q_actor| = (batch_size)

            # Take samples as many as n_samples, and get average rewards for them.
            # I figured out that n_samples = 1 would be enough.
            baseline = []
            with torch.no_grad():
                for i in range(config.n_samples):
                    _, sampled_indice = model.search(
                        x, is_greedy=False, max_length=config.max_length)

                    if config.reward_mode == 'bleu':
                        baseline_reward = get_bleu_reward(
                            y, sampled_indice, n_gram=config.rl_n_gram)
                        epoch_accuracy.append(baseline_reward.sum() /
                                              batch_size)
                    else:
                        padded_sampled_indice, padded_premise, padded_hypothesis = padding_three_tensors(
                            sampled_indice, premise, hypothesis, batch_size)

                        # put pred sentece into either premise and hypothesis
                        for i in range(batch_size):
                            if not isSrcPremise[i]:
                                padded_premise[i] = padded_sampled_indice[i]
                            else:
                                padded_hypothesis[i] = padded_sampled_indice[i]

                        kwargs = {'p': padded_premise, 'h': padded_hypothesis}
                        pred_logit = bimpm(**kwargs)
                        accuracy = get_accuracy(pred_logit, label)
                        epoch_accuracy.append(accuracy)

                        # Based on the result of sampling, get reward.
                        if config.reward_mode == 'nli':
                            baseline_reward = -get_nli_reward(
                                pred_logit, label, nli_criterion)
                        else:
                            baseline_reward = 1/(2 * nli_weight.pow(2)) * -get_nli_reward(pred_logit, label, nli_criterion) \
                                + 1/(2 * bleu_weight.pow(2)) * (get_bleu_reward(y, sampled_indice, n_gram=config.rl_n_gram)/100) \
                                + torch.log(nli_weight * bleu_weight)

                    baseline += [baseline_reward]
                baseline = torch.stack(baseline).sum(dim=0).div(
                    config.n_samples)
                # |baseline| = (n_samples, batch_size) --> (batch_size)

            # Now, we have relatively expected cumulative reward.
            # Which score can be drawn from q_actor subtracted by baseline.
            tmp_reward = q_actor - baseline
            # |tmp_reward| = (batch_size)
            # calcuate gradients with back-propagation
            get_gradient(indice, y_hat, criterion, reward=tmp_reward)

            # simple math to show stats
            total_loss += float(tmp_reward.sum())
            total_actor_loss += float(q_actor.sum())
            total_sample_count += batch_size
            total_word_count += int(current_batch_word_cnt)
            total_parameter_norm += float(
                utils.get_parameter_norm(model.parameters()))
            total_grad_norm += float(utils.get_grad_norm(model.parameters()))

            if (batch_index + 1) % config.print_every == 0:
                avg_loss = total_loss / total_sample_count
                avg_actor_loss = total_actor_loss / total_sample_count
                avg_parameter_norm = total_parameter_norm / config.print_every
                avg_grad_norm = total_grad_norm / config.print_every
                avg_epoch_accuracy = sum(epoch_accuracy) / len(epoch_accuracy)
                elapsed_time = time.time() - start_time

                print(
                    "epoch: %d batch: %d/%d\t|param|: %.2f\t|g_param|: %.2f\trwd: %.4f\tactor loss: %.4f\tAccuracy: %.2f\t%5d words/s %3d secs"
                    %
                    (epoch, batch_index + 1,
                     int(
                         len(train_iter.dataset.examples) //
                         config.batch_size), avg_parameter_norm, avg_grad_norm,
                     avg_loss, avg_actor_loss, avg_epoch_accuracy,
                     total_word_count // elapsed_time, elapsed_time))

                if config.reward_mode == 'combined':
                    print("nli_weight, bleu_weight:",
                          nli_weight.data.cpu().numpy()[0],
                          bleu_weight.data.cpu().numpy()[0])

                total_loss, total_actor_loss, total_sample_count, total_word_count, total_parameter_norm, total_grad_norm = 0, 0, 0, 0, 0, 0
                epoch_accuracy = []
                start_time = time.time()

                train_loss = avg_actor_loss

            # In orther to avoid gradient exploding, we apply gradient clipping.
            torch_utils.clip_grad_norm_(model.parameters(),
                                        config.max_grad_norm)
            # Take a step of gradient descent.
            optimizer.step()
            if config.reward_mode == 'combined':
                weight_optimizer.step()

            sample_cnt += batch_size
            if sample_cnt >= len(train_iter.dataset.examples):
                break

        sample_cnt = 0
        total_reward = 0

        # Start validation
        with torch.no_grad():
            model.eval()  # Turn-off drop-out

            for batch_index, batch in enumerate(valid_iter):
                current_batch_word_cnt = torch.sum(batch.tgt[1])
                x = batch.src
                y = batch.tgt[0][:, 1:]
                batch_size = y.size(0)
                # |x| = (batch_size, length)
                # |y| = (batch_size, length)

                # feed-forward
                y_hat, indice = model.search(x,
                                             is_greedy=True,
                                             max_length=config.max_length)
                # |y_hat| = (batch_size, length, output_size)
                # |indice| = (batch_size, length)

                reward = get_bleu_reward(y, indice, n_gram=config.rl_n_gram)

                total_reward += float(reward.sum())
                sample_cnt += batch_size
                if sample_cnt >= len(valid_iter.dataset.examples):
                    break

            avg_bleu = total_reward / sample_cnt
            print("valid BLEU: %.4f" % avg_bleu)

            if highest_valid_bleu < avg_bleu:
                highest_valid_bleu = avg_bleu
                no_improve_cnt = 0
            else:
                no_improve_cnt += 1

            if valid_nli_iter:
                nli_validation(valid_nli_iter, model, bimpm, config)
            model.train()

        model_fn = config.model.split(".")
        model_fn = model_fn[:-1] + [
            "%02d" % (config.n_epochs + epoch),
            "%.2f-%.4f" % (train_loss, avg_bleu)
        ] + [model_fn[-1]] + [config.reward_mode]

        # PyTorch provides efficient method for save and load model, which uses python pickle.
        to_save = {
            "model": model.state_dict(),
            "config": config,
            "epoch": config.n_epochs + epoch + 1,
            "current_lr": current_lr
        }
        if others_to_save is not None:
            for k, v in others_to_save.items():
                to_save[k] = v
        torch.save(to_save, '.'.join(model_fn))

        if config.early_stop > 0 and no_improve_cnt > config.early_stop:
            break
Beispiel #19
0
def train_epoch(model, criterion, train_iter, valid_iter, config):
    current_lr = config.lr

    lowest_valid_loss = np.inf
    no_improve_cnt = 0

    for epoch in range(1, config.n_epochs):
        optimizer = optim.SGD(model.parameters(), lr=current_lr)
        print("current learning rate: %f" % current_lr)
        print(optimizer)

        sample_cnt = 0
        total_loss, total_word_count, total_parameter_norm, total_grad_norm = 0, 0, 0, 0
        start_time = time.time()
        train_loss = np.inf

        for batch_index, batch in enumerate(train_iter):
            optimizer.zero_grad()

            batch = prepare_data(batch)

            current_batch_word_cnt = torch.sum(batch[1])
            # Most important lines in this method.
            # Since model takes BOS + sentence as an input and sentence + EOS as an output,
            # x(input) excludes last index, and y(index) excludes first index.
            x = batch[0][:, :-1]
            y = batch[0][:, 1:]
            # feed-forward
            hidden = model.init_hidden(config.batch_size)
            # print("hidden : ", hidden[0].shape, hidden[1].shape)
            # print("x : ", x.shape)
            # print("batch[1]", batch[1])
            y_hat = model(x, batch[1], hidden)

            # calcuate loss and gradients with back-propagation
            loss = get_loss(y, y_hat, criterion)

            # simple math to show stats
            total_loss += float(loss)
            total_word_count += int(current_batch_word_cnt)
            total_parameter_norm += float(
                utils.get_parameter_norm(model.parameters()))
            total_grad_norm += float(utils.get_grad_norm(model.parameters()))

            if (batch_index + 1) % config.print_every == 0:
                avg_loss = total_loss / total_word_count
                avg_parameter_norm = total_parameter_norm / config.print_every
                avg_grad_norm = total_grad_norm / config.print_every
                elapsed_time = time.time() - start_time

                print(
                    "epoch: %d batch: %d/%d\t|param|: %.2f\t|g_param|: %.2f\tloss: %.4f\tPPL: %.2f\t%5d words/s %3d secs"
                    % (epoch, batch_index + 1,
                       int((len(train_iter.dataset) // config.batch_size)),
                       avg_parameter_norm, avg_grad_norm, avg_loss,
                       np.exp(avg_loss), total_word_count // elapsed_time,
                       elapsed_time))

                total_loss, total_word_count, total_parameter_norm, total_grad_norm = 0, 0, 0, 0
                start_time = time.time()

                train_loss = avg_loss

            # Another important line in this method.
            # In orther to avoid gradient exploding, we apply gradient clipping.
            torch_utils.clip_grad_norm_(model.parameters(),
                                        config.max_grad_norm)
            # Take a step of gradient descent.
            optimizer.step()

            sample_cnt += batch[0].size(0)
            if sample_cnt >= len(train_iter.dataset):
                break

        sample_cnt = 0
        total_loss, total_word_count = 0, 0

        model.eval()
        for batch_index, batch in enumerate(valid_iter):
            batch = prepare_data(batch)
            current_batch_word_cnt = torch.sum(batch[1])
            x = batch[0][:, :-1]
            y = batch[0][:, 1:]
            hidden = model.init_hidden(config.batch_size)

            y_hat = model(x, batch[1], hidden)

            loss = get_loss(y, y_hat, criterion, do_backward=False)

            total_loss += float(loss)
            total_word_count += int(current_batch_word_cnt)

            sample_cnt += batch[0].size(0)
            if sample_cnt >= len(valid_iter.dataset):
                break

        avg_loss = total_loss / total_word_count
        print("valid loss: %.4f\tPPL: %.2f" % (avg_loss, np.exp(avg_loss)))

        if lowest_valid_loss > avg_loss:
            lowest_valid_loss = avg_loss
            no_improve_cnt = 0
        else:
            # decrease learing rate if there is no improvement.
            current_lr /= 10.
            no_improve_cnt += 1

        model.train()

        # model_fn = config.model.split(".")
        model_fn = config.model  # model name
        model_fn = [model_fn[:-1]] + [
            "%02d" % epoch,
            "%.2f-%.2f" % (train_loss, np.exp(train_loss)),
            "%.2f-%.2f" % (avg_loss, np.exp(avg_loss))
        ] + [model_fn[-1]]

        # PyTorch provides efficient method for save and load model, which uses python pickle.
        torch.save(
            {
                "model": model.state_dict(),
                "config": config,
                "epoch": epoch + 1,
                "current_lr": current_lr
            }, ".".join(model_fn))

        if config.early_stop > 0 and no_improve_cnt > config.early_stop:
            break
Beispiel #20
0
    def step(engine, mini_batch):
        from utils import get_grad_norm, get_parameter_norm

        for language_model, model, optimizer in zip(engine.language_models,
                                                    engine.models,
                                                    engine.optimizers):
            language_model.eval()
            model.train()
            optimizer.zero_grad()

        # X2Y
        x, y = (mini_batch.src[0][:, 1:-1],
                mini_batch.src[1] - 2), mini_batch.tgt[0][:, :-1]
        # |x| = (batch_size, n)
        # |y| = (batch_size, m)
        y_hat = engine.models[X2Y](x, y)
        # |y_hat| = (batch_size, m, y_vocab_size)
        with torch.no_grad():
            p_hat_y = engine.language_models[X2Y](y)
            # |p_hat_y| = |y_hat|

        #Y2X
        # Since encoder in seq2seq takes packed_sequence instance,
        # we need to re-sort if we use reversed src and tgt.
        x, y, restore_indice = DualSupervisedTrainer._reordering(
            mini_batch.src[0][:, :-1],
            mini_batch.tgt[0][:, 1:-1],
            mini_batch.tgt[1] - 2,
        )
        # |x| = (batch_size, n)
        # |y| = (batch_size, m)
        x_hat = engine.models[Y2X](y, x).index_select(dim=0,
                                                      index=restore_indice)
        # |x_hat| = (batch_size, n, x_vocab_size)

        with torch.no_grad():
            p_hat_x = engine.language_models[Y2X](x).index_select(
                dim=0, index=restore_indice)
            # |p_hat_x| = |x_hat|

        x, y = mini_batch.src[0][:, 1:], mini_batch.tgt[0][:, 1:]
        loss_x2y, loss_y2x, dual_loss = DualSupervisedTrainer._get_loss(
            x,
            y,
            x_hat,
            y_hat,
            engine.crits,
            p_hat_x,
            p_hat_y,
            # According to the paper, DSL should be warm-started.
            # Thus, we turn-off the regularization at the beginning.
            lagrange=engine.config.dsl_lambda
            if engine.state.epoch >= engine.config.n_epochs else .0)

        loss_x2y.div(y.size(0)).backward()
        loss_y2x.div(x.size(0)).backward()

        p_norm = float(
            get_parameter_norm(
                list(engine.models[X2Y].parameters()) +
                list(engine.models[Y2X].parameters())))
        g_norm = float(
            get_grad_norm(
                list(engine.models[X2Y].parameters()) +
                list(engine.models[Y2X].parameters())))

        for model, optimizer in zip(engine.models, engine.optimizers):
            torch_utils.clip_grad_norm_(
                model.parameters(),
                engine.config.max_grad_norm,
            )
            # Take a step of gradient descent.
            optimizer.step()

        return (
            float(loss_x2y / mini_batch.src[1].sum()),
            float(loss_y2x / mini_batch.tgt[1].sum()),
            float(dual_loss / x.size(0)),
            p_norm,
            g_norm,
        )
Beispiel #21
0
    def step(engine, mini_batch):
        from utils import get_grad_norm, get_parameter_norm

        # You have to reset the gradients of all model parameters
        # before to take another step in gradient descent.
        engine.model.train()
        engine.optimizer.zero_grad()

        # Raw target variable has both BOS and EOS token.
        # The output of sequence-to-sequence does not have BOS token.
        # Thus, remove BOS token for reference.
        x, y = mini_batch.src, mini_batch.tgt[0][:, 1:]
        # |x| = (batch_size, length)
        # |y| = (batch_size, length)

        # Take sampling process because set False for is_greedy.
        y_hat, indice = engine.model.search(
            x, is_greedy=False, max_length=engine.config.max_length)
        # Based on the result of sampling, get reward.
        actor_reward = MinimumRiskTrainer.get_reward(
            indice, y, n_gram=engine.config.rl_n_gram)
        # |y_hat| = (batch_size, length, output_size)
        # |indice| = (batch_size, length)
        # |actor_reward| = (batch_size)

        # Take samples as many as n_samples, and get average rewards for them.
        # I figured out that n_samples = 1 would be enough.
        baseline = []
        with torch.no_grad():
            for _ in range(engine.config.rl_n_samples):
                _, sampled_indice = engine.model.search(
                    x,
                    is_greedy=False,
                    max_length=engine.config.max_length,
                )
                baseline += [
                    MinimumRiskTrainer.get_reward(
                        sampled_indice,
                        y,
                        n_gram=engine.config.rl_n_gram,
                    )
                ]

            baseline = torch.stack(baseline).sum(dim=0).div(
                engine.config.rl_n_samples)
            # |baseline| = (n_samples, batch_size) --> (batch_size)

        # Now, we have relatively expected cumulative reward.
        # Which score can be drawn from actor_reward subtracted by baseline.
        final_reward = actor_reward - baseline
        # |final_reward| = (batch_size)

        # calculate gradients with back-propagation
        MinimumRiskTrainer.get_gradient(y_hat,
                                        indice,
                                        engine.crit,
                                        reward=final_reward)

        p_norm = float(get_parameter_norm(engine.model.parameters()))
        g_norm = float(get_grad_norm(engine.model.parameters()))

        # In orther to avoid gradient exploding, we apply gradient clipping.
        torch_utils.clip_grad_norm_(
            engine.model.parameters(),
            engine.config.max_grad_norm,
        )
        # Take a step of gradient descent.
        engine.optimizer.step()

        return (
            float(actor_reward.mean()),
            float(baseline.mean()),
            float(final_reward.mean()),
            p_norm,
            g_norm,
        )