def train_one_epoch(config, model, data_loader, optimizer, epoch,
                    lr_scheduler):
    model.train()
    optimizer.zero_grad()

    num_steps = len(data_loader)
    batch_time = AverageMeter()
    loss_meter = AverageMeter()
    norm_meter = AverageMeter()

    start = time.time()
    end = time.time()
    for idx, (samples_1, samples_2, targets) in enumerate(data_loader):
        samples_1 = samples_1.cuda(non_blocking=True)
        samples_2 = samples_2.cuda(non_blocking=True)
        targets = targets.cuda(non_blocking=True)

        loss = model(samples_1, samples_2)

        optimizer.zero_grad()
        if config.AMP_OPT_LEVEL != "O0":
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            if config.TRAIN.CLIP_GRAD:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    amp.master_params(optimizer), config.TRAIN.CLIP_GRAD)
            else:
                grad_norm = get_grad_norm(amp.master_params(optimizer))
        else:
            loss.backward()
            if config.TRAIN.CLIP_GRAD:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), config.TRAIN.CLIP_GRAD)
            else:
                grad_norm = get_grad_norm(model.parameters())
        optimizer.step()
        lr_scheduler.step_update(epoch * num_steps + idx)

        torch.cuda.synchronize()

        loss_meter.update(loss.item(), targets.size(0))
        norm_meter.update(grad_norm)
        batch_time.update(time.time() - end)
        end = time.time()

        if idx % config.PRINT_FREQ == 0:
            lr = optimizer.param_groups[0]['lr']
            memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
            etas = batch_time.avg * (num_steps - idx)
            logger.info(
                f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t'
                f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t'
                f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t'
                f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
                f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t'
                f'mem {memory_used:.0f}MB')
    epoch_time = time.time() - start
    logger.info(
        f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}"
    )
Example #2
0
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    count = 0

    conv_param_names = []
    conv_params = []
    for name, param in net.named_parameters():
        if "conv" in name:
            conv_params += [param]
            conv_param_names += [name]

    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        # if random.uniform(0,1) < 0.2 and count < 5:
        #     count +=1
        #     get_gradient_stats(net, epoch, batch_idx)

        if batch_idx % 10 == 0:
            # conv params
            param_stats, bin_counts = get_param_stats(conv_params,
                                                      conv_param_names)
            grad_norm_stats = get_grad_norm(conv_params, conv_param_names)
            log_stats(param_stats,
                      bin_counts,
                      grad_norm_stats,
                      dir="GradientStatsPercentile_Abs_Norm",
                      epoch=epoch,
                      iteration=batch_idx)
            param_stats, bin_counts = get_param_stats(conv_params,
                                                      conv_param_names,
                                                      take_abs=True)
            grad_norm_stats = get_grad_norm(conv_params, conv_param_names)
            log_stats(param_stats,
                      bin_counts,
                      grad_norm_stats,
                      dir="GradientStatsPercentile_Abs_Norm",
                      epoch=epoch,
                      iteration=batch_idx,
                      param_file="PerParamStatsAbs.log",
                      bin_counts_file="OverallStatsAbs.log",
                      grad_norm_file="GradNormStatsAbs.log")

        optimizer.step()
        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar(
            batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' %
            (train_loss /
             (batch_idx + 1), 100. * correct / total, correct, total))
Example #3
0
    def train(engine, mini_batch):
        # You have to reset the gradients of all model parameters
        # before to take another step in gradient descent.
        engine.model.train() # Because we assign model as class variable, we can easily access to it.
        engine.optimizer.zero_grad()

        x, y = mini_batch
        x, y = x.to(engine.device), y.to(engine.device)

        # Take feed-forward
        y_hat = engine.model(x)

        loss = engine.crit(y_hat, y)
        loss.backward()

        # Calculate accuracy only if 'y' is LongTensor,
        # which means that 'y' is one-hot representation.
        if isinstance(y, torch.LongTensor) or isinstance(y, torch.cuda.LongTensor):
            accuracy = (torch.argmax(y_hat, dim=-1) == y).sum() / float(y.size(0))
        else:
            accuracy = 0

        p_norm = float(get_parameter_norm(engine.model.parameters()))
        g_norm = float(get_grad_norm(engine.model.parameters()))

        # Take a step of gradient descent.
        engine.optimizer.step()

        return {
            'loss': float(loss),
            'accuracy': float(accuracy),
            '|param|': p_norm,
            '|g_param|': g_norm,
        }
    def train(engine, mini_batch):
        engine.model.train()
        engine.optimizer.zero_grad()

        x, y = mini_batch
        x, y = x.to(engine.device), y.to(engine.device)

        y_hat = engine.model(x)
        y_hat = y_hat.view(-1)

        loss = engine.crit(y_hat, y)
        loss.backward()

        if isinstance(y, torch.LongTensor) or isinstance(
                y, torch.cuda.LongTensor):
            accuracy = (torch.argmax(y_hat, dim=-1) == y).sum() / float(
                y.size(0))
        else:
            accuracy = 0

        # 커지는 걸 보면서 모델이 학습을 하고 있다는걸 확인
        p_nrom = float(get_parameter_norm(engine.model.parameters()))

        # gradient의 크기가 작아지는 걸 보면서 descent가 이루어지는 걸 확인
        g_norm = float(get_grad_norm(engine.model.parameters()))

        engine.optimizer.step()

        return {
            "loss": loss,
            "accuracy": accuracy,
            "|param|": p_nrom,
            "|g_param|": g_norm
        }
Example #5
0
def train_reorder_dream():
    dr_model.train()  # turn on training mode for dropout
    dr_hidden = dr_model.init_hidden(dr_config.batch_size)

    total_loss = 0
    start_time = time()
    num_batchs = ceil(len(train_ub) / dr_config.batch_size)
    for i, x in enumerate(batchify(train_ub, dr_config.batch_size, is_reordered=True)):
        baskets, lens, ids, r_baskets, h_baskets = x
        dr_hidden = repackage_hidden(dr_hidden)  # repackage hidden state for RNN
        dr_model.zero_grad()  # optim.zero_grad()
        dynamic_user, _ = dr_model(baskets, lens, dr_hidden)
        loss = reorder_bpr_loss(r_baskets, h_baskets, dynamic_user, dr_model.encode.weight, dr_config)

        try:
            loss.backward()
        except RuntimeError:  # for debugging
            print('caching')
            tmp = {'baskets': baskets, 'ids': ids, 'r_baskets': r_baskets, 'h_baskets': h_baskets,
                   'dynamic_user': dynamic_user, 'item_embedding': dr_model.encode.weight}
            print(baskets)
            print(ids)
            print(r_baskets)
            print(h_baskets)
            print(dr_model.encode.weight)
            print(dynamic_user.data)
            with open('tmp.pkl', 'wb') as f:
                pickle.dump(tmp, f, pickle.HIGHEST_PROTOCOL)
            break

        # Clip to avoid gradient exploding
        torch.nn.utils.clip_grad_norm(dr_model.parameters(), dr_config.clip)

        # Parameter updating
        # manual SGD
        # for p in dr_model.parameters(): # Update parameters by -lr*grad
        #    p.data.add_(- dr_config.learning_rate, p.grad.data)
        # adam
        grad_norm = get_grad_norm(dr_model)
        previous_params = deepcopy(list(dr_model.parameters()))
        optim.step()

        total_loss += loss.data
        params = deepcopy(list(dr_model.parameters()))
        delta = get_weight_update(previous_params, params)
        weight_update_ratio = get_ratio_update(delta, params)

        # Logging
        if i % dr_config.log_interval == 0 and i > 0:
            elapsed = (time() - start_time) * 1000 / dr_config.log_interval
            cur_loss = total_loss[0] / dr_config.log_interval / dr_config.batch_size # turn tensor into float
            total_loss = 0
            start_time = time()
            print(
                '[Training]| Epochs {:3d} | Batch {:5d} / {:5d} | ms/batch {:02.2f} | Loss {:05.2f} |'.format(epoch, i,
                                                                                                              num_batchs,
                                                                                                              elapsed,
                                                                                                              cur_loss))
Example #6
0
    def train_epoch(self, train, optimizer):
        '''
        Train an epoch with given train iterator and optimizer.
        '''
        total_loss, total_word_count = 0, 0
        total_grad_norm = 0
        avg_loss, avg_grad_norm = 0, 0
        sample_cnt = 0

        # Iterate whole train-set.
        for idx, mini_batch in enumerate(train):
            # Raw target variable has both BOS and EOS token.
            # The output of sequence-to-sequence does not have BOS token.
            # Thus, remove BOS token for reference.
            x, y = mini_batch.src, mini_batch.tgt[0][:, 1:]
            # |x| = (batch_size, length)
            # |y| = (batch_size, length)

            # You have to reset the gradients of all model parameters before to take another step in gradient descent.
            optimizer.zero_grad()

            # Take feed-forward
            # Similar as before, the input of decoder does not have EOS token.
            # Thus, remove EOS token for decoder input.
            y_hat = self.model(x, mini_batch.tgt[0][:, :-1])
            # |y_hat| = (batch_size, length, output_size)

            # Calcuate loss and gradients with back-propagation.
            loss = self._get_loss(y_hat, y)
            loss.div(y.size(0)).backward()

            # Simple math to show stats.
            # Don't forget to detach final variables.
            total_loss += float(loss)
            total_word_count += int(mini_batch.tgt[1].sum())
            param_norm = float(
                utils.get_parameter_norm(self.model.parameters()))
            total_grad_norm += float(
                utils.get_grad_norm(self.model.parameters()))

            avg_loss = total_loss / total_word_count
            avg_grad_norm = total_grad_norm / (idx + 1)

            # In orther to avoid gradient exploding, we apply gradient clipping.
            torch_utils.clip_grad_norm_(self.model.parameters(),
                                        self.config.max_grad_norm)
            # Take a step of gradient descent.
            optimizer.step()

            sample_cnt += mini_batch.tgt[0].size(0)

            if idx >= len(train) * self.config.train_ratio_per_epoch:
                break

        return avg_loss, param_norm, avg_grad_norm
Example #7
0
    def train_epoch(self, 
                    train, 
                    optimizer, 
                    batch_size=64, 
                    verbose=VERBOSE_SILENT
                    ):
        '''
        Train an epoch with given train iterator and optimizer.
        '''
        total_loss, total_param_norm, total_grad_norm = 0, 0, 0
        avg_loss, avg_param_norm, avg_grad_norm = 0, 0, 0
        sample_cnt = 0

        progress_bar = tqdm(train, 
                            desc='Training: ', 
                            unit='batch'
                            ) if verbose is VERBOSE_BATCH_WISE else train
        # Iterate whole train-set.
        for idx, mini_batch in enumerate(progress_bar):
            x, y = mini_batch.text, mini_batch.label
            # Don't forget make grad zero before another back-prop.
            optimizer.zero_grad()

            y_hat = self.model(x)

            loss = self.get_loss(y_hat, y)
            loss.backward()

            total_loss += loss
            total_param_norm += utils.get_parameter_norm(self.model.parameters())
            total_grad_norm += utils.get_grad_norm(self.model.parameters())

            # Caluclation to show status
            avg_loss = total_loss / (idx + 1)
            avg_param_norm = total_param_norm / (idx + 1)
            avg_grad_norm = total_grad_norm / (idx + 1)

            if verbose is VERBOSE_BATCH_WISE:
                progress_bar.set_postfix_str('|param|=%.2f |g_param|=%.2f loss=%.4e' % (avg_param_norm,
                                                                                        avg_grad_norm,
                                                                                        avg_loss
                                                                                        ))

            optimizer.step()

            sample_cnt += mini_batch.text.size(0)
            if sample_cnt >= len(train.dataset.examples):
                break

        if verbose is VERBOSE_BATCH_WISE:
            progress_bar.close()

        return avg_loss, avg_param_norm, avg_grad_norm
Example #8
0
def train_dream():
    dr_model.train()  # turn on training mode for dropout
    dr_hidden = dr_model.init_hidden(dr_config.batch_size)
    total_loss = 0
    start_time = time()
    num_batchs = ceil(len(train_ub) / dr_config.batch_size)
    for i, x in enumerate(batchify(train_ub, dr_config.batch_size)):
        baskets, lens, _ = x
        dr_hidden = repackage_hidden(
            dr_hidden)  # repackage hidden state for RNN
        dr_model.zero_grad()  # optim.zero_grad()
        dynamic_user, _ = dr_model(baskets, lens, dr_hidden)
        loss = bpr_loss(baskets, dynamic_user, dr_model.encode.weight,
                        dr_config)
        loss.backward()

        # Clip to avoid gradient exploding
        torch.nn.utils.clip_grad_norm(dr_model.parameters(), dr_config.clip)

        # Parameter updating
        # manual SGD
        # for p in dr_model.parameters(): # Update parameters by -lr*grad
        #    p.data.add_(- dr_config.learning_rate, p.grad.data)
        # adam
        grad_norm = get_grad_norm(dr_model)
        previous_params = deepcopy(list(dr_model.parameters()))
        optim.step()

        total_loss += loss.data
        params = deepcopy(list(dr_model.parameters()))
        delta = get_weight_update(previous_params, params)
        weight_update_ratio = get_ratio_update(delta, params)

        # Logging
        if i % dr_config.log_interval == 0 and i > 0:
            elapsed = (time() - start_time) * 1000 / dr_config.log_interval
            cur_loss = total_loss.item(
            ) / dr_config.log_interval / dr_config.batch_size  # turn tensor into float
            total_loss = 0
            start_time = time()
            print(
                '[Training]| Epochs {:3d} | Batch {:5d} / {:5d} | ms/batch {:02.2f} | Loss {:05.2f} |'
                .format(epoch, i, num_batchs, elapsed, cur_loss))
            writer.add_scalar('model/train_loss', cur_loss,
                              epoch * num_batchs + i)
            writer.add_scalar('model/grad_norm', grad_norm,
                              epoch * num_batchs + i)
            writer.add_scalar('model/weight_update_ratio', weight_update_ratio,
                              epoch * num_batchs + i)
Example #9
0
def train_model(epoch):
    model.train()

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    end = time.time()

    total = 0
    correct = 0
    for batch_idx, (inputs, inputs_id, targets) in enumerate(train_loader):
        if inputs.size(0) < args.batch_size:
            continue
        inputs, inputs_id, targets = inputs.to(device), inputs_id.to(
            device), targets.to(device)
        targets = targets.long().squeeze(-1)

        if args.model == 'charcnn':
            outputs = model(inputs)
        else:
            outputs = model(inputs_id)
        loss = F.cross_entropy(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        grad_norm = get_grad_norm(model)
        # torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        acc, _ = topk_accuracy(outputs, targets, topk=(1, 1))
        top1.update(acc[0].item(), args.batch_size)

        losses.update(loss.item(), args.batch_size)
        batch_time.update(time.time() - end)
        end = time.time()

        if batch_idx % args.print_freq == 0:
            print(
                'Train Epoch: {} [{}/{}]| Loss: {:.3f} | acc: {:.3f} | grad norm: {:.3f} | batch time: {:.3f}'
                .format(epoch, batch_idx, len(train_loader), losses.val,
                        top1.val, grad_norm, batch_time.avg))

    writer.add_scalar('log/train accuracy', top1.avg, epoch)
    writer.add_scalar('log/train loss', losses.avg, epoch)

    for name, param in model.named_parameters():
        writer.add_histogram(name, param.clone().cpu().data.numpy(), epoch)
Example #10
0
    def train(engine, mini_batch):
        from utils import get_grad_norm, get_parameter_norm

        # You have to reset the gradients of all model parameters
        # before to take another step in gradient descent.
        engine.model.train()
        engine.optimizer.zero_grad()

        # if 'is_src_target' is true, the trainer would train language model for source language.
        # For dsl case, both x and y has BOS and EOS tokens.
        # Thus, we need to remove BOS and EOS before the training.
        x = mini_batch.src[
            0][:, :-1] if engine.is_src_target else mini_batch.tgt[0][:, :-1]
        y = mini_batch.src[
            0][:, 1:] if engine.is_src_target else mini_batch.tgt[0][:, 1:]
        # |x| = |y| = (batch_size, length)

        y_hat = engine.model(x)
        # |y_hat| = (batch_size, length, output_size)

        loss = engine.crit(
            y_hat.contiguous().view(-1, y_hat.size(-1)),
            y.contiguous().view(-1),
        ).sum()
        loss.div(y.size(0)).backward()
        word_count = int(
            mini_batch.src[1].sum()) if engine.is_src_target else int(
                mini_batch.tgt[1].sum())

        p_norm = float(get_parameter_norm(engine.model.parameters()))
        g_norm = float(get_grad_norm(engine.model.parameters()))

        # In orther to avoid gradient exploding, we apply gradient clipping.
        torch_utils.clip_grad_norm_(
            engine.model.parameters(),
            engine.config.max_grad_norm,
        )
        # Take a step of gradient descent.
        engine.optimizer.step()

        return {
            'loss': float(loss / word_count),
            '|param|': p_norm,
            '|g_param|': g_norm,
        }
Example #11
0
    def step(engine, mini_batch):
        from utils import get_grad_norm, get_parameter_norm

        engine.model.train()
        engine.optimizer.zero_grad()

        x, y = mini_batch.text, mini_batch.label

        y_hat = engine.model(x)
        loss = engine.crit(y_hat, y)
        loss.backward()

        p_norm = float(get_parameter_norm(engine.model.parameters()))
        g_norm = float(get_grad_norm(engine.model.parameters()))

        engine.optimizer.step()

        return float(loss), p_norm, g_norm
Example #12
0
def train_model(epoch):
    model.train()

    batch_time = AverageMeter()
    losses = AverageMeter()

    end = time.time()
    ntokens = len(corpus.dictionary)

    hidden = (torch.zeros(2, args.batch_size, args.lstm_dim).to(device),
              torch.zeros(2, args.batch_size, args.lstm_dim).to(device))

    for batch, i in enumerate(
            range(0,
                  train_inputs.size(0) - 1, args.seq_length)):
        data, targets = get_batch(train_inputs, train_targets, i, args)
        data = data.to(device)
        targets = targets.to(device)

        model.zero_grad()
        hidden = [state.detach() for state in hidden]
        output, hidden = model(data, hidden)

        loss = F.cross_entropy(output, targets)
        loss.backward()
        grad_norm = get_grad_norm(model)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optimizer.step()

        losses.update(loss.item(), args.batch_size)
        batch_time.update(time.time() - end)
        end = time.time()

        if batch % args.print_freq == 0:
            print(
                'Train Epoch: {} [{}]| Loss: {:.3f} | pexplexity: {:.3f} | grad norm: {:.3f} | batch time: {:.3f}'
                .format(epoch, batch, losses.val, np.exp(losses.avg),
                        grad_norm, batch_time.avg))

    writer.add_scalar('log/train loss', losses.avg, epoch)
    writer.add_scalar('log/train perplexity', np.exp(losses.avg), epoch)

    for name, param in model.named_parameters():
        writer.add_histogram(name, param.clone().cpu().data.numpy(), epoch)
Example #13
0
    def train(engine, mini_batch):
        # You have to reset the gradients of all model parameters
        # before to take another step in gradient descent.
        engine.model.train(
        )  # Because we assign model as class variable, we can easily access to it.
        engine.optimizer.zero_grad()

        images, targets = mini_batch
        images = list(image.to(engine.device) for image in images)
        targets = [{k: v.to(engine.device)
                    for k, v in t.items()} for t in targets]

        # Take feed-forward

        losses = engine.model(images, targets)
        loss = sum(loss for loss in losses.values())
        loss.backward()

        # Calculate accuracy only if 'y' is LongTensor,
        # which means that 'y' is one-hot representation.
        if isinstance(targets[0]["boxes"], torch.LongTensor) or isinstance(
                targets[0]["boxes"], torch.cuda.LongTensor):
            accuracy = (torch.argmax(targets[0]["boxes"], dim=-1)
                        == y).sum() / float(targets[0]["boxes"].size(0))
        else:
            accuracy = 0

        p_norm = float(get_parameter_norm(engine.model.parameters()))
        g_norm = float(get_grad_norm(engine.model.parameters()))

        # Take a step of gradient descent.
        engine.optimizer.step()

        # Take a step of scheduler
        engine.scheduler.step(loss)

        return {
            'loss': float(loss),
            'accuracy': float(accuracy),
            '|param|': p_norm,
            '|g_param|': g_norm,
        }
Example #14
0
    def train(engine, mini_batch):
        # You have to reset the gradients of all model parameters
        # before to take another step in gradient descent.
        engine.model.train(
        )  # Because we assign model as class variable, we can easily access to it.
        engine.optimizer.zero_grad()

        x, y = mini_batch
        x, y = x.to(engine.device), y.to(engine.device)
        # 모델과 같은 device를 할당

        # Take feed-forward
        y_hat = engine.model(x)
        # y_hat = (bs, 10)# 10차원의 확률값이 나옴

        loss = engine.crit(y_hat, y)  # crit를 지나면 loss는 scaler값이 됨
        loss.backward()

        # Calculate accuracy only if 'y' is LongTensor,
        # which means that 'y' is one-hot representation.
        if isinstance(y, torch.LongTensor) or isinstance(
                y, torch.cuda.LongTensor):
            accuracy = (torch.argmax(y_hat, dim=-1) == y).sum() / float(
                y.size(0))
        else:
            accuracy = 0

        p_norm = float(get_parameter_norm(engine.model.parameters()))
        # 파라미터의 l2 norm (학습이 진행될수록 커져야함)
        g_norm = float(get_grad_norm(engine.model.parameters()))
        # gradient의 l2 norm (학습이 진행될수록 작아져야함)
        # p_norm, g_nomr을 통해 학습이 잘되고 있는지 판단하는 지표라고 생각

        engine.optimizer.step()
        # 경사하강 스텝을 수행하라는 코드

        return {
            'loss': float(loss),
            'accuracy': float(accuracy),
            '|param|': p_norm,
            '|g_param|': g_norm,
        }
Example #15
0
    def train(engine, mini_batch):
        from utils import get_grad_norm, get_parameter_norm

        # You have to reset the gradients of all model parameters
        # before to take another step in gradient descent.
        engine.model.train()        
        engine.optimizer.zero_grad()

        # Raw target variable has both BOS and EOS token. 
        # The output of sequence-to-sequence does not have BOS token. 
        # Thus, remove BOS token for reference.
        x, y = mini_batch.src, mini_batch.tgt[0][:, 1:]
        # |x| = (batch_size, length)
        # |y| = (batch_size, length)

        # Take feed-forward
        # Similar as before, the input of decoder does not have EOS token.
        # Thus, remove EOS token for decoder input.
        y_hat = engine.model(x, mini_batch.tgt[0][:, :-1])
        # |y_hat| = (batch_size, length, output_size)

        loss = engine.crit(y_hat.contiguous().view(-1, y_hat.size(-1)),
                           y.contiguous().view(-1)
                           )
        loss.div(y.size(0)).backward()
        word_count = int(mini_batch.tgt[1].sum())

        p_norm = float(get_parameter_norm(engine.model.parameters()))
        g_norm = float(get_grad_norm(engine.model.parameters()))

        # In orther to avoid gradient exploding, we apply gradient clipping.
        torch_utils.clip_grad_norm_(
            engine.model.parameters(),
            engine.config.max_grad_norm,
        )
        # Take a step of gradient descent.
        engine.optimizer.step()

        if engine.config.use_noam_decay and engine.lr_scheduler is not None:
            engine.lr_scheduler.step()

        return float(loss / word_count), p_norm, g_norm
Example #16
0
def optimize_from_buffer(model,
                         loss_fn,
                         optim,
                         repay_buffer,
                         epochs=1,
                         prefix=""):
    loss = torch.tensor(0.)
    grad_norm = torch.tensor(0.)
    extra = {}
    for _ in range(epochs):
        transition = repay_buffer.transpose()
        model.zero_grad()
        loss, extra = loss_fn(transition)
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(model, 1.)
        grad_norm = utils.get_grad_norm(model)
        optim.step()
    return {
        f"{prefix}/loss": loss.detach(),
        f"{prefix}/grad_norm": grad_norm,
        **extra
    }
Example #17
0
    def train(engine, mini_batch):
        engine.model.train()
        engine.optimizer.zero_grad()
        x, y = mini_batch
        # 학습하는 중간에 데이터를 GPU로 전송해줌
        x, y = x.to(engine.device), y.to(engine.device)

        pred_y = engine.model(x)

        loss = engine.crit(pred_y, y)
        loss.backward()

        # 회귀 예측의 경우, 어큐러시 측정이 불가능하므로 스킵해야 함
        # Y가 정수로 나온다면 Classification일 것이라 가정
        if isinstance(y, torch.LongTensor) or isinstance(
                y, torch.cuda.LongTensor):
            accuracy = (torch.argmax(pred_y, dim=-1) == y).sum() / float(
                y.size(0))
        else:
            accuracy = 0

        # 파라미터 놈: 학습이 진행될수록 점진적으로 커짐
        p_norm = float(get_parameter_norm(engine.model.parameters()))
        # 그래드 놈: 값이 크면 클수록 많이 배우고 있다는 뜻
        # 처음 시작할 때는 많이 배우기 때문에 값이 클 것임 (=기울기가 가파름)
        # 진행됨에 따라서 서서히 줄어드는 게 보편적 (물론 아닐 수도 있음)
        # 적었다 커졌다 날뛰거나, Nan으로 Loss 자체가 날라가서 학습이 실패할수도 있음
        # 즉, 학습의 안정성을 보장함
        g_norm = float(get_grad_norm(engine.model.parameters()))

        engine.optimizer.step()

        return {
            'loss': float(loss),
            'accuracy': float(accuracy),
            '|param|': p_norm,
            '|g_param|': g_norm,
        }
Example #18
0
def train_epoch(model, criterion, train_iter, valid_iter, config, start_epoch = 1, others_to_save = None):
    current_lr = config.rl_lr

    highest_valid_bleu = -np.inf
    no_improve_cnt = 0

    # Print initial valid BLEU before we start RL.
    model.eval()
    total_reward, sample_cnt = 0, 0
    for batch_index, batch in enumerate(valid_iter):
        current_batch_word_cnt = torch.sum(batch.tgt[1])
        x = batch.src
        y = batch.tgt[0][:, 1:]
        batch_size = y.size(0)
        # |x| = (batch_size, length)
        # |y| = (batch_size, length)

        # feed-forward
        y_hat, indice = model.search(x, is_greedy = True, max_length = config.max_length)
        # |y_hat| = (batch_size, length, output_size)
        # |indice| = (batch_size, length)

        reward = get_reward(y, indice, n_gram = config.rl_n_gram)

        total_reward += float(reward.sum())
        sample_cnt += batch_size
        if sample_cnt >= len(valid_iter.dataset.examples):
            break
    avg_bleu = total_reward / sample_cnt
    print("initial valid BLEU: %.4f" % avg_bleu) # You can figure-out improvement.
    model.train() # Now, begin training.

    # Start RL
    for epoch in range(start_epoch, config.rl_n_epochs + 1):
        #optimizer = optim.Adam(model.parameters(), lr = current_lr)
        optimizer = optim.SGD(model.parameters(), lr = current_lr) # Default hyper-parameter is set for SGD.
        print("current learning rate: %f" % current_lr)
        print(optimizer)

        sample_cnt = 0
        total_loss, total_bleu, total_sample_count, total_word_count, total_parameter_norm, total_grad_norm = 0, 0, 0, 0, 0, 0
        start_time = time.time()
        train_bleu = np.inf

        for batch_index, batch in enumerate(train_iter):
            optimizer.zero_grad()

            current_batch_word_cnt = torch.sum(batch.tgt[1])
            x = batch.src  
            y = batch.tgt[0][:, 1:]  # 정답, 앞의 BOS를 뺀다. - 샘플링으로 진행하기 떄문?
            batch_size = y.size(0)
            # |x| = (batch_size, length)
            # |y| = (batch_size, length)

            # Take sampling process because set False for is_greedy.
            y_hat, indice = model.search(x, is_greedy = False, max_length = config.max_length)
            # Based on the result of sampling, get reward.
            _a_actor = get_reward(y, indice, n_gram = config.rl_n_gram)
            # |y_hat| = (batch_size, length, output_size) ## 샘플링하기 전의 Softmax 값
            # |indice| = (batch_size, length) ## Long Tensor - 샘플링을 통해 만들어지다..?
            # |q_actor| = (batch_size)

            # Take samples as many as n_samples, and get average rewards for them.
            # I figured out that n_samples = 1 would be enough.
            baseline = []
            with torch.no_grad():
                for i in range(config.n_samples):
                    _, sampled_indice = model.search(x, is_greedy = False, max_length = config.max_length)
                    baseline += [get_reward(y, sampled_indice, n_gram = config.rl_n_gram)]
                baseline = torch.stack(baseline).sum(dim = 0).div(config.n_samples)
                # |baseline| = (n_samples, batch_size) --> (batch_size)

            # Now, we have relatively expected cumulative reward.
            # Which score can be drawn from q_actor subtracted by baseline.
            tmp_reward = q_actor - baseline
            # |tmp_reward| = (batch_size)
            # calcuate gradients with back-propagation
            get_gradient(indice, y_hat, criterion, reward = tmp_reward)

            # simple math to show stats
            total_loss += float(tmp_reward.sum())
            total_bleu += float(q_actor.sum())
            total_sample_count += batch_size
            total_word_count += int(current_batch_word_cnt)
            total_parameter_norm += float(utils.get_parameter_norm(model.parameters()))
            total_grad_norm += float(utils.get_grad_norm(model.parameters()))

            if (batch_index + 1) % config.print_every == 0:
                avg_loss = total_loss / total_sample_count
                avg_bleu = total_bleu / total_sample_count
                avg_parameter_norm = total_parameter_norm / config.print_every
                avg_grad_norm = total_grad_norm / config.print_every
                elapsed_time = time.time() - start_time

                print("epoch: %d batch: %d/%d\t|param|: %.2f\t|g_param|: %.2f\trwd: %.4f\tBLEU: %.4f\t%5d words/s %3d secs" % (epoch, 
                                                                                                            batch_index + 1, 
                                                                                                            int(len(train_iter.dataset.examples) // config.batch_size), 
                                                                                                            avg_parameter_norm, 
                                                                                                            avg_grad_norm, 
                                                                                                            avg_loss,
                                                                                                            avg_bleu,
                                                                                                            total_word_count // elapsed_time,
                                                                                                            elapsed_time
                                                                                                            ))

                total_loss, total_bleu, total_sample_count, total_word_count, total_parameter_norm, total_grad_norm = 0, 0, 0, 0, 0, 0
                start_time = time.time()

                train_bleu = avg_bleu

            # In orther to avoid gradient exploding, we apply gradient clipping.
            torch_utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
            # Take a step of gradient descent.
            optimizer.step()

            sample_cnt += batch_size
            if sample_cnt >= len(train_iter.dataset.examples):
                break

        sample_cnt = 0
        total_reward = 0

        # Start validation
        with torch.no_grad():
            model.eval() # Turn-off drop-out

            for batch_index, batch in enumerate(valid_iter):
                current_batch_word_cnt = torch.sum(batch.tgt[1])
                x = batch.src
                y = batch.tgt[0][:, 1:]
                batch_size = y.size(0)
                # |x| = (batch_size, length)
                # |y| = (batch_size, length)

                # feed-forward
                y_hat, indice = model.search(x, is_greedy = True, max_length = config.max_length)
                # |y_hat| = (batch_size, length, output_size)
                # |indice| = (batch_size, length)

                reward = get_reward(y, indice, n_gram = config.rl_n_gram)

                total_reward += float(reward.sum())
                sample_cnt += batch_size
                if sample_cnt >= len(valid_iter.dataset.examples):
                    break

            avg_bleu = total_reward / sample_cnt
            print("valid BLEU: %.4f" % avg_bleu)

            if highest_valid_bleu < avg_bleu:
                highest_valid_bleu = avg_bleu
                no_improve_cnt = 0
            else:
                no_improve_cnt += 1

            model.train()

        model_fn = config.model.split(".")
        model_fn = model_fn[:-1] + ["%02d" % (config.n_epochs + epoch), "%.2f-%.4f" % (train_bleu, avg_bleu)] + [model_fn[-1]]

        # PyTorch provides efficient method for save and load model, which uses python pickle.
        to_save = {"model": model.state_dict(),
                    "config": config,
                    "epoch": config.n_epochs + epoch + 1,
                    "current_lr": current_lr
                    }
        if others_to_save is not None:
            for k, v in others_to_save.items():
                to_save[k] = v
        torch.save(to_save, '.'.join(model_fn))

        if config.early_stop > 0 and no_improve_cnt > config.early_stop:
            break
Example #19
0
def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch,
                    mixup_fn, lr_scheduler, writer):
    model.train()
    optimizer.zero_grad()

    num_steps = len(data_loader)
    batch_time = AverageMeter()
    loss_meter = AverageMeter()
    norm_meter = AverageMeter()

    start = time.time()
    end = time.time()
    for idx, (samples, targets) in enumerate(data_loader):
        samples = samples.cuda(non_blocking=True)
        targets = targets.cuda(non_blocking=True)

        if mixup_fn is not None:
            samples, targets = mixup_fn(samples, targets)

        outputs = model(samples)

        if config.TRAIN.ACCUMULATION_STEPS > 1:
            loss = criterion(outputs, targets)
            loss = loss / config.TRAIN.ACCUMULATION_STEPS
            if config.AMP_OPT_LEVEL != "O0":
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                if config.TRAIN.CLIP_GRAD:
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), config.TRAIN.CLIP_GRAD)
                else:
                    grad_norm = get_grad_norm(amp.master_params(optimizer))
            else:
                loss.backward()
                if config.TRAIN.CLIP_GRAD:
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        model.parameters(), config.TRAIN.CLIP_GRAD)
                else:
                    grad_norm = get_grad_norm(model.parameters())
            if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:
                optimizer.step()
                optimizer.zero_grad()
                lr_scheduler.step_update(epoch * num_steps + idx)
            writer.add_scalar("train/loss",
                              scalar_value=loss,
                              global_step=(epoch * num_steps + idx))
            writer.add_scalar("train/lr",
                              scalar_value=optimizer.param_groups[0]['lr'],
                              global_step=epoch)
            writer.add_scalar("train/grad_norm",
                              scalar_value=grad_norm,
                              global_step=(epoch * num_steps + idx))
        else:
            loss = criterion(outputs, targets)
            optimizer.zero_grad()
            if config.AMP_OPT_LEVEL != "O0":
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                if config.TRAIN.CLIP_GRAD:
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), config.TRAIN.CLIP_GRAD)
                else:
                    grad_norm = get_grad_norm(amp.master_params(optimizer))
            else:
                loss.backward()
                if config.TRAIN.CLIP_GRAD:
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        model.parameters(), config.TRAIN.CLIP_GRAD)
                else:
                    grad_norm = get_grad_norm(model.parameters())
            optimizer.step()
            lr_scheduler.step_update(epoch * num_steps + idx)
            writer.add_scalar("train/loss",
                              scalar_value=loss,
                              global_step=(epoch * num_steps + idx))
            writer.add_scalar("train/lr",
                              scalar_value=optimizer.param_groups[0]['lr'],
                              global_step=epoch)
            writer.add_scalar("train/grad_norm",
                              scalar_value=grad_norm,
                              global_step=(epoch * num_steps + idx))

        torch.cuda.synchronize()

        loss_meter.update(loss.item(), targets.size(0))
        norm_meter.update(grad_norm)
        batch_time.update(time.time() - end)
        end = time.time()

        if idx % config.PRINT_FREQ == 0:
            lr = optimizer.param_groups[0]['lr']
            memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
            etas = batch_time.avg * (num_steps - idx)
            logger.info(
                f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t'
                f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t'
                f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t'
                f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
                f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t'
                f'mem {memory_used:.0f}MB')
    epoch_time = time.time() - start
    logger.info(
        f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}"
    )
Example #20
0
    def train_epoch(self,
                    train,
                    optimizer,
                    max_grad_norm=5,
                    verbose=VERBOSE_SILENT
                    ):
        '''
        Train an epoch with given train iterator and optimizer.
        '''
        total_reward, total_actor_reward = 0, 0
        total_grad_norm = 0
        avg_reward, avg_actor_reward = 0, 0
        avg_param_norm, avg_grad_norm = 0, 0
        sample_cnt = 0

        progress_bar = tqdm(train,
                            desc='Training: ',
                            unit='batch'
                            ) if verbose is VERBOSE_BATCH_WISE else train
        # Iterate whole train-set.
        for idx, mini_batch in enumerate(progress_bar):
            # Raw target variable has both BOS and EOS token. 
            # The output of sequence-to-sequence does not have BOS token. 
            # Thus, remove BOS token for reference.
            x, y = mini_batch.src, mini_batch.tgt[0][:, 1:]
            # |x| = (batch_size, length)
            # |y| = (batch_size, length)
            
            # You have to reset the gradients of all model parameters before to take another step in gradient descent.
            optimizer.zero_grad()

            # Take sampling process because set False for is_greedy.
            y_hat, indice = self.model.search(x,
                                              is_greedy=False,
                                              max_length=self.config.max_length
                                              )
            # Based on the result of sampling, get reward.
            actor_reward = self._get_reward(indice,
                                            y,
                                            n_gram=self.config.rl_n_gram
                                            )
            # |y_hat| = (batch_size, length, output_size)
            # |indice| = (batch_size, length)
            # |actor_reward| = (batch_size)

            # Take samples as many as n_samples, and get average rewards for them.
            # I figured out that n_samples = 1 would be enough.
            baseline = []
            with torch.no_grad():
                for i in range(self.config.n_samples):
                    _, sampled_indice = self.model.search(x,
                                                          is_greedy=False,
                                                          max_length=self.config.max_length
                                                          )
                    baseline += [self._get_reward(sampled_indice,
                                                  y,
                                                  n_gram=self.config.rl_n_gram
                                                  )]
                baseline = torch.stack(baseline).sum(dim=0).div(self.config.n_samples)
                # |baseline| = (n_samples, batch_size) --> (batch_size)

            # Now, we have relatively expected cumulative reward.
            # Which score can be drawn from actor_reward subtracted by baseline.
            final_reward = actor_reward - baseline
            # |final_reward| = (batch_size)

            # calcuate gradients with back-propagation
            self._get_gradient(y_hat, indice, reward=final_reward)
            
            # Simple math to show stats.
            total_reward += float(final_reward.sum())
            total_actor_reward += float(actor_reward.sum())
            sample_cnt += int(actor_reward.size(0))
            param_norm = float(utils.get_parameter_norm(self.model.parameters()))
            total_grad_norm += float(utils.get_grad_norm(self.model.parameters()))

            avg_reward = total_reward / sample_cnt
            avg_actor_reward = total_actor_reward / sample_cnt
            avg_grad_norm = total_grad_norm / (idx + 1)

            if verbose is VERBOSE_BATCH_WISE:
                progress_bar.set_postfix_str('|g_param|=%.2f rwd=%4.2f avg_frwd=%.2e BLEU=%.4f' % (avg_grad_norm,
                                                                                                 float(actor_reward.sum().div(y.size(0))),
                                                                                                 avg_reward,
                                                                                                 avg_actor_reward
                                                                                                 ))

            # In orther to avoid gradient exploding, we apply gradient clipping.
            torch_utils.clip_grad_norm_(self.model.parameters(),
                                        self.config.max_grad_norm
                                        )
            # Take a step of gradient descent.
            optimizer.step()


            if idx >= len(progress_bar) * self.config.train_ratio_per_epoch:
                break

        if verbose is VERBOSE_BATCH_WISE:
            progress_bar.close()

        return avg_actor_reward, param_norm, avg_grad_norm
Example #21
0
def train_epoch(model,
                bimpm,
                criterion,
                train_iter,
                valid_iter,
                config,
                start_epoch=1,
                others_to_save=None,
                valid_nli_iter=None):
    current_lr = config.rl_lr

    highest_valid_bleu = -np.inf
    no_improve_cnt = 0

    # Print initial valid BLEU before we start RL.
    model.eval()
    total_reward, sample_cnt = 0, 0
    for batch_index, batch in enumerate(valid_iter):
        current_batch_word_cnt = torch.sum(batch.tgt[1])
        x = batch.src
        y = batch.tgt[0][:, 1:]
        batch_size = y.size(0)
        # |x| = (batch_size, length)
        # |y| = (batch_size, length)

        # feed-forward
        y_hat, indice = model.search(x,
                                     is_greedy=True,
                                     max_length=config.max_length)
        # |y_hat| = (batch_size, length, output_size)
        # |indice| = (batch_size, length)

        reward = get_bleu_reward(y, indice, n_gram=config.rl_n_gram)

        total_reward += float(reward.sum())
        sample_cnt += batch_size
        if sample_cnt >= len(valid_iter.dataset.examples):
            break
    avg_bleu = total_reward / sample_cnt
    print("initial valid BLEU: %.4f" %
          avg_bleu)  # You can figure-out improvement.

    if valid_nli_iter:
        nli_validation(valid_nli_iter, model, bimpm, config)
    model.train()  # Now, begin training.

    # Start RL
    nli_criterion = nn.CrossEntropyLoss(reduce=False)
    print("start rl epoch:", start_epoch)
    print("number of epoch to complete:", config.rl_n_epochs + 1)

    if config.reward_mode == 'combined':
        if config.gpu_id >= 0:
            nli_weight = torch.tensor([1.0], requires_grad=True, device="cuda")
            bleu_weight = torch.tensor([1.0],
                                       requires_grad=True,
                                       device="cuda")
        else:
            nli_weight = torch.tensor([1.0], requires_grad=True)
            bleu_weight = torch.tensor([1.0], requires_grad=True)

        print("nli_weight, bleu_weight:",
              nli_weight.data.cpu().numpy()[0],
              bleu_weight.data.cpu().numpy()[0])
        weight_optimizer = optim.Adam(iter([nli_weight, bleu_weight]),
                                      lr=0.0001)

    optimizer = optim.SGD(
        model.parameters(),
        lr=current_lr,
    )  # Default hyper-parameter is set for SGD.
    print("current learning rate: %f" % current_lr)
    print(optimizer)

    for epoch in range(start_epoch, config.rl_n_epochs + 1):
        sample_cnt = 0
        total_loss, total_actor_loss, total_sample_count, total_word_count, total_parameter_norm, total_grad_norm = 0, 0, 0, 0, 0, 0
        start_time = time.time()
        train_loss = np.inf
        epoch_accuracy = []

        for batch_index, batch in enumerate(train_iter):
            optimizer.zero_grad()

            current_batch_word_cnt = torch.sum(batch.tgt[1])
            x = batch.src
            y = batch.tgt[0][:, 1:]
            batch_size = y.size(0)
            if config.reward_mode != 'bleu':
                premise = batch.premise
                hypothesis = batch.hypothesis
                isSrcPremise = batch.isSrcPremise
                label = batch.labels

            # |x| = (batch_size, length)
            # |y| = (batch_size, length)

            # Take sampling process because set False for is_greedy.
            y_hat, indice = model.search(x,
                                         is_greedy=False,
                                         max_length=config.max_length)

            if config.reward_mode == 'bleu':
                q_actor = get_bleu_reward(y, indice, n_gram=config.rl_n_gram)
                epoch_accuracy.append(q_actor.sum() / batch_size)
            else:
                padded_indice, padded_premise, padded_hypothesis = padding_three_tensors(
                    indice, premise, hypothesis, batch_size)

                # put pred sentece into either premise and hypothesis
                for i in range(batch_size):
                    if not isSrcPremise[i]:
                        padded_premise[i] = padded_indice[i]
                    else:
                        padded_hypothesis[i] = padded_indice[i]

                kwargs = {'p': padded_premise, 'h': padded_hypothesis}
                pred_logit = bimpm(**kwargs)
                accuracy = get_accuracy(pred_logit, label)
                epoch_accuracy.append(accuracy)

                # Based on the result of sampling, get reward.
                if config.reward_mode == 'nli':
                    q_actor = -get_nli_reward(pred_logit, label, nli_criterion)
                else:
                    q_actor = 1/(2 * nli_weight.pow(2)) * -get_nli_reward(pred_logit, label, nli_criterion) \
                        + 1/(2 * bleu_weight.pow(2)) * (get_bleu_reward(y, indice, n_gram=config.rl_n_gram)/100) \
                        + torch.log(nli_weight * bleu_weight)
            # |y_hat| = (batch_size, length, output_size)
            # |indice| = (batch_size, length)
            # |q_actor| = (batch_size)

            # Take samples as many as n_samples, and get average rewards for them.
            # I figured out that n_samples = 1 would be enough.
            baseline = []
            with torch.no_grad():
                for i in range(config.n_samples):
                    _, sampled_indice = model.search(
                        x, is_greedy=False, max_length=config.max_length)

                    if config.reward_mode == 'bleu':
                        baseline_reward = get_bleu_reward(
                            y, sampled_indice, n_gram=config.rl_n_gram)
                        epoch_accuracy.append(baseline_reward.sum() /
                                              batch_size)
                    else:
                        padded_sampled_indice, padded_premise, padded_hypothesis = padding_three_tensors(
                            sampled_indice, premise, hypothesis, batch_size)

                        # put pred sentece into either premise and hypothesis
                        for i in range(batch_size):
                            if not isSrcPremise[i]:
                                padded_premise[i] = padded_sampled_indice[i]
                            else:
                                padded_hypothesis[i] = padded_sampled_indice[i]

                        kwargs = {'p': padded_premise, 'h': padded_hypothesis}
                        pred_logit = bimpm(**kwargs)
                        accuracy = get_accuracy(pred_logit, label)
                        epoch_accuracy.append(accuracy)

                        # Based on the result of sampling, get reward.
                        if config.reward_mode == 'nli':
                            baseline_reward = -get_nli_reward(
                                pred_logit, label, nli_criterion)
                        else:
                            baseline_reward = 1/(2 * nli_weight.pow(2)) * -get_nli_reward(pred_logit, label, nli_criterion) \
                                + 1/(2 * bleu_weight.pow(2)) * (get_bleu_reward(y, sampled_indice, n_gram=config.rl_n_gram)/100) \
                                + torch.log(nli_weight * bleu_weight)

                    baseline += [baseline_reward]
                baseline = torch.stack(baseline).sum(dim=0).div(
                    config.n_samples)
                # |baseline| = (n_samples, batch_size) --> (batch_size)

            # Now, we have relatively expected cumulative reward.
            # Which score can be drawn from q_actor subtracted by baseline.
            tmp_reward = q_actor - baseline
            # |tmp_reward| = (batch_size)
            # calcuate gradients with back-propagation
            get_gradient(indice, y_hat, criterion, reward=tmp_reward)

            # simple math to show stats
            total_loss += float(tmp_reward.sum())
            total_actor_loss += float(q_actor.sum())
            total_sample_count += batch_size
            total_word_count += int(current_batch_word_cnt)
            total_parameter_norm += float(
                utils.get_parameter_norm(model.parameters()))
            total_grad_norm += float(utils.get_grad_norm(model.parameters()))

            if (batch_index + 1) % config.print_every == 0:
                avg_loss = total_loss / total_sample_count
                avg_actor_loss = total_actor_loss / total_sample_count
                avg_parameter_norm = total_parameter_norm / config.print_every
                avg_grad_norm = total_grad_norm / config.print_every
                avg_epoch_accuracy = sum(epoch_accuracy) / len(epoch_accuracy)
                elapsed_time = time.time() - start_time

                print(
                    "epoch: %d batch: %d/%d\t|param|: %.2f\t|g_param|: %.2f\trwd: %.4f\tactor loss: %.4f\tAccuracy: %.2f\t%5d words/s %3d secs"
                    %
                    (epoch, batch_index + 1,
                     int(
                         len(train_iter.dataset.examples) //
                         config.batch_size), avg_parameter_norm, avg_grad_norm,
                     avg_loss, avg_actor_loss, avg_epoch_accuracy,
                     total_word_count // elapsed_time, elapsed_time))

                if config.reward_mode == 'combined':
                    print("nli_weight, bleu_weight:",
                          nli_weight.data.cpu().numpy()[0],
                          bleu_weight.data.cpu().numpy()[0])

                total_loss, total_actor_loss, total_sample_count, total_word_count, total_parameter_norm, total_grad_norm = 0, 0, 0, 0, 0, 0
                epoch_accuracy = []
                start_time = time.time()

                train_loss = avg_actor_loss

            # In orther to avoid gradient exploding, we apply gradient clipping.
            torch_utils.clip_grad_norm_(model.parameters(),
                                        config.max_grad_norm)
            # Take a step of gradient descent.
            optimizer.step()
            if config.reward_mode == 'combined':
                weight_optimizer.step()

            sample_cnt += batch_size
            if sample_cnt >= len(train_iter.dataset.examples):
                break

        sample_cnt = 0
        total_reward = 0

        # Start validation
        with torch.no_grad():
            model.eval()  # Turn-off drop-out

            for batch_index, batch in enumerate(valid_iter):
                current_batch_word_cnt = torch.sum(batch.tgt[1])
                x = batch.src
                y = batch.tgt[0][:, 1:]
                batch_size = y.size(0)
                # |x| = (batch_size, length)
                # |y| = (batch_size, length)

                # feed-forward
                y_hat, indice = model.search(x,
                                             is_greedy=True,
                                             max_length=config.max_length)
                # |y_hat| = (batch_size, length, output_size)
                # |indice| = (batch_size, length)

                reward = get_bleu_reward(y, indice, n_gram=config.rl_n_gram)

                total_reward += float(reward.sum())
                sample_cnt += batch_size
                if sample_cnt >= len(valid_iter.dataset.examples):
                    break

            avg_bleu = total_reward / sample_cnt
            print("valid BLEU: %.4f" % avg_bleu)

            if highest_valid_bleu < avg_bleu:
                highest_valid_bleu = avg_bleu
                no_improve_cnt = 0
            else:
                no_improve_cnt += 1

            if valid_nli_iter:
                nli_validation(valid_nli_iter, model, bimpm, config)
            model.train()

        model_fn = config.model.split(".")
        model_fn = model_fn[:-1] + [
            "%02d" % (config.n_epochs + epoch),
            "%.2f-%.4f" % (train_loss, avg_bleu)
        ] + [model_fn[-1]] + [config.reward_mode]

        # PyTorch provides efficient method for save and load model, which uses python pickle.
        to_save = {
            "model": model.state_dict(),
            "config": config,
            "epoch": config.n_epochs + epoch + 1,
            "current_lr": current_lr
        }
        if others_to_save is not None:
            for k, v in others_to_save.items():
                to_save[k] = v
        torch.save(to_save, '.'.join(model_fn))

        if config.early_stop > 0 and no_improve_cnt > config.early_stop:
            break
Example #22
0
def train_model(net, optimizer, transition, args):

    actions = torch.Tensor(transition.action).long().to(device)
    rewards = torch.Tensor(transition.reward).to(device)
    masks = torch.Tensor(transition.mask).to(device)
    goals = torch.stack(transition.goal).to(device)
    policies = torch.stack(transition.policy).to(device)
    m_states = torch.stack(transition.m_state).to(device)
    m_values = torch.stack(transition.m_value).to(device)
    w_values_ext = torch.stack(transition.w_value_ext).to(device)
    w_values_int = torch.stack(transition.w_value_int).to(device)

    m_returns = get_returns(rewards, masks, args.m_gamma, m_values)
    w_returns = get_returns(rewards, masks, args.w_gamma, w_values_ext)

    intrinsic_rewards = torch.zeros_like(rewards).to(device)
    # todo: how to get intrinsic reward before 10 steps

    for i in range(args.horizon, len(rewards)):
        cos_sum = 0
        for j in range(1, args.horizon + 1):
            alpha = m_states[i] - m_states[i - j]
            beta = goals[i - j]
            cosine_sim = F.cosine_similarity(alpha, beta)
            cos_sum = cos_sum + cosine_sim
        intrinsic_reward = cos_sum / args.horizon
        intrinsic_rewards[i] = intrinsic_reward.detach()
    returns_int = get_returns(intrinsic_rewards, masks, args.w_gamma,
                              w_values_int)

    m_loss = torch.zeros_like(w_returns).to(device)
    w_loss = torch.zeros_like(m_returns).to(device)

    # todo: how to update manager near end state
    for i in range(0, len(rewards) - args.horizon):
        m_advantage = m_returns[i] - m_values[i].squeeze(-1)
        alpha = m_states[i + args.horizon] - m_states[i]
        beta = goals[i]
        cosine_sim = F.cosine_similarity(alpha.detach(), beta)
        m_loss[i] = -m_advantage * cosine_sim

        log_policy = torch.log(policies[i] + 1e-5)
        w_advantage = w_returns[i] + returns_int[i] - w_values_ext[i].squeeze(
            -1) - w_values_int[i].squeeze(-1)
        log_policy = log_policy.gather(-1, actions[i].unsqueeze(-1))
        w_loss[i] = -w_advantage * log_policy.squeeze(-1)

    m_loss = m_loss.mean()
    w_loss = w_loss.mean()
    m_loss_value = F.mse_loss(m_values.squeeze(-1), m_returns.detach())
    w_loss_value_ext = F.mse_loss(w_values_ext.squeeze(-1), w_returns.detach())
    w_loss_value_int = F.mse_loss(w_values_int.squeeze(-1),
                                  returns_int.detach())

    loss = w_loss + w_loss_value_ext + w_loss_value_int + m_loss + m_loss_value
    # TODO: Add entropy to loss for exploration

    optimizer.zero_grad()
    loss.backward(retain_graph=True)
    grad_norm = get_grad_norm(net)
    torch.nn.utils.clip_grad_norm(net.parameters(), args.clip_grad_norm)
    optimizer.step()
    return loss, grad_norm
Example #23
0
def train_one_epoch_distill(config,
                            model,
                            model_teacher,
                            data_loader,
                            optimizer,
                            epoch,
                            mixup_fn,
                            lr_scheduler,
                            criterion_soft=None,
                            criterion_truth=None,
                            criterion_attn=None,
                            criterion_hidden=None):

    layer_id_s_list = config.DISTILL.STUDENT_LAYER_LIST
    layer_id_t_list = config.DISTILL.TEACHER_LAYER_LIST

    model.train()
    optimizer.zero_grad()

    model_teacher.eval()

    num_steps = len(data_loader)
    batch_time = AverageMeter()
    loss_meter = AverageMeter()
    norm_meter = AverageMeter()
    loss_soft_meter = AverageMeter()
    loss_truth_meter = AverageMeter()
    loss_attn_meter = AverageMeter()
    loss_hidden_meter = AverageMeter()

    acc1_meter = AverageMeter()
    acc5_meter = AverageMeter()
    teacher_acc1_meter = AverageMeter()
    teacher_acc5_meter = AverageMeter()

    start = time.time()
    end = time.time()
    for idx, (samples, targets) in enumerate(data_loader):
        samples = samples.cuda(non_blocking=True)
        targets = targets.cuda(non_blocking=True)
        original_targets = targets

        if mixup_fn is not None:
            samples, targets = mixup_fn(samples, targets)

        if config.DISTILL.ATTN_LOSS and config.DISTILL.HIDDEN_LOSS:
            outputs, qkv_s, hidden_s = model(
                samples,
                layer_id_s_list,
                is_attn_loss=True,
                is_hidden_loss=True,
                is_hidden_org=config.DISTILL.HIDDEN_RELATION)
        elif config.DISTILL.ATTN_LOSS:
            outputs, qkv_s = model(
                samples,
                layer_id_s_list,
                is_attn_loss=True,
                is_hidden_loss=False,
                is_hidden_org=config.DISTILL.HIDDEN_RELATION)
        elif config.DISTILL.HIDDEN_LOSS:
            outputs, hidden_s = model(
                samples,
                layer_id_s_list,
                is_attn_loss=False,
                is_hidden_loss=True,
                is_hidden_org=config.DISTILL.HIDDEN_RELATION)
        else:
            outputs = model(samples)

        with torch.no_grad():
            acc1, acc5 = accuracy(outputs, original_targets, topk=(1, 5))
            if config.DISTILL.ATTN_LOSS or config.DISTILL.HIDDEN_LOSS:
                outputs_teacher, qkv_t, hidden_t = model_teacher(
                    samples,
                    layer_id_t_list,
                    is_attn_loss=True,
                    is_hidden_loss=True)
            else:
                outputs_teacher = model_teacher(samples)
            teacher_acc1, teacher_acc5 = accuracy(outputs_teacher,
                                                  original_targets,
                                                  topk=(1, 5))

        if config.TRAIN.ACCUMULATION_STEPS > 1:
            loss_truth = config.DISTILL.ALPHA * criterion_truth(
                outputs, targets)
            loss_soft = (1.0 - config.DISTILL.ALPHA) * criterion_soft(
                outputs / config.DISTILL.TEMPERATURE,
                outputs_teacher / config.DISTILL.TEMPERATURE)
            if config.DISTILL.ATTN_LOSS:
                loss_attn = config.DISTILL.QKV_LOSS_WEIGHT * criterion_attn(
                    qkv_s, qkv_t, config.DISTILL.AR)
            else:
                loss_attn = torch.zeros(loss_truth.shape)
            if config.DISTILL.HIDDEN_LOSS:
                loss_hidden = config.DISTILL.HIDDEN_LOSS_WEIGHT * criterion_hidden(
                    hidden_s, hidden_t)
            else:
                loss_hidden = torch.zeros(loss_truth.shape)
            loss = loss_truth + loss_soft + loss_attn + loss_hidden

            loss = loss / config.TRAIN.ACCUMULATION_STEPS
            if config.AMP_OPT_LEVEL != "O0":
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                if config.TRAIN.CLIP_GRAD:
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), config.TRAIN.CLIP_GRAD)
                else:
                    grad_norm = get_grad_norm(amp.master_params(optimizer))
            else:
                loss.backward()
                if config.TRAIN.CLIP_GRAD:
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        model.parameters(), config.TRAIN.CLIP_GRAD)
                else:
                    grad_norm = get_grad_norm(model.parameters())
            if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:
                optimizer.step()
                optimizer.zero_grad()
                lr_scheduler.step_update(epoch * num_steps + idx)
        else:
            loss_truth = config.DISTILL.ALPHA * criterion_truth(
                outputs, targets)
            loss_soft = (1.0 - config.DISTILL.ALPHA) * criterion_soft(
                outputs / config.DISTILL.TEMPERATURE,
                outputs_teacher / config.DISTILL.TEMPERATURE)
            if config.DISTILL.ATTN_LOSS:
                loss_attn = config.DISTILL.QKV_LOSS_WEIGHT * criterion_attn(
                    qkv_s, qkv_t, config.DISTILL.AR)
            else:
                loss_attn = torch.zeros(loss_truth.shape)
            if config.DISTILL.HIDDEN_LOSS:
                loss_hidden = config.DISTILL.HIDDEN_LOSS_WEIGHT * criterion_hidden(
                    hidden_s, hidden_t)
            else:
                loss_hidden = torch.zeros(loss_truth.shape)
            loss = loss_truth + loss_soft + loss_attn + loss_hidden

            optimizer.zero_grad()
            if config.AMP_OPT_LEVEL != "O0":
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                if config.TRAIN.CLIP_GRAD:
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), config.TRAIN.CLIP_GRAD)
                else:
                    grad_norm = get_grad_norm(amp.master_params(optimizer))
            else:
                loss.backward()
                if config.TRAIN.CLIP_GRAD:
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        model.parameters(), config.TRAIN.CLIP_GRAD)
                else:
                    grad_norm = get_grad_norm(model.parameters())
            optimizer.step()
            lr_scheduler.step_update(epoch * num_steps + idx)

        torch.cuda.synchronize()

        loss_meter.update(loss.item(), targets.size(0))
        loss_soft_meter.update(loss_soft.item(), targets.size(0))
        loss_truth_meter.update(loss_truth.item(), targets.size(0))
        loss_attn_meter.update(loss_attn.item(), targets.size(0))
        loss_hidden_meter.update(loss_hidden.item(), targets.size(0))
        norm_meter.update(grad_norm)
        batch_time.update(time.time() - end)
        end = time.time()

        acc1_meter.update(acc1.item(), targets.size(0))
        acc5_meter.update(acc5.item(), targets.size(0))
        teacher_acc1_meter.update(teacher_acc1.item(), targets.size(0))
        teacher_acc5_meter.update(teacher_acc5.item(), targets.size(0))

        if idx % config.PRINT_FREQ == 0:
            lr = optimizer.param_groups[0]['lr']
            memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
            etas = batch_time.avg * (num_steps - idx)
            logger.info(
                f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t'
                f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t'
                f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t'
                f'Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}\t'
                f'Teacher_Acc@1 {teacher_acc1_meter.avg:.3f} Teacher_Acc@5 {teacher_acc5_meter.avg:.3f}\t'
                f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
                f'loss_soft {loss_soft_meter.val:.4f} ({loss_soft_meter.avg:.4f})\t'
                f'loss_truth {loss_truth_meter.val:.4f} ({loss_truth_meter.avg:.4f})\t'
                f'loss_attn {loss_attn_meter.val:.4f} ({loss_attn_meter.avg:.4f})\t'
                f'loss_hidden {loss_hidden_meter.val:.4f} ({loss_hidden_meter.avg:.4f})\t'
                f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t'
                f'mem {memory_used:.0f}MB')
    epoch_time = time.time() - start
    logger.info(
        f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}"
    )
Example #24
0
    def _train_epoch(sess,
                     model,
                     optimizer,
                     pretrain,
                     losses,
                     verbose=verbose_config.verbose):
        global teach_rate
        global step
        if captioning:
            data_loader = data_loaders['train']
            encoder.eval()
        else:
            data_iterator.restart_dataset(sess, 'train')
            feed_dict = {
                data_iterator.handle: data_iterator.get_handle(sess, 'train')
            }
        model.train()

        if not captioning:

            def _get_data_loader():
                while True:
                    try:
                        yield sess.run(data_batch, feed_dict=feed_dict)
                    except tf.errors.OutOfRangeError:
                        break

            data_loader = _get_data_loader()

        for batch_i, batch in enumerate(data_loader):
            if batch_i >= train_config.train_batches:
                break

            sample_verbose = verbose and (step +
                                          1) % verbose_config.steps_sample == 0
            if captioning:
                images, tgt_ids, lengths = batch
                res = run_model(model,
                                encoder,
                                batch,
                                target_vocab,
                                teach_rate=teach_rate,
                                device=device,
                                verbose=sample_verbose)
            else:
                tgt_ids = batch['target_text_ids']
                res = run_model(model,
                                None,
                                batch,
                                target_vocab,
                                teach_rate=teach_rate,
                                device=device,
                                verbose=sample_verbose)
            batch_size = tgt_ids.shape[0]
            if train_config.enable_cross_entropy:
                cel = res['ce']['loss']
                cel_ = cel.cpu().data.numpy()
            else:
                cel_ = -1.
            if train_config.enable_bleu:
                probs = res['mb']['X']
                if sample_verbose and verbose_config.probs_verbose:
                    probs.retain_grad()
                gen_ids = res['mb']['gen_ids']
                gen_probs = res['mb']['gen_probs']
                mbl = res['mb']['loss']
                mbl_ = mbl.cpu().data.numpy()
            else:
                mbl_ = -1.
            loss = res['loss']
            if pretrain:
                if sample_verbose:
                    logging.info('pretraining')
                loss = cel
            loss_ = loss.cpu().data.numpy()

            if train_config.enable_bleu and sample_verbose and verbose_config.probs_verbose:
                mbls_ = res['mb']['mbls_']
                grad_ = []
                for order in range(1, criterion_bleu.max_order + 1):
                    optimizer.zero_grad()
                    mbls_[order - 1].backward(retain_graph=True)
                    grad_.append(probs.grad)
                grad_ = torch.stack(grad_, dim=1)

            optimizer.zero_grad()
            loss.backward()
            if train_config.clip_grad_norm is None:
                grad_norm = get_grad_norm(model.parameters())
            else:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), train_config.clip_grad_norm)

            if train_config.enable_bleu and sample_verbose:
                samples = min(verbose_config.samples, batch_size)
                gen_words, tgt_words = map(ids_to_words, (gen_ids, tgt_ids))
                if verbose_config.probs_verbose:
                    gen_grads = torch.gather(probs.grad, -1,
                                             gen_ids.unsqueeze(-1)).squeeze(-1)
                    max_grads, max_ids = probs.grad.min(-1)
                    max_probs = torch.gather(probs, -1,
                                             max_ids.unsqueeze(-1)).squeeze(-1)
                    max_words = ids_to_words(max_ids)
                    max_grad_, max_id_ = grad_.min(-1)
                    max_word_ = ids_to_words(max_id_)
                for sample_i, (gen_sent, tgt_sent) in enumerate(
                        zip(gen_words, tgt_words)):
                    if sample_i >= samples:
                        break
                    l = list(tgt_sent).index(
                        target_vocab.eos_token.encode(
                            data_config.encoding)) + 1
                    logging.info('tgt: {}'.format(b' '.join(
                        tgt_sent[:l]).decode(data_config.encoding)))
                    logging.info('gen: {}'.format(b' '.join(
                        gen_sent[:l]).decode(data_config.encoding)))
                    if verbose_config.probs_verbose:
                        logging.info('max: {}'.format(b' '.join(
                            max_words[sample_i][:l]).decode(
                                data_config.encoding)))
                        logging.info('gen probs:\n{}'.format(
                            gen_probs[sample_i][:l]))
                        logging.info('gen grads:\n{}'.format(
                            gen_grads[sample_i][:l]))
                        logging.info('max probs:\n{}'.format(
                            max_probs[sample_i][:l]))
                        logging.info('max grads:\n{}'.format(
                            max_grads[sample_i][:l]))
                        for order in range(1, criterion_bleu.max_order + 1):
                            logging.info('{}-gram max: {}'.format(
                                order, b' '.join(
                                    max_word_[sample_i][order - 1][:l]).decode(
                                        data_config.encoding)))
                            logging.info('{}-gram max grads:\n{}'.format(
                                order, max_grad_[sample_i][order - 1][:l]))
            losses.append([loss_, cel_, mbl_, grad_norm])
            writer.add_scalar('train/loss', loss_, step)
            writer.add_scalar('train/cel', cel_, step)
            writer.add_scalar('train/mbl', mbl_, step)
            writer.add_scalar('train/grad_norm', grad_norm, step)
            step += 1
            if step % verbose_config.steps_loss == 0:
                logging.info(
                    'step: {}\tloss: {:.3f}\tcel: {:.3f}\tmbl: {:.3f}\tgrad_norm: {:.3f}'
                    .format(step, loss_, cel_, mbl_, grad_norm))

            optimizer.step()

            if step % verbose_config.steps_eval == 0:
                _eval_on_dev_set()
                if captioning:
                    encoder.eval()
                model.train()
                #losses.plot(os.path.join(logdir, 'train_losses'))

            if train_config.checkpoints and step % verbose_config.steps_ckpt == 0:
                _save_model(epoch, step)

            if train_config.enable_bleu and step % train_config.teach_rate_anneal_steps == 0:
                teach_rate *= train_config.teach_rate_anneal
                logging.info("teach rate: {}".format(teach_rate))
Example #25
0
def train_epoch(model, criterion, train_iter, valid_iter, config):
    current_lr = config.lr

    lowest_valid_loss = np.inf
    no_improve_cnt = 0

    for epoch in range(1, config.n_epochs):
        optimizer = optim.SGD(model.parameters(), lr=current_lr)
        print("current learning rate: %f" % current_lr)
        print(optimizer)

        sample_cnt = 0
        total_loss, total_word_count, total_parameter_norm, total_grad_norm = 0, 0, 0, 0
        start_time = time.time()
        train_loss = np.inf

        for batch_index, batch in enumerate(train_iter):
            optimizer.zero_grad()

            batch = prepare_data(batch)

            current_batch_word_cnt = torch.sum(batch[1])
            # Most important lines in this method.
            # Since model takes BOS + sentence as an input and sentence + EOS as an output,
            # x(input) excludes last index, and y(index) excludes first index.
            x = batch[0][:, :-1]
            y = batch[0][:, 1:]
            # feed-forward
            hidden = model.init_hidden(config.batch_size)
            # print("hidden : ", hidden[0].shape, hidden[1].shape)
            # print("x : ", x.shape)
            # print("batch[1]", batch[1])
            y_hat = model(x, batch[1], hidden)

            # calcuate loss and gradients with back-propagation
            loss = get_loss(y, y_hat, criterion)

            # simple math to show stats
            total_loss += float(loss)
            total_word_count += int(current_batch_word_cnt)
            total_parameter_norm += float(
                utils.get_parameter_norm(model.parameters()))
            total_grad_norm += float(utils.get_grad_norm(model.parameters()))

            if (batch_index + 1) % config.print_every == 0:
                avg_loss = total_loss / total_word_count
                avg_parameter_norm = total_parameter_norm / config.print_every
                avg_grad_norm = total_grad_norm / config.print_every
                elapsed_time = time.time() - start_time

                print(
                    "epoch: %d batch: %d/%d\t|param|: %.2f\t|g_param|: %.2f\tloss: %.4f\tPPL: %.2f\t%5d words/s %3d secs"
                    % (epoch, batch_index + 1,
                       int((len(train_iter.dataset) // config.batch_size)),
                       avg_parameter_norm, avg_grad_norm, avg_loss,
                       np.exp(avg_loss), total_word_count // elapsed_time,
                       elapsed_time))

                total_loss, total_word_count, total_parameter_norm, total_grad_norm = 0, 0, 0, 0
                start_time = time.time()

                train_loss = avg_loss

            # Another important line in this method.
            # In orther to avoid gradient exploding, we apply gradient clipping.
            torch_utils.clip_grad_norm_(model.parameters(),
                                        config.max_grad_norm)
            # Take a step of gradient descent.
            optimizer.step()

            sample_cnt += batch[0].size(0)
            if sample_cnt >= len(train_iter.dataset):
                break

        sample_cnt = 0
        total_loss, total_word_count = 0, 0

        model.eval()
        for batch_index, batch in enumerate(valid_iter):
            batch = prepare_data(batch)
            current_batch_word_cnt = torch.sum(batch[1])
            x = batch[0][:, :-1]
            y = batch[0][:, 1:]
            hidden = model.init_hidden(config.batch_size)

            y_hat = model(x, batch[1], hidden)

            loss = get_loss(y, y_hat, criterion, do_backward=False)

            total_loss += float(loss)
            total_word_count += int(current_batch_word_cnt)

            sample_cnt += batch[0].size(0)
            if sample_cnt >= len(valid_iter.dataset):
                break

        avg_loss = total_loss / total_word_count
        print("valid loss: %.4f\tPPL: %.2f" % (avg_loss, np.exp(avg_loss)))

        if lowest_valid_loss > avg_loss:
            lowest_valid_loss = avg_loss
            no_improve_cnt = 0
        else:
            # decrease learing rate if there is no improvement.
            current_lr /= 10.
            no_improve_cnt += 1

        model.train()

        # model_fn = config.model.split(".")
        model_fn = config.model  # model name
        model_fn = [model_fn[:-1]] + [
            "%02d" % epoch,
            "%.2f-%.2f" % (train_loss, np.exp(train_loss)),
            "%.2f-%.2f" % (avg_loss, np.exp(avg_loss))
        ] + [model_fn[-1]]

        # PyTorch provides efficient method for save and load model, which uses python pickle.
        torch.save(
            {
                "model": model.state_dict(),
                "config": config,
                "epoch": epoch + 1,
                "current_lr": current_lr
            }, ".".join(model_fn))

        if config.early_stop > 0 and no_improve_cnt > config.early_stop:
            break
Example #26
0
    def step(engine, mini_batch):
        from utils import get_grad_norm, get_parameter_norm

        # You have to reset the gradients of all model parameters
        # before to take another step in gradient descent.
        engine.model.train()
        engine.optimizer.zero_grad()

        # Raw target variable has both BOS and EOS token.
        # The output of sequence-to-sequence does not have BOS token.
        # Thus, remove BOS token for reference.
        x, y = mini_batch.src, mini_batch.tgt[0][:, 1:]
        # |x| = (batch_size, length)
        # |y| = (batch_size, length)

        # Take sampling process because set False for is_greedy.
        y_hat, indice = engine.model.search(
            x, is_greedy=False, max_length=engine.config.max_length)
        # Based on the result of sampling, get reward.
        actor_reward = MinimumRiskTrainer.get_reward(
            indice, y, n_gram=engine.config.rl_n_gram)
        # |y_hat| = (batch_size, length, output_size)
        # |indice| = (batch_size, length)
        # |actor_reward| = (batch_size)

        # Take samples as many as n_samples, and get average rewards for them.
        # I figured out that n_samples = 1 would be enough.
        baseline = []
        with torch.no_grad():
            for _ in range(engine.config.rl_n_samples):
                _, sampled_indice = engine.model.search(
                    x,
                    is_greedy=False,
                    max_length=engine.config.max_length,
                )
                baseline += [
                    MinimumRiskTrainer.get_reward(
                        sampled_indice,
                        y,
                        n_gram=engine.config.rl_n_gram,
                    )
                ]

            baseline = torch.stack(baseline).sum(dim=0).div(
                engine.config.rl_n_samples)
            # |baseline| = (n_samples, batch_size) --> (batch_size)

        # Now, we have relatively expected cumulative reward.
        # Which score can be drawn from actor_reward subtracted by baseline.
        final_reward = actor_reward - baseline
        # |final_reward| = (batch_size)

        # calculate gradients with back-propagation
        MinimumRiskTrainer.get_gradient(y_hat,
                                        indice,
                                        engine.crit,
                                        reward=final_reward)

        p_norm = float(get_parameter_norm(engine.model.parameters()))
        g_norm = float(get_grad_norm(engine.model.parameters()))

        # In orther to avoid gradient exploding, we apply gradient clipping.
        torch_utils.clip_grad_norm_(
            engine.model.parameters(),
            engine.config.max_grad_norm,
        )
        # Take a step of gradient descent.
        engine.optimizer.step()

        return (
            float(actor_reward.mean()),
            float(baseline.mean()),
            float(final_reward.mean()),
            p_norm,
            g_norm,
        )
Example #27
0
    def train_epoch(self, train, optimizer, verbose=VERBOSE_BATCH_WISE):
        '''
        Train an epoch with given train iterator and optimizer.
        '''
        total_loss, total_word_count = 0, 0
        total_grad_norm = 0
        avg_loss, avg_grad_norm = 0, 0
        sample_cnt = 0

        progress_bar = tqdm(
            train, desc='Training: ',
            unit='batch') if verbose is VERBOSE_BATCH_WISE else train
        # Iterate whole train-set.
        for idx, mini_batch in enumerate(progress_bar):
            # Raw target variable has both BOS and EOS token.
            # The output of sequence-to-sequence does not have BOS token.
            # Thus, remove BOS token for reference.
            x = mini_batch.src[0][:, :-1] if self.is_src else mini_batch.tgt[
                0][:, :-1]
            y = mini_batch.src[0][:,
                                  1:] if self.is_src else mini_batch.tgt[0][:,
                                                                            1:]
            # |x| = (batch_size, length)

            # You have to reset the gradients of all model parameters before to take another step in gradient descent.
            optimizer.zero_grad()

            # Take feed-forward
            y_hat = self.model(x)
            loss = self._get_loss(y_hat, y).sum()
            loss.div(y.size(0)).backward()

            # Simple math to show stats.
            # Don't forget to detach final variables.
            total_loss += float(loss)
            total_word_count += int(mini_batch.src[1].sum() if self.
                                    is_src else mini_batch.tgt[1].sum())
            param_norm = float(
                utils.get_parameter_norm(self.model.parameters()))
            total_grad_norm += float(
                utils.get_grad_norm(self.model.parameters()))

            avg_loss = total_loss / total_word_count
            avg_grad_norm = total_grad_norm / (idx + 1)

            if verbose is VERBOSE_BATCH_WISE:
                progress_bar.set_postfix_str(
                    '|param|=%.2f |g_param|=%.2f loss=%.4e PPL=%.2f' %
                    (param_norm, avg_grad_norm, avg_loss, exp(avg_loss)))

            # In orther to avoid gradient exploding, we apply gradient clipping.
            torch_utils.clip_grad_norm_(self.model.parameters(),
                                        self.config.max_grad_norm)
            # Take a step of gradient descent.
            optimizer.step()

            sample_cnt += x.size(0)
            if idx >= len(progress_bar):
                break

        if verbose is VERBOSE_BATCH_WISE:
            progress_bar.close()

        return avg_loss, param_norm, avg_grad_norm
Example #28
0
def train_epoch(model,
                bimpm,
                criterion,
                train_iter,
                valid_iter,
                config,
                start_epoch=1,
                others_to_save=None,
                valid_nli_iter=None):
    current_lr = config.rl_lr

    highest_valid_bleu = -np.inf
    no_improve_cnt = 0

    # Print initial valid BLEU before we start RL.
    model.eval()
    total_reward, sample_cnt = 0, 0
    for batch_index, batch in enumerate(valid_iter):
        current_batch_word_cnt = torch.sum(batch.tgt[1])
        x = batch.src
        y = batch.tgt[0][:, 1:]
        batch_size = y.size(0)
        # |x| = (batch_size, length)
        # |y| = (batch_size, length)

        # feed-forward
        y_hat, indice = model.search(x,
                                     is_greedy=True,
                                     max_length=config.max_length)
        # |y_hat| = (batch_size, length, output_size)
        # |indice| = (batch_size, length)

        reward = get_bleu_reward(y,
                                 indice,
                                 n_gram=min(config.rl_n_gram, indice.size(1)))

        total_reward += float(reward.sum())
        sample_cnt += batch_size
        if sample_cnt >= len(valid_iter.dataset.examples):
            break
    avg_bleu = total_reward / sample_cnt
    print("initial valid BLEU: %.4f" %
          avg_bleu)  # You can figure-out improvement.

    if valid_nli_iter:
        nli_validation(valid_nli_iter, model, bimpm, config)
    model.train()  # Now, begin training.

    # Start RL
    nli_criterion = nn.CrossEntropyLoss(reduce=False)
    print("start epoch:", start_epoch)
    print("number of epoch to complete:", config.rl_n_epochs + 1)

    if config.reward_mode == 'combined':
        if config.gpu_id >= 0:
            nli_weight = torch.tensor([1.0], requires_grad=True, device="cuda")
            bleu_weight = torch.tensor([1.0],
                                       requires_grad=True,
                                       device="cuda")
        else:
            nli_weight = torch.tensor([1.0], requires_grad=True)
            bleu_weight = torch.tensor([1.0], requires_grad=True)

        print("nli_weight, bleu_weight:",
              nli_weight.data.cpu().numpy()[0],
              bleu_weight.data.cpu().numpy()[0])
        weight_optimizer = optim.Adam(iter([nli_weight, bleu_weight]),
                                      lr=0.0001)

    optimizer = optim.SGD(
        model.parameters(), lr=current_lr,
        momentum=0.9)  # Default hyper-parameter is set for SGD.
    print("current learning rate: %f" % current_lr)
    print(optimizer)

    for epoch in range(start_epoch, config.rl_n_epochs + 1):
        sample_cnt = 0
        total_risk, total_errors, total_sample_count, total_word_count, total_parameter_norm, total_grad_norm = 0, 0, 0, 0, 0, 0
        total_label_correct = np.zeros(3)
        total_label_size = np.zeros(3)
        start_time = time.time()
        train_loss = np.inf

        for batch_index, batch in enumerate(train_iter):
            optimizer.zero_grad()

            current_batch_word_cnt = torch.sum(batch.tgt[1])
            x = batch.src
            y = batch.tgt[0][:, 1:]
            batch_size = y.size(0)
            epoch_accuracy = []
            max_sample_length = 0
            sequence_probs, errors = [], []
            if config.reward_mode != 'bleu':
                premise = batch.premise
                hypothesis = batch.hypothesis
                isSrcPremise = batch.isSrcPremise
                label = batch.labels

            # |x| = (batch_size, length)
            # |y| = (batch_size, length)

            for _ in range(config.n_samples):
                # Take sampling process because set False for is_greedy.
                y_hat, indice = model.search(x,
                                             is_greedy=False,
                                             max_length=config.max_length)
                max_sample_length = max(max_sample_length, indice.size(1))
                prob = y_hat.gather(2, indice.unsqueeze(2)).squeeze(2)
                sequence_probs.append(prob)
                # |prob| = (batch_size, length)

                if config.reward_mode == 'bleu':
                    bleu = get_bleu_reward(y, indice, n_gram=config.rl_n_gram)
                    reward = 1 - bleu / 100
                    epoch_accuracy.append(bleu.sum() / batch_size)
                else:
                    padded_indice, padded_premise, padded_hypothesis = padding_three_tensors(
                        indice, premise, hypothesis, batch_size)

                    # put pred sentece into either premise and hypothesis
                    for i in range(batch_size):
                        if isSrcPremise[i]:
                            padded_premise[i] = padded_indice[i]
                        else:
                            padded_hypothesis[i] = padded_indice[i]

                    with torch.no_grad():
                        kwargs = {'p': padded_premise, 'h': padded_hypothesis}
                        pred_logit = bimpm(**kwargs)
                        accuracy = get_accuracy(pred_logit, label)
                        num_correct, size = get_label_accuracy(
                            pred_logit, label)
                        total_label_correct += num_correct
                        total_label_size += size
                    epoch_accuracy.append(accuracy)

                    # Based on the result of sampling, get reward.
                    if config.reward_mode == 'nli':
                        reward = get_nli_reward(pred_logit, label,
                                                nli_criterion)
                    else:
                        reward = 1/(2 * nli_weight.pow(2)) * get_nli_reward(pred_logit, label, nli_criterion) \
                            + 1/(2 * bleu_weight.pow(2)) * (1 - get_bleu_reward(y, indice, n_gram=config.rl_n_gram)/100) \
                            + torch.log(nli_weight * bleu_weight)
                # |y_hat| = (batch_size, length, output_size)
                # |indice| = (batch_size, length)
                # |reward| = (batch_size)
                errors.append(reward)

            padded_probs = pad_probs(sequence_probs, max_sample_length)
            sequence_probs = torch.stack(padded_probs, dim=2)
            # |sequence_probs| = (batch_size, max_sample_length, sample_size)
            errors = torch.stack(errors, dim=1)
            # |errors| = (batch_size, sample_size)

            avg_probs = sequence_probs.sum(dim=1) / max_sample_length
            if config.temperature != 1.0:
                probs = avg_probs.exp_().pow(1 / config.temperature)
                probs = nn.functional.softmax(probs, dim=1)
            else:
                probs = nn.functional.softmax(avg_probs.exp_(), dim=1)
            risk = (probs * errors).sum() / batch_size
            risk.backward()

            # simple math to show stats
            total_risk += float(risk.sum())
            total_errors += float(reward.sum())
            total_sample_count += batch_size
            total_word_count += int(current_batch_word_cnt)
            total_parameter_norm += float(
                utils.get_parameter_norm(model.parameters()))
            total_grad_norm += float(utils.get_grad_norm(model.parameters()))

            if (batch_index + 1) % config.print_every == 0:
                avg_risk = total_risk / total_sample_count
                avg_errors = total_errors / total_sample_count
                avg_parameter_norm = total_parameter_norm / config.print_every
                avg_grad_norm = total_grad_norm / config.print_every
                avg_epoch_accuracy = sum(epoch_accuracy) / len(epoch_accuracy)
                elapsed_time = time.time() - start_time

                print(
                    "epoch: %d batch: %d/%d\t|param|: %.2f\t|g_param|: %.2f\trisk: %.4f\terror: %.4f\tAccuracy: %.2f\t%5d words/s %3d secs"
                    % (epoch, batch_index + 1,
                       int(
                           len(train_iter.dataset.examples) //
                           config.batch_size), avg_parameter_norm,
                       avg_grad_norm, avg_risk, avg_errors, avg_epoch_accuracy,
                       total_word_count // elapsed_time, elapsed_time))

                print("train label accuracy:",
                      total_label_correct / total_label_size)
                if config.reward_mode == 'combined':
                    print("nli_weight, bleu_weight:",
                          nli_weight.data.cpu().numpy()[0],
                          bleu_weight.data.cpu().numpy()[0])

                total_risk, total_errors, total_sample_count, total_word_count, total_parameter_norm, total_grad_norm = 0, 0, 0, 0, 0, 0
                epoch_accuracy = []
                total_label_correct = np.zeros(3)
                total_label_size = np.zeros(3)
                start_time = time.time()

                train_loss = avg_bleu

            # In orther to avoid gradient exploding, we apply gradient clipping.
            torch_utils.clip_grad_norm_(model.parameters(),
                                        config.max_grad_norm)
            # Take a step of gradient descent.
            optimizer.step()
            if config.reward_mode == 'combined':
                weight_optimizer.step()

            sample_cnt += batch_size
            if sample_cnt >= len(train_iter.dataset.examples):
                break

        sample_cnt = 0
        total_reward = 0

        # Start validation
        with torch.no_grad():
            model.eval()  # Turn-off drop-out

            for batch_index, batch in enumerate(valid_iter):
                current_batch_word_cnt = torch.sum(batch.tgt[1])
                x = batch.src
                y = batch.tgt[0][:, 1:]
                batch_size = y.size(0)
                # |x| = (batch_size, length)
                # |y| = (batch_size, length)

                # feed-forward
                y_hat, indice = model.search(x,
                                             is_greedy=True,
                                             max_length=config.max_length)
                # |y_hat| = (batch_size, length, output_size)
                # |indice| = (batch_size, length)

                reward = get_bleu_reward(y, indice, n_gram=config.rl_n_gram)

                total_reward += float(reward.sum())
                sample_cnt += batch_size
                if sample_cnt >= len(valid_iter.dataset.examples):
                    break

            avg_bleu = total_reward / sample_cnt
            print("valid BLEU: %.4f" % avg_bleu)

            if highest_valid_bleu < avg_bleu:
                highest_valid_bleu = avg_bleu
                no_improve_cnt = 0
            else:
                no_improve_cnt += 1

            if valid_nli_iter:
                nli_validation(valid_nli_iter, model, bimpm, config)
            model.train()

        model_fn = config.model.split(".")
        model_fn = model_fn[:-1] + [
            "%02d" % (config.n_epochs + epoch),
            "%.2f-%.4f" % (train_loss, avg_bleu)
        ] + [model_fn[-1]] + [config.reward_mode]

        # PyTorch provides efficient method for save and load model, which uses python pickle.
        to_save = {
            "model": model.state_dict(),
            "config": config,
            "epoch": config.n_epochs + epoch + 1,
            "current_lr": current_lr
        }
        if others_to_save is not None:
            for k, v in others_to_save.items():
                to_save[k] = v
        torch.save(to_save, '.'.join(model_fn))

        if config.early_stop > 0 and no_improve_cnt > config.early_stop:
            break
Example #29
0
    def step(engine, mini_batch):
        from utils import get_grad_norm, get_parameter_norm

        for language_model, model, optimizer in zip(engine.language_models,
                                                    engine.models,
                                                    engine.optimizers):
            language_model.eval()
            model.train()
            optimizer.zero_grad()

        # X2Y
        x, y = (mini_batch.src[0][:, 1:-1],
                mini_batch.src[1] - 2), mini_batch.tgt[0][:, :-1]
        # |x| = (batch_size, n)
        # |y| = (batch_size, m)
        y_hat = engine.models[X2Y](x, y)
        # |y_hat| = (batch_size, m, y_vocab_size)
        with torch.no_grad():
            p_hat_y = engine.language_models[X2Y](y)
            # |p_hat_y| = |y_hat|

        #Y2X
        # Since encoder in seq2seq takes packed_sequence instance,
        # we need to re-sort if we use reversed src and tgt.
        x, y, restore_indice = DualSupervisedTrainer._reordering(
            mini_batch.src[0][:, :-1],
            mini_batch.tgt[0][:, 1:-1],
            mini_batch.tgt[1] - 2,
        )
        # |x| = (batch_size, n)
        # |y| = (batch_size, m)
        x_hat = engine.models[Y2X](y, x).index_select(dim=0,
                                                      index=restore_indice)
        # |x_hat| = (batch_size, n, x_vocab_size)

        with torch.no_grad():
            p_hat_x = engine.language_models[Y2X](x).index_select(
                dim=0, index=restore_indice)
            # |p_hat_x| = |x_hat|

        x, y = mini_batch.src[0][:, 1:], mini_batch.tgt[0][:, 1:]
        loss_x2y, loss_y2x, dual_loss = DualSupervisedTrainer._get_loss(
            x,
            y,
            x_hat,
            y_hat,
            engine.crits,
            p_hat_x,
            p_hat_y,
            # According to the paper, DSL should be warm-started.
            # Thus, we turn-off the regularization at the beginning.
            lagrange=engine.config.dsl_lambda
            if engine.state.epoch >= engine.config.n_epochs else .0)

        loss_x2y.div(y.size(0)).backward()
        loss_y2x.div(x.size(0)).backward()

        p_norm = float(
            get_parameter_norm(
                list(engine.models[X2Y].parameters()) +
                list(engine.models[Y2X].parameters())))
        g_norm = float(
            get_grad_norm(
                list(engine.models[X2Y].parameters()) +
                list(engine.models[Y2X].parameters())))

        for model, optimizer in zip(engine.models, engine.optimizers):
            torch_utils.clip_grad_norm_(
                model.parameters(),
                engine.config.max_grad_norm,
            )
            # Take a step of gradient descent.
            optimizer.step()

        return (
            float(loss_x2y / mini_batch.src[1].sum()),
            float(loss_y2x / mini_batch.tgt[1].sum()),
            float(dual_loss / x.size(0)),
            p_norm,
            g_norm,
        )
Example #30
0
def train_epoch(model,
                criterion,
                train_iter,
                valid_iter,
                config,
                start_epoch=1,
                others_to_save=None):
    current_lr = config.lr

    lowest_valid_loss = np.inf
    no_improve_cnt = 0

    for epoch in range(start_epoch, config.n_epochs + 1):
        if config.adam:
            optimizer = optim.Adam(model.parameters(), lr=current_lr)
        else:
            optimizer = optim.SGD(model.parameters(), lr=current_lr)
        print("current learning rate: %f" % current_lr)
        print(optimizer)

        sample_cnt = 0
        total_loss, total_word_count, total_parameter_norm, total_grad_norm = 0, 0, 0, 0
        start_time = time.time()
        train_loss = np.inf

        for batch_index, batch in enumerate(train_iter):
            # You have to reset the gradients of all model parameters before to take another step in gradient descent.
            optimizer.zero_grad()

            current_batch_word_cnt = torch.sum(batch.tgt[1])
            x = batch.src
            # Raw target variable has both BOS and EOS token.
            # The output of sequence-to-sequence does not have BOS token.
            # Thus, remove BOS token for reference.
            y = batch.tgt[0][:, 1:]
            # |x| = (batch_size, length)
            # |y| = (batch_size, length)

            # Take feed-forward
            # Similar as before, the input of decoder does not have EOS token.
            # Thus, remove EOS token for decoder input.
            y_hat = model(x, batch.tgt[0][:, :-1])
            # |y_hat| = (batch_size, length, output_size)

            # Calcuate loss and gradients with back-propagation.
            loss = get_loss(y, y_hat, criterion)

            # Simple math to show stats.
            total_loss += float(loss)
            total_word_count += int(current_batch_word_cnt)
            total_parameter_norm += float(
                utils.get_parameter_norm(model.parameters()))
            total_grad_norm += float(utils.get_grad_norm(model.parameters()))

            # Print current training status in every this number of mini-batch is done.
            if (batch_index + 1) % config.print_every == 0:
                avg_loss = total_loss / total_word_count
                avg_parameter_norm = total_parameter_norm / config.print_every
                avg_grad_norm = total_grad_norm / config.print_every
                elapsed_time = time.time() - start_time

                # You can check the current status using parameter norm and gradient norm.
                # Also, you can check the speed of the training.
                print(
                    "epoch: %d batch: %d/%d\t|param|: %.2f\t|g_param|: %.2f\tloss: %.4f\tPPL: %.2f\t%5d words/s %3d secs"
                    % (epoch, batch_index + 1,
                       int(
                           len(train_iter.dataset.examples) //
                           config.batch_size), avg_parameter_norm,
                       avg_grad_norm, avg_loss, np.exp(avg_loss),
                       total_word_count // elapsed_time, elapsed_time))

                total_loss, total_word_count, total_parameter_norm, total_grad_norm = 0, 0, 0, 0
                start_time = time.time()

                train_loss = avg_loss

            # In orther to avoid gradient exploding, we apply gradient clipping.
            torch_utils.clip_grad_norm_(model.parameters(),
                                        config.max_grad_norm)
            ##### timestep이 다를 수록 gradient가 쌓인다.

            # Take a step of gradient descent.
            optimizer.step()

            sample_cnt += batch.tgt[0].size(0)
            if sample_cnt >= len(train_iter.dataset.examples):
                break

        sample_cnt = 0
        total_loss, total_word_count = 0, 0

        with torch.no_grad():  # In validation, we don't need to get gradients.
            model.eval()  # Turn-on the evaluation mode.

            for batch_index, batch in enumerate(valid_iter):
                current_batch_word_cnt = torch.sum(batch.tgt[1])
                x = batch.src
                y = batch.tgt[0][:, 1:]
                # |x| = (batch_size, length)
                # |y| = (batch_size, length)

                # Take feed-forward
                y_hat = model(x, batch.tgt[0][:, :-1])
                # |y_hat| = (batch_size, length, output_size)

                loss = get_loss(y, y_hat, criterion, do_backward=False)

                total_loss += float(loss)
                total_word_count += int(current_batch_word_cnt)

                sample_cnt += batch.tgt[0].size(0)
                if sample_cnt >= len(valid_iter.dataset.examples):
                    break

            # Print result of validation.
            avg_loss = total_loss / total_word_count
            print("valid loss: %.4f\tPPL: %.2f" % (avg_loss, np.exp(avg_loss)))

            if lowest_valid_loss > avg_loss:
                lowest_valid_loss = avg_loss
                no_improve_cnt = 0

                # Altough there is an improvement in last epoch, we need to decay the learning-rate if it meets the requirements.
                if epoch >= config.lr_decay_start_at:
                    current_lr = max(config.min_lr,
                                     current_lr * config.lr_decay_rate)
            else:
                # Decrease learing rate if there is no improvement.
                current_lr = max(config.min_lr,
                                 current_lr * config.lr_decay_rate)
                no_improve_cnt += 1

            # Again, turn-on the training mode.
            model.train()

        # Set a filename for model of last epoch.
        # We need to put every information to filename, as much as possible.
        model_fn = config.model.split(".")
        model_fn = model_fn[:-1] + [
            "%02d" % epoch,
            "%.2f-%.2f" % (train_loss, np.exp(train_loss)),
            "%.2f-%.2f" % (avg_loss, np.exp(avg_loss))
        ] + [model_fn[-1]]

        # PyTorch provides efficient method for save and load model, which uses python pickle.
        to_save = {
            "model": model.state_dict(),
            "config": config,
            "epoch": epoch + 1,
            "current_lr": current_lr
        }
        if others_to_save is not None:  # Add others if it is necessary.
            for k, v in others_to_save.items():
                to_save[k] = v
        torch.save(to_save, '.'.join(model_fn))

        # Take early stopping if it meets the requirement.
        if config.early_stop > 0 and no_improve_cnt > config.early_stop:
            break