def batchify(data, bsz): # Work out how cleanly we can divide the dataset into bsz parts. nbatch = data.size(0) // bsz print(data.size(0), bsz) print("nbatch", nbatch) # Trim off any extra elements that wouldn't cleanly fit (remainders). data = data.narrow(0, 0, nbatch * bsz) # Evenly divide the data across the bsz batches. data = data.view(bsz, -1).t().contiguous() if args.cuda: data = data.cuda() return data
def train(data_source): # Turn on training mode which enables dropout. model.train() total_loss = 0 start_time = time.time() ntokens = len(corpus.dictionary) hidden = model.init_hidden(args.batch_size) hidden_lang = model.init_hidden(args.batch_size) criterion = nn.CrossEntropyLoss() batch_idx = 0 num_batch = math.ceil(data_source.size(0) / args.bptt) # print(num_batch, data_source.size(0), args.bptt) indices = np.arange(num_batch) np.random.shuffle(indices) for batch, i in enumerate(range(0, data_source.size(0) - 1, args.bptt)): sys.stdout.flush() # print(">>", batch, i, indices[batch] * args.bptt) data, targets = get_batch(data_source, indices[batch] * args.bptt) # Starting each batch, we detach the hidden state from how it was previously produced. # If we didn't, the model would try backpropagating all the way to start of the dataset. hidden = repackage_hidden(hidden) hidden_lang = repackage_hidden(hidden_lang) model.zero_grad() output, hidden = model(data, hidden) loss = criterion(output.view(-1, ntokens), targets) loss.backward() batch_idx += data.size(1) # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) opt = optim.SGD(model.parameters(), lr=lr) opt.step() total_loss += loss.data if batch % args.log_interval == 0 and batch > 0: cur_loss = total_loss.item() / args.log_interval elapsed = time.time() - start_time log = '| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | word_loss {:5.2f} | ppl {:8.2f}'.format( epoch, batch, len(data_source) // args.bptt, lr, elapsed * 1000 / args.log_interval, cur_loss, math.exp(cur_loss)) printhelper.print_log(log_file, log) total_loss = 0 start_time = time.time()