Пример #1
0
def trans_word_to_index(data_set, word_cout_dic):
    index_to_char_list, char_to_index_dict = gen_char_index(
        set(word_cout_dic.keys()))
    data_set_index = [[
        char_to_index_dict[one_word] for one_word in line
        if one_word in word_cout_dic
    ] for line in data_set]
    return data_set_index, index_to_char_list, char_to_index_dict
Пример #2
0
    def train(self, corpus_data, is_random_iter, learn_rate, loss, num_step,
              batch_size, num_epoch, prefixes, pred_len):

        index_to_char_list, char_to_index_dict = gen_char_index(corpus_data)
        vocab_size = len(index_to_char_list)
        corpus_chars_index = gen_index_list(corpus_data, char_to_index_dict)

        iter_func = char_data_iter_consecutive
        if is_random_iter:
            iter_func = char_data_iter_random

        state = self.init_state(batch_size, len(self.hidden_bias))

        pred_period = 50

        for epoch in range(num_epoch):
            n = 0
            l_sum = 0
            start = time.time()
            data_iter = iter_func(corpus_chars_index, num_step, batch_size)
            for x, y in data_iter:
                if is_random_iter:
                    state = self.init_state(batch_size, len(self.hidden_bias))
                else:
                    state.detach_()
                x_vector = self.to_one_hot(x, num_step, vocab_size)
                y_hat, state = self.forward((x_vector, state))
                y_hat = torch.cat(y_hat, dim=0)

                params = self.get_params()

                for one_param in params:
                    if one_param.grad is None:
                        break
                    one_param.grad.data.zero_()

                yy = torch.transpose(y, 0, 1).contiguous().view(-1)

                yy_device = yy.to(self.device)
                loss_sum = loss(y_hat, yy_device).sum()
                loss_sum.backward()

                gradient_clipping(params=params,
                                  theta=0.01,
                                  device=self.device)
                sgd(params, learn_rate, 1)
                n += yy.shape[0]
                l_sum += loss_sum.cpu().item() * yy.shape[0]
            if (epoch + 1) % pred_period == 0:
                print('epoch %d, perplexity %f, time %.2f sec' %
                      (epoch + 1, math.exp(l_sum / n), time.time() - start))
                for prefix in prefixes:
                    print(
                        ' -',
                        self.predict(prefix, pred_len, index_to_char_list,
                                     char_to_index_dict))
Пример #3
0
if '__main__' == __name__:
    from utility.jaychou.load_jaychou_lyrics import get_jaychou_lyrics
    from utility.gen_char_index import gen_char_index
    from utility.train_rnn import train_rnn

    num_epoch = 250
    lr = 0.01
    theta = 0.01
    hiddens = 256
    batch_size = 128
    num_step = 35
    num_epochs = 250
    pred_len = 50

    corpus_data = get_jaychou_lyrics()
    index_to_char_list, char_to_index_dict = gen_char_index(corpus_data)
    net = LSTM(len(index_to_char_list), hiddens, len(index_to_char_list))
    params = list(net.parameters())
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    train_rnn(
        corpus_data,
        index_to_char_list,
        char_to_index_dict,
        True,
        lr,
        nn.CrossEntropyLoss(),
        num_step,
        batch_size,
        num_epoch,
        theta,
        ['分开', '不分开'],
import zipfile


def get_jaychou_lyrics(file_path='../../data/jaychou_lyrics.txt.zip'):
    result = None
    with zipfile.ZipFile(file_path) as zn:
        with zn.open('jaychou_lyrics.txt') as f:
            result = f.read().decode('utf-8')
    if result is not None:
        result = result.replace('\n', ' ').replace('\r', ' ')
    return result


if '__main__' == __name__:
    from utility.gen_char_index import gen_char_index, gen_index_list
    from utility.char_data_iter import char_data_iter_random, char_data_iter_consecutive
    result = get_jaychou_lyrics()
    index_to_char_list, char_to_index_dict = gen_char_index(result)
    char_index_list = gen_char_index(result, char_to_index_dict)