Exemplo n.º 1
0
 def forward(self, X, state):
     # The output `X` shape: (`num_steps`, `batch_size`, `embed_size`)
     X = self.embedding(X).permute(1, 0, 2)
     # Broadcast `context` so it has the same `num_steps` as `X`
     context = state[-1].repeat(X.shape[0], 1, 1)
     X_and_context = d2l.concat((X, context), 2)
     output, state = self.rnn(X_and_context, state)
     output = self.dense(output).permute(1, 0, 2)
     # `output` shape: (`batch_size`, `num_steps`, `vocab_size`)
     # `state` shape: (`num_layers`, `batch_size`, `num_hiddens`)
     return output, state
Exemplo n.º 2
0
def train_seq2seq(net, data_iter, lr, num_epochs, tgt_vocab, device):
    """Train a model for sequence to sequence."""
    def xavier_init_weights(m):
        if type(m) == nn.Linear:
            nn.init.xavier_uniform_(m.weight)
        if type(m) == nn.GRU:
            for param in m._flat_weights_names:
                if "weight" in param:
                    nn.init.xavier_uniform_(m._parameters[param])

    net.apply(xavier_init_weights)
    net.to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    loss = MaskedSoftmaxCELoss()
    net.train()
    animator = d2l.Animator(xlabel='epoch',
                            ylabel='loss',
                            xlim=[10, num_epochs])
    for epoch in range(num_epochs):
        timer = d2l.Timer()
        metric = d2l.Accumulator(2)  # Sum of training loss, no. of tokens
        first = True
        for i, batch in enumerate(data_iter):
            X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch]
            bos = torch.tensor([tgt_vocab['<bos>']] * Y.shape[0],
                               device=device).reshape(-1, 1)

            # 4. In training, replace teacher forcing with feeding the prediction at the previous time step into the
            # decoder. How does this influence the performance?

            if first:
                dec_input = d2l.concat([bos, Y[:, :-1]], 1)  # Teacher forcing
                first = False
            else:
                dec_input = Y_hat.argmax(dim=2)
                dec_input = dec_input[:X.shape[0]]

            Y_hat, _ = net(X, dec_input, X_valid_len)
            l = loss(Y_hat, Y, Y_valid_len)
            l.sum().backward()  # Make the loss scalar for `backward`
            d2l.grad_clipping(net, 1)
            num_tokens = Y_valid_len.sum()
            optimizer.step()
            with torch.no_grad():
                metric.add(l.sum(), num_tokens)
        if (epoch + 1) % 10 == 0:
            animator.add(epoch + 1, (metric[0] / metric[1], ))
    print(f'loss {metric[0] / metric[1]:.3f}, {metric[1] / timer.stop():.1f} '
          f'tokens/sec on {str(device)}')
Exemplo n.º 3
0
def train(model,
          training_batches,
          lr,
          vocab,
          device,
          model_save_dir,
          model_save_file=None):
    print("Training...")

    def xavier_init_weights(m):
        if type(m) == nn.Linear:
            torch.nn.init.xavier_uniform_(m.weight)
        if type(m) == nn.GRU:
            for param in m._flat_weights_names:
                if "weight" in param:
                    torch.nn.init.xavier_uniform_(m._parameters[param])

    model.apply(xavier_init_weights)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss = MaskedSoftmaxCELoss()

    start_epoch = 0

    if model_save_file and os.path.exists(model_save_file):
        checkpoint = torch.load(model_save_file)
        start_epoch = checkpoint['epoch']
        model.encoder.load_state_dict(checkpoint['en'])
        model.decoder.load_state_dict(checkpoint['de'])
        optimizer.load_state_dict(checkpoint['opt'])
    model.train()
    model.to(device)
    start = time.time()

    if not os.path.exists(model_save_dir):
        os.makedirs(model_save_dir)

    for epoch in range(start_epoch, len(training_batches), 1):
        batch = training_batches[epoch]
        X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch]
        bos = torch.tensor([vocab['<bos>']] * Y.shape[0],
                           device=device).reshape(-1, 1)
        dec_input = d2l.concat([bos, Y[:, :-1]], 1)  # Teacher forcing
        Y_hat, _ = model(X, dec_input, X_valid_len)
        l = loss(Y_hat, Y, Y_valid_len)
        l.sum().backward()  # Make the loss scalar for `backward`
        d2l.grad_clipping(model, 1)
        optimizer.step()
        print("Progress:{%.2f}%% Total time: %.2f s" % (round(
            (epoch + 1) * 100 / len(training_batches)), time.time() - start),
              end="\r")
        if (epoch + 1) % 100 == 0:
            torch.save(
                {
                    'epoch': epoch + 1,
                    'en': model.encoder.state_dict(),
                    'de': model.decoder.state_dict(),
                    'opt': optimizer.state_dict(),
                    'loss': loss,
                },
                os.path.join(model_save_dir,
                             '{}_{}.tar'.format(epoch + 1, MODEL_FILE_NAME)))
    print(model)
Exemplo n.º 4
0
output.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape

embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1
batch_size, num_steps = 64, 10
lr, num_epochs, device = 0.005, 250, d2l.try_gpu()

train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)
encoder = d2l.Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens,
                             num_layers, dropout)
decoder = Seq2SeqAttentionDecoder(len(tgt_vocab), embed_size, num_hiddens,
                                  num_layers, dropout)
net = d2l.EncoderDecoder(encoder, decoder)
d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)

engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):
    translation, dec_attention_weight_seq = d2l.predict_seq2seq(
        net, eng, src_vocab, tgt_vocab, num_steps, device, True)
    print(f'{eng} => {translation}, ',
          f'bleu {d2l.bleu(translation, fra, k=2):.3f}')

attention_weights = d2l.reshape(
    d2l.concat([step[0][0][0] for step in dec_attention_weight_seq], 0),
    (1, 1, -1, num_steps))

# Plus one to include the end-of-sequence token
d2l.show_heatmaps(attention_weights[:, :, :, :len(engs[-1].split()) + 1].cpu(),
                  xlabel='Key posistions',
                  ylabel='Query posistions')