def train_torch_lstm(conf, args=None): pdata = PoemData() pdata.read_data(conf) pdata.get_vocab() if conf.use_gpu: device = torch.device('cuda') else: device = torch.device('cpu') model = PoetryModel(pdata.vocab_size, conf, device) train_data = pdata.train_data test_data = pdata.test_data train_data = torch.from_numpy(np.array(train_data['pad_words'])) dev_data = torch.from_numpy(np.array(test_data['pad_words'])) dataloader = DataLoader(train_data, batch_size=conf.batch_size, shuffle=True, num_workers=conf.num_workers) devloader = DataLoader(dev_data, batch_size=conf.batch_size, shuffle=True, num_workers=conf.num_workers) optimizer = Adam(model.parameters(), lr=conf.learning_rate) criterion = nn.CrossEntropyLoss() loss_meter = meter.AverageValueMeter() if conf.load_best_model: model.load_state_dict(torch.load(conf.beat_model_path)) if conf.use_gpu: model.cuda() criterion.cuda() step = 0 bestppl = 1e9 early_stop_controller = 0 for epoch in range(conf.n_epochs): losses = [] loss_meter.reset() model.train() for i, data in enumerate(dataloader): data = data.long().transpose(1, 0).contiguous() if conf.use_gpu: data = data.cuda() input, target = data[:-1, :], data[1:, :] optimizer.zero_grad() output, _ = model(input) loss = criterion(output, target.contiguous().view(-1)) loss.backward() optimizer.step() losses.append(loss.item()) loss_meter.add(loss.item()) step += 1 if step % 100 == 0: print("epoch_%d_step_%d_loss:%0.4f" % (epoch + 1, step, loss.item())) train_loss = float(loss_meter.value()[0]) model.eval() for i, data in enumerate(devloader): data = data.long().transpose(1, 0).contiguous() if conf.use_gpu: data = data.cuda() input, target = data[:-1, :], data[1:, :] output, _ = model(input) loss = criterion(output, target.view(-1)) loss_meter.add(loss.item()) ppl = math.exp(loss_meter.value()[0]) print("epoch_%d_loss:%0.4f , ppl:%0.4f" % (epoch + 1, train_loss, ppl)) if epoch % conf.save_every == 0: torch.save(model.state_dict(), "{0}_{1}".format(conf.model_prefix, epoch)) fout = open("{0}out_{1}".format(conf.out_path, epoch), 'w', encoding='utf-8') for word in list('日红山夜湖海月'): gen_poetry = generate_poet(model, word, pdata.vocab, conf) fout.write("".join(gen_poetry) + '\n\n') fout.close() if ppl < bestppl: bestppl = ppl early_stop_controller = 0 torch.save(model.state_dict(), "{0}_{1}".format(conf.best_model_path, "best_model")) else: early_stop_controller += 1 if early_stop_controller > conf.patience: print("early stop.") break
loss = 0 counts = 0 for case in range(v * batch, min((v + 1) * batch, TRAINSIZE)): s = data[case] hidden = model.initHidden() t, o = makeForOneCase(s, one_hot_var_target) output, hidden = model(t.cuda(), hidden) loss += criterion(output, o.cuda()) counts += 1 loss = loss / counts print "=====", loss.data[0] print "start training" for epoch in range(epochNum): model.train() for batchIndex in range(int(TRAINSIZE / batch)): model.zero_grad() loss = 0 counts = 0 for case in range(batchIndex * batch, min((batchIndex + 1) * batch, TRAINSIZE)): s = data[case] hidden = model.initHidden() t, o = makeForOneCase(s, one_hot_var_target) output, hidden = model(t.cuda(), hidden) loss += criterion(output, o.cuda()) counts += 1 loss = loss / counts loss.backward() print epoch, loss.data[0]