if '__main__' == __name__:
    from utility.jaychou.load_jaychou_lyrics import get_jaychou_lyrics
    from utility.gen_char_index import gen_char_index
    from utility.train_rnn import train_rnn

    num_epoch = 250
    lr = 0.01
    theta = 0.01
    hiddens = 256
    batch_size = 128
    num_step = 35
    num_epochs = 250
    pred_len = 50

    corpus_data = get_jaychou_lyrics()
    index_to_char_list, char_to_index_dict = gen_char_index(corpus_data)
    net = LSTM(len(index_to_char_list), hiddens, len(index_to_char_list))
    params = list(net.parameters())
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    train_rnn(
        corpus_data,
        index_to_char_list,
        char_to_index_dict,
        True,
        lr,
        nn.CrossEntropyLoss(),
        num_step,
        batch_size,
        num_epoch,
        theta,
Beispiel #2
0
                yy_device = yy.to(self.device)
                loss_sum = loss(y_hat, yy_device).sum()
                loss_sum.backward()

                gradient_clipping(params=params,
                                  theta=0.01,
                                  device=self.device)
                sgd(params, learn_rate, 1)
                n += yy.shape[0]
                l_sum += loss_sum.cpu().item() * yy.shape[0]
            if (epoch + 1) % pred_period == 0:
                print('epoch %d, perplexity %f, time %.2f sec' %
                      (epoch + 1, math.exp(l_sum / n), time.time() - start))
                for prefix in prefixes:
                    print(
                        ' -',
                        self.predict(prefix, pred_len, index_to_char_list,
                                     char_to_index_dict))


if '__main__' == __name__:
    from torch import nn
    from utility.jaychou.load_jaychou_lyrics import get_jaychou_lyrics
    lr = 100
    corpus_char_list = get_jaychou_lyrics()
    index_to_char_list, char_to_index_dict = gen_char_index(corpus_char_list)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    rnn = RNN(len(index_to_char_list), len(index_to_char_list), 256, device)
    rnn.train(corpus_char_list, False, lr, nn.CrossEntropyLoss(), 35, 32, 1000,
              ['分开', '不分开'], 50)