def evaluate(data_source, batch_size=10): model.eval() if args.model == 'QRNN': model.reset() total_loss = 0 ntokens = len(corpus.dictionary) hidden = model.init_hidden(batch_size) for i in range(0, data_source.size(0) - 1, args.bptt): data, targets = get_batch(data_source, i, args, evaluation=True) output, hidden = model(data, hidden) total_loss += len(data) * criterion(model.decoder.weight, model.decoder.bias, output, targets).data hidden = repackage_hidden(hidden) return total_loss.item() / len(data_source)
def train(): if args.model == 'QRNN': model.reset() total_loss = 0 start_time = time.time() ntokens = len(corpus.dictionary) hidden = model.init_hidden(args.batch_size) batch, i = 0, 0 while i < train_data.size(0) - 1 - 1: bptt = args.bptt if np.random.random() < 0.95 else args.bptt / 2. seq_len = max(5, int(np.random.normal(bptt, 5))) lr2 = optimizer.param_groups[0]['lr'] optimizer.param_groups[0]['lr'] = lr2 * seq_len / args.bptt model.train() data, targets = get_batch(train_data, i, args, seq_len=seq_len) hidden = repackage_hidden(hidden) optimizer.zero_grad() output, hidden, rnn_hs, dropped_rnn_hs = model(data, hidden, return_h=True) raw_loss = criterion(model.decoder.weight, model.decoder.bias, output, targets) loss = raw_loss if args.alpha: loss = loss + sum(args.alpha * dropped_rnn_h.pow(2).mean() for dropped_rnn_h in dropped_rnn_hs[-1:]) if args.beta: loss = loss + sum(args.beta * (rnn_h[1:] - rnn_h[:-1]).pow(2).mean() for rnn_h in rnn_hs[-1:]) loss.backward() if args.clip: torch.nn.utils.clip_grad_norm_(params, args.clip) optimizer.step() total_loss += raw_loss.data optimizer.param_groups[0]['lr'] = lr2 if batch % args.log_interval == 0 and batch > 0: cur_loss = total_loss.item() / args.log_interval elapsed = time.time() - start_time print( '| epoch {:3d} | {:5d}/{:5d} batches | lr {:05.5f} | ms/batch {:5.2f} | ' 'loss {:5.2f} | ppl {:8.2f} | bpc {:8.3f}'.format( epoch, batch, len(train_data) // args.bptt, optimizer.param_groups[0]['lr'], elapsed * 1000 / args.log_interval, cur_loss, math.exp(cur_loss), cur_loss / math.log(2))) total_loss = 0 start_time = time.time() batch += 1 i += seq_len