Exemplo n.º 1
0
def train_ch8(net,
              train_iter,
              vocab,
              lr,
              num_epochs,
              device,
              use_random_iter=False):
    loss = nn.CrossEntropyLoss()
    animator = d2l.Animator(xlabel='epoch',
                            ylabel='perplexity',
                            legend=['train'],
                            xlim=[10, num_epochs])
    # 初始化
    if isinstance(net, nn.Module):
        updater = torch.optim.Adam(net.parameters(), lr)
    else:
        updater = lambda batch_size: d2l.sgd(net.params, lr, batch_size)
    predict = lambda prefix: predict_ch8(prefix, 50, net, vocab, device)
    # 训练和预测
    for epoch in range(num_epochs):
        ppl, speed = train_epoch_ch8(net, train_iter, loss, updater, device,
                                     use_random_iter)
        if (epoch + 1) % 10 == 0:
            print(predict('time traveller'))
            animator.add(epoch + 1, [ppl])
    print(f'困惑度 {ppl:.1f}, {speed:.1f} 标记/秒 {str(device)}')
    print(predict('time traveller'))
    print(predict('traveller'))
Exemplo n.º 2
0
def train_ch8(net,
              train_iter,
              vocab,
              lr,
              num_epochs,
              device,
              use_random_iter=False):
    loss = nn.CrossEntropyLoss()
    animator = d2l.Animator(xlabel="epoch",
                            ylabel="preplexity",
                            legend=["train"],
                            xlim=[10, num_epochs])
    if isinstance(net, nn.Module):
        updater = torch.optim.SGD(net.parameters(), lr)
    else:
        updater = lambda batch_size: d2l.sgd(net.params, lr, batch_size)
    predict = lambda prefix: predict_ch8(prefix, 50, net, vocab, device)
    for epoch in range(num_epochs):
        ppl, speed = train_epoch_ch8(net, train_iter, loss, updater, device,
                                     use_random_iter)
        if (epoch + 1) % 10 == 0:
            print(predict("time traveller"))
            animator.add(epoch + 1, [ppl])
    d2l.plt.show()
    print(f"困惑度 {ppl:.1f}, {speed:.1f} 标记/秒 {str(device)}")
    print(predict("time traveller"))
    print(predict("traveller"))
Exemplo n.º 3
0
def train_ch8(model,
              train_iter,
              vocab,
              lr,
              num_epochs,
              device,
              use_random_iter=False):
    """Train a model (defined in Chapter 8)."""
    loss = nn.CrossEntropyLoss()
    animator = d2l.Animator(xlabel='epoch',
                            ylabel='perplexity',
                            legend=['train'],
                            xlim=[10, num_epochs])
    # Initialize
    if isinstance(model, nn.Module):  # false, skip
        updater = torch.optim.SGD(model.parameters(), lr)
    else:
        updater = lambda batch_size: d2l.sgd(model.params, lr, batch_size)
    predict = lambda prefix: predict_ch8(prefix, 50, model, vocab, device)
    # Train and predict
    for epoch in range(num_epochs):
        ppl, speed = train_epoch_ch8(model, train_iter, loss, updater, device,
                                     use_random_iter)
        if (epoch + 1) % 10 == 0:
            print(predict('time traveller'))
            animator.add(epoch + 1, [ppl])
    print(f'perplexity {ppl:.1f}, {speed:.1f} tokens/sec on {str(device)}')
    print(predict('time traveller'))
    print(predict('traveller'))
Exemplo n.º 4
0
 def train(self,
           net,
           train_iter,
           lr,
           num_epochs,
           device,
           use_random_iter=False):
     """Train a model (defined in Chapter 8)."""
     loss = nn.MSELoss()
     animator = d2l.Animator(xlabel='epoch',
                             ylabel='perplexity',
                             legend=['train'],
                             xlim=[10, num_epochs])
     # Initialize
     if isinstance(net, nn.Module):
         updater = torch.optim.SGD(net.parameters(), lr)
     else:
         updater = lambda batch_size: d2l.sgd(net.params, lr, batch_size)
     # Train and predict
     for epoch in range(num_epochs):
         mse, speed = self.train_epoch(net, train_iter, loss, updater,
                                       device, use_random_iter)
         if (epoch + 1) % 10 == 0:
             animator.add(epoch + 1, [mse])
     # plt.show()
     print(
         f'mean squared loss {mse:.1f}, {speed:.1f} tokens/sec on {str(device)}'
     )
Exemplo n.º 5
0
def train(lambd):
    w, b = init_params()
    net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_loss
    num_epochs, lr = 100, 0.003
    animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',
                            xlim=[5, num_epochs], legend=['train', 'test'])
    for epoch in range(num_epochs):
        for X, y in train_iter:
            with torch.enable_grad():
                # The L2 norm penalty term has been added, and broadcasting
                # makes `l2_penalty(w)` a vector whose length is `batch_size`
                l = loss(net(X), y) + lambd * l2_penalty(w)
            l.sum().backward()
            d2l.sgd([w, b], lr, batch_size)
        if (epoch + 1) % 5 == 0:
            animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),
                                     d2l.evaluate_loss(net, test_iter, loss)))
    print('L2 norm of w:', torch.norm(w).item())
Exemplo n.º 6
0
def train_ch8_slim(net,
                   train_iter,
                   vocab,
                   lr,
                   num_epochs,
                   device,
                   use_random_iter=False):
    """Train a model (defined in Chapter 8).
    Slimmed down for binary searching
    """
    loss = nn.CrossEntropyLoss()
    # Initialize
    if isinstance(net, nn.Module):
        updater = torch.optim.SGD(net.parameters(), lr)
    else:
        updater = lambda batch_size: d2l.sgd(net.params, lr, batch_size)
    predict = lambda prefix: predict_ch8(prefix, 50, net, vocab, device)
    # Train and predict
    for epoch in range(num_epochs):
        ppl, speed = train_epoch_ch8(net, train_iter, loss, updater, device,
                                     use_random_iter)
    print(f'perplexity {ppl:.1f}, {speed:.1f} tokens/sec on {str(device)}')
    return ppl
Exemplo n.º 7
0
def updater(batch_size):
    return d2l.sgd([W, b], lr, batch_size)
Exemplo n.º 8
0
    def __getitem__(self, idx):
        return self.data[idx]


# 评估数据迭代器 data_iter 访问的数据集在任意模型 net 上的准确率
def evaluate_accuracy(net, data_iter):
    """计算模型在指定数据集上的精度"""
    metric = Accumulator(2)  # 正确预测数、预测总数
    for X, y in data_iter:
        metric.add(accuracy(net(X), y), y.numel())
    return metric[0] / metric[1]


# 训练
lr = 0.1
num_epochs = 10

for epoch in range(num_epochs):
    # 训练损失总和、训练准确度总和、样本数
    metric = Accumulator(3)
    for X, y in train_iter:
        y_hat = net(X)
        l = cross_entropy(y_hat, y)
        l.sum().backward()
        d2l.sgd([W, b], lr, batch_size)
        metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())
    test_acc = evaluate_accuracy(net, test_iter)
    print(
        f'epch {epoch +1}, train loss {metric[0] / metric[2] :.5f}, train acc {metric[1] / metric[2] :.5f}, test acc {test_acc:.5f}'
    )