コード例 #1
0
            total_token += words_num
            loss = 0
            
            for i in range(max_len - 1):
                decoder_inputs = de_seq[:,i]
                decoder_targets = de_seq[:,i + 1]

                decoder_outputs, decoder_state = decoder(decoder_inputs, encoder_output, decoder_state)
                #outputs, targets = mask_select(ones_matrix, decoder_inputs, decoder_targets, decoder_outputs)

                loss += loss_function(decoder_outputs, decoder_targets)

            batch_loss = loss.item() / words_num
            print('Epoch: {}, Batch: {}\{}, Batch Loss: {}'.format(epoch + 1, batch_idx + 1, batch_total, batch_loss))

            loss.backward()
            torch.nn.utils.clip_grad_norm_(encoder.parameters(), 5)
            torch.nn.utils.clip_grad_norm_(decoder.parameters(), 5)
            en_optimizer.step()
            de_optimizer.step()
            
            total_loss += loss.item()

        total_loss = total_loss.item() / total_token
        print('Epoch: {}, Total Loss: {}'.format(epoch + 1, total_loss))
        
        ''' save model '''
        torch.save(encoder.state_dict(), 'encoder_rev.' + str(epoch + 1) + '.pt')
        torch.save(decoder.state_dict(), 'decoder_rev.' + str(epoch + 1) + '.pt')

コード例 #2
0
def train(article,
          title,
          word2idx,
          target2idx,
          source_lengths,
          target_lengths,
          args,
          val_article=None,
          val_title=None,
          val_source_lengths=None,
          val_target_lengths=None):

    if not os.path.exists('./temp/x.pkl'):
        size_of_val = int(len(article) * 0.05)
        val_article, val_title, val_source_lengths, val_target_lengths = \
            utils.sampling(article, title, source_lengths, target_lengths, size_of_val)

        utils.save_everything(article, title, source_lengths, target_lengths,
                              val_article, val_title, val_source_lengths,
                              val_target_lengths, word2idx)

    size_of_val = len(val_article)
    batch_size = args.batch
    train_size = len(article)
    val_size = len(val_article)
    max_a = max(source_lengths)
    max_t = max(target_lengths)
    print("source vocab size:", len(word2idx))
    print("target vocab size:", len(target2idx))
    print("max a:{}, max t:{}".format(max_a, max_t))
    print("train_size:", train_size)
    print("val size:", val_size)
    print("batch_size:", batch_size)
    print("-" * 30)
    use_coverage = False

    encoder = Encoder(len(word2idx))
    decoder = Decoder(len(target2idx), 50)
    if os.path.exists('decoder_model'):
        encoder.load_state_dict(torch.load('encoder_model'))
        decoder.load_state_dict(torch.load('decoder_model'))

    optimizer = torch.optim.Adam(list(encoder.parameters()) +
                                 list(decoder.parameters()),
                                 lr=0.001)
    n_epoch = 5
    print("Making word index and extend vocab")
    #article, article_tar, title, ext_vocab_all, ext_count = indexing_word(article, title, word2idx, target2idx)
    #article = to_tensor(article)
    #article_extend = to_tensor(article_extend)
    #title = to_tensor(title)
    print("preprocess done")

    if args.use_cuda:
        encoder.cuda()
        decoder.cuda()

    print("start training")
    for epoch in range(n_epoch):
        total_loss = 0
        batch_n = int(train_size / batch_size)
        if epoch > 0:
            use_coverage = True
        for b in range(batch_n):
            # initialization
            batch_x = article[b * batch_size:(b + 1) * batch_size]
            batch_y = title[b * batch_size:(b + 1) * batch_size]
            #batch_x_ext = article_extend[b*batch_size: (b+1)*batch_size]
            batch_x, batch_x_ext, batch_y, extend_vocab, extend_lengths = \
                utils.batch_index(batch_x, batch_y, word2idx, target2idx)

            if args.use_cuda:
                batch_x = batch_x.cuda()
                batch_y = batch_y.cuda()
                batch_x_ext = batch_x_ext.cuda()
            x_lengths = source_lengths[b * batch_size:(b + 1) * batch_size]
            y_lengths = target_lengths[b * batch_size:(b + 1) * batch_size]

            # work around to deal with length
            pack = pack_padded_sequence(batch_x_ext,
                                        x_lengths,
                                        batch_first=True)
            batch_x_ext_var, _ = pad_packed_sequence(pack, batch_first=True)
            current_loss = train_on_batch(encoder, decoder, optimizer, batch_x,
                                          batch_y, x_lengths, y_lengths,
                                          word2idx, target2idx,
                                          batch_x_ext_var, extend_lengths,
                                          use_coverage)

            batch_x = batch_x.cpu()
            batch_y = batch_y.cpu()
            batch_x_ext = batch_x_ext.cpu()

            print('epoch:{}/{}, batch:{}/{}, loss:{}'.format(
                epoch + 1, n_epoch, b + 1, batch_n, current_loss))
            if (b + 1) % args.show_decode == 0:
                torch.save(encoder.state_dict(), 'encoder_model')
                torch.save(decoder.state_dict(), 'decoder_model')
                batch_x_val, batch_x_ext_val, batch_y_val, extend_vocab, extend_lengths = \
                    utils.batch_index(val_article, val_title, word2idx, target2idx)
                for i in range(1):
                    idx = np.random.randint(0, val_size)
                    decode.beam_search(encoder, decoder,
                                       batch_x_val[idx].unsqueeze(0),
                                       batch_y_val[idx].unsqueeze(0), word2idx,
                                       target2idx, batch_x_ext_val[idx],
                                       extend_lengths[idx], extend_vocab[idx])

                batch_x_val = batch_x_val.cpu()
                batch_y_val = batch_y_val.cpu()
                batch_x_ext_val = batch_x_ext_val.cpu()

            total_loss += current_loss
            print('-' * 30)

    print()
    print("training finished")
コード例 #3
0
        loss = loss / batch_size
        enc.zero_grad()
        dec.zero_grad()
        loss.backward()
        epoch_loss += loss.item()
        print(f'Batch loss : {loss.item()}')
        nn.utils.clip_grad_norm_(enc.parameters(), 5)
        nn.utils.clip_grad_norm_(dec.parameters(), 5)
        trn_src_t.detach()
        trn_tgt_t.detach()
        opt_dec.step()
        opt_enc.step()

    opt_enc.zero_grad()
    opt_dec.zero_grad()
    torch.save(enc.state_dict(), f'encoder_{e}.pkl')
    torch.save(dec.state_dict(), f'decoder_{e}.pkl')
    print(f'Epoch training loss : {epoch_loss / n_batch}')

    enc.eval()
    dec.eval()
    test_loss = 0
    for i in range(len(tst_tgt_t) // batch_size):
        lengths = torch.LongTensor(l_tst_src[batch_size * i:batch_size *
                                             (i + 1)])
        out, h_n = enc(tst_src_t[batch_size * i:batch_size * (i + 1)], lengths)
        output = dec.teacher_force(
            tst_tgt_t[batch_size * i:batch_size * (i + 1)].reshape(
                [batch_size, tgt_max, 1]), h_n,
            torch.LongTensor(l_tst_tgt[batch_size * i:batch_size * (i + 1)]))
        for o, l, t in zip(output,