def trans_word_to_index(data_set, word_cout_dic): index_to_char_list, char_to_index_dict = gen_char_index( set(word_cout_dic.keys())) data_set_index = [[ char_to_index_dict[one_word] for one_word in line if one_word in word_cout_dic ] for line in data_set] return data_set_index, index_to_char_list, char_to_index_dict
def train(self, corpus_data, is_random_iter, learn_rate, loss, num_step, batch_size, num_epoch, prefixes, pred_len): index_to_char_list, char_to_index_dict = gen_char_index(corpus_data) vocab_size = len(index_to_char_list) corpus_chars_index = gen_index_list(corpus_data, char_to_index_dict) iter_func = char_data_iter_consecutive if is_random_iter: iter_func = char_data_iter_random state = self.init_state(batch_size, len(self.hidden_bias)) pred_period = 50 for epoch in range(num_epoch): n = 0 l_sum = 0 start = time.time() data_iter = iter_func(corpus_chars_index, num_step, batch_size) for x, y in data_iter: if is_random_iter: state = self.init_state(batch_size, len(self.hidden_bias)) else: state.detach_() x_vector = self.to_one_hot(x, num_step, vocab_size) y_hat, state = self.forward((x_vector, state)) y_hat = torch.cat(y_hat, dim=0) params = self.get_params() for one_param in params: if one_param.grad is None: break one_param.grad.data.zero_() yy = torch.transpose(y, 0, 1).contiguous().view(-1) yy_device = yy.to(self.device) loss_sum = loss(y_hat, yy_device).sum() loss_sum.backward() gradient_clipping(params=params, theta=0.01, device=self.device) sgd(params, learn_rate, 1) n += yy.shape[0] l_sum += loss_sum.cpu().item() * yy.shape[0] if (epoch + 1) % pred_period == 0: print('epoch %d, perplexity %f, time %.2f sec' % (epoch + 1, math.exp(l_sum / n), time.time() - start)) for prefix in prefixes: print( ' -', self.predict(prefix, pred_len, index_to_char_list, char_to_index_dict))
if '__main__' == __name__: from utility.jaychou.load_jaychou_lyrics import get_jaychou_lyrics from utility.gen_char_index import gen_char_index from utility.train_rnn import train_rnn num_epoch = 250 lr = 0.01 theta = 0.01 hiddens = 256 batch_size = 128 num_step = 35 num_epochs = 250 pred_len = 50 corpus_data = get_jaychou_lyrics() index_to_char_list, char_to_index_dict = gen_char_index(corpus_data) net = LSTM(len(index_to_char_list), hiddens, len(index_to_char_list)) params = list(net.parameters()) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') train_rnn( corpus_data, index_to_char_list, char_to_index_dict, True, lr, nn.CrossEntropyLoss(), num_step, batch_size, num_epoch, theta, ['分开', '不分开'],
import zipfile def get_jaychou_lyrics(file_path='../../data/jaychou_lyrics.txt.zip'): result = None with zipfile.ZipFile(file_path) as zn: with zn.open('jaychou_lyrics.txt') as f: result = f.read().decode('utf-8') if result is not None: result = result.replace('\n', ' ').replace('\r', ' ') return result if '__main__' == __name__: from utility.gen_char_index import gen_char_index, gen_index_list from utility.char_data_iter import char_data_iter_random, char_data_iter_consecutive result = get_jaychou_lyrics() index_to_char_list, char_to_index_dict = gen_char_index(result) char_index_list = gen_char_index(result, char_to_index_dict)